# 1. Import Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler
from torchvision.utils import make_grid
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

# 2. Configure folder structure and cuda enabling

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if not os.path.exists("/kaggle/working/images"):
    os.mkdir('/kaggle/working/images')
device

# 3. Data Loading and Label Processing

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

In [None]:
BATCH = 16
shuffle=False
num_classes = len(os.listdir('/kaggle/input/aid-scene-classification-datasets/AID'))

In [None]:
class OneHotAID(torch.utils.data.Dataset):
    def __init__(self, base_dataset, num_classes):
        self.dataset = base_dataset
        self.num_classes = num_classes

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        one_hot = F.one_hot(torch.tensor(label), num_classes=self.num_classes)
        return img,one_hot

In [None]:
dataset = datasets.ImageFolder('/kaggle/input/aid-scene-classification-datasets/AID', transform=transform)
sample_dataset = RandomSampler(dataset, num_samples=1024)

In [None]:
one_hot_dataset = OneHotAID(dataset, num_classes)
data_loader = DataLoader(one_hot_dataset, batch_size=BATCH, shuffle=shuffle, num_workers=4, pin_memory=True, sampler=sample_dataset)

In [None]:
for imgs, labels in data_loader:
    imgs=imgs.to(device)
    labels=labels.to(device)
    break

In [None]:
imgs.shape, labels.shape

In [None]:
imgs.max(), imgs.min()

# 4. Model Architecture Definition

## 4.1 Generator Definition

In [None]:
def upsample_block(scale_factor, in_channels, out_channels, kernel_size, stride=1, padding=0, last_layer = True):
    return (
            # nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels, kernel_size=4,stride=2,padding=1), # doubles up the shape
            nn.Upsample(scale_factor=scale_factor, mode ='bilinear', align_corners = True),
            nn.BatchNorm2d(in_channels),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2,inplace=True),
            # nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            # nn.BatchNorm2d(out_channels),
            # nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2,inplace=True) if not last_layer else nn.Tanh()
    )
class Generator(nn.Module):
    def __init__(self, z_dim, label_dim, n_hidden_layers, hidden_channels):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.n_hidden_layers = n_hidden_layers
        self.hidden_channels = hidden_channels

        self.fc = nn.Sequential(
            nn.Linear(in_features=z_dim, out_features=z_dim*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=z_dim*4, out_features=32*32*2),
            nn.LeakyReLU(0.2,inplace=True)
        )

        self.label_projector = nn.Sequential(
            nn.Linear(in_features=label_dim, out_features=256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=256, out_features=512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=512, out_features=32*32*2),
            nn.LeakyReLU(0.2,inplace=True)
        )

        self.in_channels = 4 # after the fully connected network it will be converted to 32x32 images
        layers = [
            *upsample_block(scale_factor=2, in_channels=self.in_channels, out_channels=hidden_channels[0], kernel_size=3, stride=1, padding=1, last_layer=False),
        ]
        for i in range(n_hidden_layers-1):
            if i==n_hidden_layers-2:
                # if last upsample block then it will add sigmoid else relu
                ub = *upsample_block(scale_factor=2, in_channels=hidden_channels[i], out_channels=hidden_channels[i+1], kernel_size=3, stride=1, padding=1, last_layer=True),
            else:
                ub = *upsample_block(scale_factor=2, in_channels=hidden_channels[i], out_channels=hidden_channels[i+1], kernel_size=3, stride=1, padding=1, last_layer=False),
            for l in ub:
                layers.append(l)

        self.upsample = nn.Sequential(*layers)
    
    def forward(self, z, labels):
        # cat = torch.cat((z, labels), axis=1)

        labels = labels.type(torch.float32)
        
        projected = self.fc(z)

        label_projected = self.label_projector(labels)

        cat = torch.cat((projected, label_projected), -1)

        reshaped = cat.view((-1, self.in_channels, 32, 32))

        upsampled = self.upsample(reshaped)

        return upsampled

## 4.2 Instance creation and shape check

In [None]:
generator = Generator(z_dim=64, label_dim=num_classes, n_hidden_layers=3, hidden_channels=[128, 128, 3]).to(device)

In [None]:
generator

In [None]:
z = torch.randn((BATCH,64)).to(device)

In [None]:
z.device, labels.device

In [None]:
fake = generator(z, labels)

In [None]:
fake.shape

## 4.3 Discriminator Definition

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes, img_channels):
        super(Discriminator, self).__init__()
        self.critic = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=64, kernel_size=4, stride=2, padding=0, bias=False), # 128
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=2, stride=2, padding=0, bias=False), # 64
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2, padding=0, bias=False), # 32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=2, stride=2, padding=0, bias=False), # 16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=0, bias=False), # 4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=0, bias=False), # 2
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.label_projector = nn.Sequential(
            nn.Linear(in_features=num_classes, out_features=256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=256, out_features=512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=512, out_features=512),
            nn.LeakyReLU(0.2,inplace=True)
        )
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_features=2*2*128, out_features=256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=256, out_features=512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(in_features=512, out_features=512),
            nn.LeakyReLU(0.2,inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(in_features=1024, out_features=512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=128, out_features=16),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(in_features=16, out_features=1),
            # nn.Sigmoid(),
        )
    
    def forward(self, images, labels):
        labels = labels.type(torch.float32)

        conved = self.critic(images)

        flat = conved.view((conved.shape[0], 2*2*128))

        features = self.feature_extractor(flat)

        label_projected = self.label_projector(labels)

        cat = torch.cat((features, label_projected), 1) # concat in axis=1
        # added = flat+label_projected
        output = self.fc(cat)

        return output

