### Import Libraries

In [None]:
!pip install torchinfo
import pandas as pd
import numpy as np
import itertools
import glob
import os
from tqdm.notebook import tqdm
from torchinfo import summary

import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from torch.autograd import Variable
import torch.autograd as autograd

import torch.nn as nn
import torch.nn.functional as F
import torch

LOAD_FROM_CHECKPOINT = True
CHECKPOINT_ROOT = '/kaggle/input/stargan-checkpoint/saved_models'

### Initial Setting

In [None]:
# ---------
# training
# ---------
epoch = 0 # epoch to start training from
n_epochs = 25 # number of epochs of training (suggested default : 200)
batch_size = 16 # size of the batches. suggested.
lr = 0.0002 # adam : learning rate
b1 = 0.5 # adam : decay of first order momentum of gradient
b2 = 0.999 # adam : decay of first order momentum of gradient

# ---------
# image data
# ---------
root = '/kaggle/input/face-expression-recognition-dataset/images'
img_height = 128 # size of image height
img_width = 128 # size of image width
channels = 3 # number of image channels

# ---------
# modeling
# ---------
residual_blocks = 6 # number of residual blocks in generator
n_critic = 5 # number of training iterations for WGAN discriminator
# selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] # selected attributes for the CelebA dataset
selected_attrs = ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']

In [None]:
# number of cpu (in kaggle server - Accelerator : GPU)
!cat /proc/cpuinfo | grep processor

In [None]:
n_cpu = 2 # number of cpu threads to use during batch generation

In [None]:
c_dim = len(selected_attrs) # number of input-attributes
c_dim

In [None]:
img_shape = (channels, img_height, img_width) # set image shape for pytorch
img_shape

In [None]:
cuda = torch.cuda.is_available()
cuda

### Define Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        conv_block = [
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True)
        ]
        
        self.conv_block = nn.Sequential(*conv_block) # list-unpacking
    
    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, img_shape=(3,128,128), res_blocks=9, c_dim=5):
        super(GeneratorResNet, self).__init__()
        channels, img_size, _ = img_shape
        
        # Initial convolution block
        model = [
            nn.Conv2d(channels+c_dim, 64, 7, stride=1, padding=3, bias=False), # in_channels = channels+c_dim (domain added in channel)
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        curr_dim = 64
        for _ in range(2):
            model += [
                nn.Conv2d(curr_dim, curr_dim*2, 4, stride=2, padding=1, bias=False), 
                nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True)
            ]
            curr_dim *= 2 # 64->128
        
        # Residual blocks
        for _ in range(res_blocks): # 9-loop
            model += [ResidualBlock(curr_dim)] # 128->128
        
        # Upsampling
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim = curr_dim//2 # 128->64
            
        # Output layer
        model += [
            nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), # 64 -> 3 (return RGB Image)
            nn.Tanh() # -1 < tanh(x) < 1
        ]
        
        self.model = nn.Sequential(*model) # Unpack the list of layers 
    
    def forward(self, x, c):
#         print(x.shape)
#         print(c.shape)
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x,c), 1) # get image(x) and domain(c)
#         print(x.shape)
        return self.model(x)

### Define Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape=(3,128,128), c_dim=5, n_strided=6):
        super(Discriminator, self).__init__()
        channels, img_size, _ = img_shape
        
        def discriminator_block(in_filters, out_filters):
            """Returns downsampling layers of each discriminator block"""
            layers = [
                nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), 
                nn.LeakyReLU(0.01)
            ]
            return layers
        
        layers = discriminator_block(channels, 64)
        curr_dim = 64
        for _ in range(n_strided-1):
            layers.extend(discriminator_block(curr_dim, curr_dim*2))
            curr_dim *= 2
            
        self.model = nn.Sequential(*layers)
        
        # Output 1 : PatchGAN
        self.out1 = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
        # Output 2 : Class prediction
        kernel_size = img_size//(2**n_strided)
        self.out2 = nn.Conv2d(curr_dim, c_dim, kernel_size, bias=False)
        
    def forward(self, img):
        feature_repr = self.model(img)
        out_adv = self.out1(feature_repr) # real or fake
        out_cls = self.out2(feature_repr) # matching-domain
        return out_adv, out_cls.view(out_cls.size(0), -1)
        

### Define Loss function and Initialize Loss weights

In [None]:
# Loss function - Cycle loss
criterion_cycle = torch.nn.L1Loss()

In [None]:
# Loss function - Domain-Class loss
def criterion_cls(logit, target):
    return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)

In [None]:
# Loss weights (suggested default in paper)
lambda_cls = 1
lambda_rec = 10
lambda_gp = 10