## 4.4 Instance creation and shape check

In [None]:
discriminator = Discriminator(num_classes=num_classes, img_channels=3).to(device)

In [None]:
pred = discriminator(fake, labels)

In [None]:
pred.shape

# 5. Loss Function Definition

In [None]:
def disc_loss(loss_fn, gen_model, disc_model, images, labels, z_dim):
    batch = labels.shape[0]
    noise = torch.randn((batch, z_dim)).to(device)

    fake_imgs = gen_model(noise, labels)

    fake_pred = disc_model(fake_imgs.detach(), labels) # detach the generator while training the discriminator
    real_pred = disc_model(images, labels)

    fake_target = torch.zeros_like(fake_pred)+0.05
    real_target = torch.ones_like(real_pred)*0.95

    fake_loss = loss_fn(fake_pred, fake_target)
    real_loss = loss_fn(real_pred, real_target)

    return (fake_loss + real_loss)/2

def gen_loss(loss_fn, gen_model, disc_model, labels, z_dim):
    batch = labels.shape[0]
    noise = torch.randn((batch, z_dim)).to(device)

    fake_images = gen_model(noise, labels)
    fake_pred = disc_model(fake_images, labels)

    fake_target = torch.ones_like(fake_pred)*0.95
    loss = loss_fn(fake_pred, fake_target)

    return loss

def show(tensor, ch = 3, size = (256,256), n_imgs = 25, epoch = 0, save=False):
    data = tensor.detach().cpu().view(-1, ch, *size)
    reshaped = make_grid(data[:n_imgs], nrow=5).permute(1,2,0) # from C,H,W to H,W,C
    plt.imshow(reshaped)
    plt.title(f"Epoch: {epoch}")
    plt.axis('off')
    if save:
        plt.savefig(f"/kaggle/working/images/fake_{epoch}.png")
    plt.show()

In [None]:
show(imgs)

# 6. Training Setup and Loop

## 6.1 Hyper Parameters and Optimizers Initialization

In [None]:
lr = 0.0003
k=2
log_step=5
EPOCHS = 1000

In [None]:
base_loss = nn.BCEWithLogitsLoss()
gen_opt = torch.optim.Adam(generator.parameters(), lr = lr, betas=(0.5,0.99))
disc_opt = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5,0.99))

## 6.2 Training Loop and Visualization

In [None]:
z = torch.randn((25, 64)).to(device)
test_labels = F.one_hot(torch.LongTensor([i+1 for i in range(25)]), num_classes=num_classes).to(device)
for i in range(EPOCHS):
    print(f"EPOCH: {i} startng...")
    mean_gen_loss = 0
    mean_disc_loss = 0
    batch = 0
    for imgs, labels in tqdm(data_loader):
        b = labels.shape[0]

        imgs = imgs.to(device)
        labels = labels.to(device)

        for ik in range(k):
            disc_opt.zero_grad()
            disc_loss_ = disc_loss(loss_fn=base_loss, gen_model=generator, disc_model=discriminator, images=imgs, labels=labels, z_dim=64)
            mean_disc_loss += disc_loss_.item()*b/k
            disc_loss_.backward()
            disc_opt.step()
        
        gen_opt.zero_grad()
        gen_loss_ = gen_loss(loss_fn=base_loss, gen_model=generator, disc_model=discriminator, labels=labels, z_dim=64)
        mean_gen_loss += gen_loss_.item()*b
        gen_loss_.backward()
        gen_opt.step()
        batch+=b


        # del gen_loss_, disc_loss_
    torch.cuda.empty_cache()


    print(f"Mean Gen Loss: {mean_gen_loss}, Mean Disc Loss: {mean_disc_loss}")
    if i%log_step==0:
        with torch.no_grad():
            fake = (generator(z, test_labels)+1)/2
            # fake = generator(z, test_labels)
        show(fake, epoch=i, save=True)

In [1]:
import glob
from PIL import Image

In [3]:
imgs = glob.glob('results/images/*.png')
len(imgs)

200

In [4]:
imgs.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
imgs[:5]

['results/images\\fake_0.png',
 'results/images\\fake_5.png',
 'results/images\\fake_10.png',
 'results/images\\fake_15.png',
 'results/images\\fake_20.png']

In [5]:
images = [Image.open(i) for i in imgs]

In [7]:
images[0].save('generation progress.gif', save_all=True, append_images = images[1:], duration=200, loop=200)