### Initialize Generator and Discriminator

In [None]:
generator = GeneratorResNet(img_shape=img_shape, res_blocks=residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)

### GPU Setting

In [None]:
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_cycle.cuda()

### Weight Setting

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02) # reset Conv2d's weight(tensor) with Gaussian Distribution

In [None]:
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal);

In [None]:
if LOAD_FROM_CHECKPOINT:
    checkpoint = torch.load(os.path.join(CHECKPOINT_ROOT, f'StarGAN_checkpoint_{epoch}_epochs.pt'))
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    print("LOADED MODELS FROM CHECKPOINT")

In [None]:
summary(generator, input_data=[torch.rand((1, 3, 128, 128)).cuda(), torch.rand((1, c_dim)).cuda()])

> Read More
- [CycleGAN Tutorial : Monet-to-Photo - Step 8. Weight Setting](https://www.kaggle.com/songseungwon/cyclegan-tutorial-monet-to-photo)

### Configure Optimizers

In [None]:
if LOAD_FROM_CHECKPOINT:
    optimizer_G = checkpoint['optimizer_G']
    optimizer_D = checkpoint['optimizer_D']
    print("LOADED OPTIMIZERS FROM CHECKPOINT")
else:
    optimizer_G = torch.optim.Adam(
        generator.parameters(),
        lr=lr,
        betas=(b1,b2)
    )
    optimizer_D = torch.optim.Adam(
        discriminator.parameters(),
        lr=lr,
        betas=(b1,b2)
    )

### Set transforms

In [None]:
processor = transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
inverse_processor = transforms.Normalize(mean = (-1.0, -1.0, -1.0), std = (2.0, 2.0, 2.0))

def transform_images(x):
    x = x.resize((128, 128))
    x = transforms.ToTensor()(x)
#     x = transforms.RandomRotation(15)(x)
    x = transforms.RandomHorizontalFlip(0.25)(x)
#     x = transforms.RandomVerticalFlip(0.25)(x)
    x = processor(x)
    return x

In [None]:
def collate_fn(batch):
    x = torch.stack([sample[0] for sample in batch])
    y = torch.stack([nn.functional.one_hot(torch.tensor(sample[1]), num_classes = c_dim).float() for sample in batch])
    #p = np.random.rand()
    #if p < augment_prob:
    #   new_x, new_y = fmix(x, y)
    #else:
    new_x, new_y = x, y
    
    # return x, y
    return new_x, new_y

In [None]:
train = datasets.ImageFolder(os.path.join(root, 'train'), transform_images)
test = datasets.ImageFolder(os.path.join(root, 'validation'), transform_images)

In [None]:
train_data = DataLoader(train, batch_size = batch_size, shuffle = True, collate_fn = collate_fn)
test_data = DataLoader(test, batch_size = batch_size, shuffle = True)

In [None]:
num_to_class = {i:c for (i, c) in enumerate(train.classes)}
class_to_num = {c:i for (i, c) in enumerate(train.classes)}
num_to_class

In [None]:
plt.imshow(inverse_processor(next(iter(train_data))[0][2]).permute(1,2,0))
plt.axis('off')

### Define Gradient Penalty Function

In [None]:
# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN-GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0),1,1,1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha*real_samples + ((1-alpha)*fake_samples)).requires_grad_(True) # requires_grad inplace
    d_interpolates, _ = D(interpolates) # adv_info, cls_info = discriminator(interpolated image)
    fake = Tensor(np.ones(d_interpolates.shape))
    # Get gradient w.r.t interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0),-1)
    gradient_penalty = ((gradients.norm(2, dim=1)-1)**2).mean()
    return gradient_penalty

### Define function to get sample images with input label list

In [None]:
def sample_images():
    """Show a generated sample of domain translations"""
    val_imgs, val_labels = next(iter(test_data))
    val_imgs = val_imgs.type(Tensor)
    val_labels = val_labels.type(Tensor)
    img_samples = None
    for i in range(10):
        img, label = val_imgs[i], val_labels[i]
        # Repeat for number of label changes
        imgs = img.repeat(c_dim, 1, 1, 1) # c_dim is number of domains (5)
        labels = [*range(c_dim)]
        labels = nn.functional.one_hot(torch.Tensor(labels).long(), num_classes = c_dim).float()
        
        labels = labels.cuda() if cuda else labels
        
        # Generate translations
        gen_imgs = generator(imgs, labels)
        # Concatenate images by width
        gen_imgs = torch.cat([x for x in gen_imgs.data], -1)
        img_sample = torch.cat((img.data, gen_imgs), -1)
        img_sample = inverse_processor(img_sample)
        # Add as row to generated samples
        img_samples = img_sample if img_samples is None else torch.cat((img_samples, img_sample),-2)
    plt.figure(figsize=(16,32))
    plt.imshow(img_samples.permute(1,2,0).detach().cpu())
    plt.axis('off')
    plt.show()

In [None]:
sample_images()

### Training

In [None]:
# import warnings
# warnings.filterwarnings(action='ignore')

generator_losses = []
discriminator_losses = []
initial_epoch = epoch

for epoch in range(epoch, n_epochs):
    # data_loader
    generator_epoch_losses = []
    discriminator_epoch_losses = []
    for i, (imgs, labels) in enumerate(tqdm(train_data)):
        # Model inputs
        imgs = imgs.type(Tensor)
        labels = labels.type(Tensor)
        
        # Sample labels as generator inputs
#         sampled_c = Tensor(np.random.randint(0, 2, (imgs.size(0), c_dim)))
        sampled_c = F.one_hot(torch.Tensor(np.random.randint(0, c_dim, imgs.size(0))).long(), c_dim).float()
        if cuda:
            sampled_c = sampled_c.cuda()
        # Generate fake batch of images
        fake_imgs = generator(imgs, sampled_c)
        
# -------------------
# Train Discriminator
# -------------------
        optimizer_D.zero_grad()
    
        # Real images
        real_validity, pred_cls = discriminator(imgs)
        # Fake images
        fake_validity, _ = discriminator(fake_imgs.detach())
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, imgs.data, fake_imgs.data)
        # Adversarial loss
        loss_D_adv = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp*gradient_penalty
        # Classification loss
        loss_D_cls = criterion_cls(pred_cls, labels)
        # Total loss
        loss_D = loss_D_adv + lambda_cls*loss_D_cls
        
        discriminator_epoch_losses.append(loss_D.item())
        
        loss_D.backward()
        optimizer_D.step()
        
        optimizer_G.zero_grad()
        
        # Every n_critic times update generator
        if i % n_critic == 0: # n_critic : 5

        # -------------------
        # Train Generator
        # -------------------
            # Translate and reconstruct image
            gen_imgs = generator(imgs, sampled_c)
            recov_imgs = generator(gen_imgs, labels)
            # Discriminator evaluates translated image
            fake_validity, pred_cls = discriminator(gen_imgs)
            # Adversarial loss
            loss_G_adv = -torch.mean(fake_validity)
            # Classification loss
            loss_G_cls = criterion_cls(pred_cls, sampled_c)
            # Reconstruction loss
            loss_G_rec = criterion_cycle(recov_imgs, imgs)
            # Total loss
            loss_G = loss_G_adv + lambda_cls*loss_G_cls + lambda_rec*loss_G_rec
            
            generator_epoch_losses.append(loss_G.item())
            
            loss_G.backward()
            optimizer_G.step()
            
        # -------------------
        # Show Progress
        # -------------------
        if (i+1) % 50 == 0: 
            print("[Epoch %d/%d] [Batch %d/%d] [D adv: %f, aux: %f] [G loss: %f, adv: %f, aux: %f, cycle: %f]"
                % (
                    epoch+1, n_epochs,                     # Epoch
                    i+1,len(train_data),                   # Batch
                    loss_D_adv.item(),loss_D_cls.item(),   # D loss
                    loss_G.item(),loss_G_adv.item(),       # G loss (total, adv)
                    loss_G_cls.item(),loss_G_rec.item(),   # G loss (cls, cycle)
                ))
    generator_losses.append(np.mean(generator_epoch_losses))
    discriminator_losses.append(np.mean(discriminator_epoch_losses))
    sample_images()


In [None]:
print(generator_losses)

In [None]:
print(discriminator_losses)

In [None]:
plt.plot(generator_losses)
plt.title("Generator Losses")
plt.xlabel("Epoch")
plt.xticks([*range(initial_epoch, n_epochs, 5)])
plt.ylabel("Loss")

In [None]:
plt.plot(discriminator_losses)
plt.title("Discriminator Losses")
plt.xlabel("Epoch")
plt.xticks([*range(initial_epoch, n_epochs, 5)])
plt.ylabel("Loss")

In [None]:
os.makedirs('saved_models', exist_ok = True)

checkpoint = {
    'generator_state_dict': generator.state_dict(),
    'discriminator_state_dict': discriminator.state_dict(),
    'optimizer_G': optimizer_G,
    'optimizer_D': optimizer_D
}

torch.save(checkpoint, f'saved_models/StarGAN_checkpoint_{n_epochs}_epochs.pt')
torch.save(generator.state_dict(), f'saved_models/StarGAN_generator_{n_epochs}_epochs.pt')