### Import Libraries

In [None]:
!pip install torchinfo
import pandas as pd
import numpy as np
import itertools
import glob
import os
from tqdm.notebook import tqdm
from torchinfo import summary

import torchvision.transforms as transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import shutil

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from torch.autograd import Variable
import torch.autograd as autograd

import torch.nn as nn
import torch.nn.functional as F
import torch

torch.manual_seed(42)

CHECKPOINT_ROOT = '/kaggle/input/stargan-checkpoint/saved_models'

### Initial Setting

In [None]:
# ---------
# training
# ---------
epoch = 100 # epoch to start training from
batch_size = 16 # size of the batches. suggested.
lr = 0.0002 # adam : learning rate
b1 = 0.5 # adam : decay of first order momentum of gradient
b2 = 0.999 # adam : decay of first order momentum of gradient

# ---------
# image data
# ---------
root = '/kaggle/input/face-expression-recognition-dataset/images'
img_height = 128 # size of image height
img_width = 128 # size of image width
channels = 3 # number of image channels

# ---------
# modeling
# ---------
residual_blocks = 6 # number of residual blocks in generator
n_critic = 5 # number of training iterations for WGAN discriminator
# selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young'] # selected attributes for the CelebA dataset
selected_attrs = ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']

In [None]:
# number of cpu (in kaggle server - Accelerator : GPU)
!cat /proc/cpuinfo | grep processor

In [None]:
n_cpu = 2 # number of cpu threads to use during batch generation

In [None]:
c_dim = len(selected_attrs) # number of input-attributes
c_dim

In [None]:
img_shape = (channels, img_height, img_width) # set image shape for pytorch
img_shape

In [None]:
cuda = torch.cuda.is_available()
cuda

### Define Generator

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        conv_block = [
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True)
        ]
        
        self.conv_block = nn.Sequential(*conv_block) # list-unpacking
    
    def forward(self, x):
        return x + self.conv_block(x)

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, img_shape=(3,128,128), res_blocks=9, c_dim=5):
        super(GeneratorResNet, self).__init__()
        channels, img_size, _ = img_shape
        
        # Initial convolution block
        model = [
            nn.Conv2d(channels+c_dim, 64, 7, stride=1, padding=3, bias=False), # in_channels = channels+c_dim (domain added in channel)
            nn.InstanceNorm2d(64, affine=True, track_running_stats=True),
            nn.ReLU(inplace=True)
        ]
        
        # Downsampling
        curr_dim = 64
        for _ in range(2):
            model += [
                nn.Conv2d(curr_dim, curr_dim*2, 4, stride=2, padding=1, bias=False), 
                nn.InstanceNorm2d(curr_dim*2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True)
            ]
            curr_dim *= 2 # 64->128
        
        # Residual blocks
        for _ in range(res_blocks): # 9-loop
            model += [ResidualBlock(curr_dim)] # 128->128
        
        # Upsampling
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False),
                nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
            ]
            curr_dim = curr_dim//2 # 128->64
            
        # Output layer
        model += [
            nn.Conv2d(curr_dim, channels, 7, stride=1, padding=3), # 64 -> 3 (return RGB Image)
            nn.Tanh() # -1 < tanh(x) < 1
        ]
        
        self.model = nn.Sequential(*model) # Unpack the list of layers 
    
    def forward(self, x, c):
#         print(x.shape)
#         print(c.shape)
        c = c.view(c.size(0), c.size(1), 1, 1)
        c = c.repeat(1, 1, x.size(2), x.size(3))
        x = torch.cat((x,c), 1) # get image(x) and domain(c)
#         print(x.shape)
        return self.model(x)

### Define Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape=(3,128,128), c_dim=5, n_strided=6):
        super(Discriminator, self).__init__()
        channels, img_size, _ = img_shape
        
        def discriminator_block(in_filters, out_filters):
            """Returns downsampling layers of each discriminator block"""
            layers = [
                nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), 
                nn.LeakyReLU(0.01)
            ]
            return layers
        
        layers = discriminator_block(channels, 64)
        curr_dim = 64
        for _ in range(n_strided-1):
            layers.extend(discriminator_block(curr_dim, curr_dim*2))
            curr_dim *= 2
            
        self.model = nn.Sequential(*layers)
        
        # Output 1 : PatchGAN
        self.out1 = nn.Conv2d(curr_dim, 1, 3, padding=1, bias=False)
        # Output 2 : Class prediction
        kernel_size = img_size//(2**n_strided)
        self.out2 = nn.Conv2d(curr_dim, c_dim, kernel_size, bias=False)
        
    def forward(self, img):
        feature_repr = self.model(img)
        out_adv = self.out1(feature_repr) # real or fake
        out_cls = self.out2(feature_repr) # matching-domain
        return out_adv, out_cls.view(out_cls.size(0), -1)
        

### Define Loss function and Initialize Loss weights

In [None]:
# Loss function - Cycle loss
criterion_cycle = torch.nn.L1Loss()

In [None]:
# Loss function - Domain-Class loss
def criterion_cls(logit, target):
    return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)

In [None]:
# Loss weights (suggested default in paper)
lambda_cls = 1
lambda_rec = 10
lambda_gp = 10

### Initialize Generator and Discriminator

In [None]:
generator = GeneratorResNet(img_shape=img_shape, res_blocks=residual_blocks, c_dim=c_dim)
discriminator = Discriminator(img_shape=img_shape, c_dim=c_dim)

### GPU Setting

In [None]:
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_cycle.cuda()

### Weight Setting

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02) # reset Conv2d's weight(tensor) with Gaussian Distribution

In [None]:
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal);

In [None]:
checkpoint = torch.load(os.path.join(CHECKPOINT_ROOT, f'StarGAN_checkpoint_{epoch}_epochs.pt'))
generator.load_state_dict(checkpoint['generator_state_dict'])
discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
print("LOADED MODELS FROM CHECKPOINT")
# generator.eval()

In [None]:
summary(generator, input_data=[torch.rand((1, 3, 128, 128)).cuda(), torch.rand((1, c_dim)).cuda()])

### Set transforms

In [None]:
processor = transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))
inverse_processor = transforms.Normalize(mean = (-1.0, -1.0, -1.0), std = (2.0, 2.0, 2.0))

def transform_images(x):
    x = x.resize((128, 128))
    x = transforms.ToTensor()(x)
#     x = transforms.RandomRotation(15)(x)
    x = transforms.RandomHorizontalFlip(0.25)(x)
#     x = transforms.RandomVerticalFlip(0.25)(x)
    x = processor(x)
    return x

In [None]:
def collate_fn(batch):
    x = torch.stack([sample[0] for sample in batch])
    y = torch.stack([nn.functional.one_hot(torch.tensor(sample[1]), num_classes = c_dim).float() for sample in batch])
    #p = np.random.rand()
    #if p < augment_prob:
    #   new_x, new_y = fmix(x, y)
    #else:
    new_x, new_y = x, y
    
    # return x, y
    return new_x, new_y

In [None]:
train = datasets.ImageFolder(os.path.join(root, 'train'), transform_images)
test = datasets.ImageFolder(os.path.join(root, 'validation'), transform_images)

In [None]:
train_data = DataLoader(train, batch_size = 1, shuffle = True, collate_fn = collate_fn)
test_data = DataLoader(test, batch_size = 1, shuffle = True)

In [None]:
num_to_class = {i:c for (i, c) in enumerate(train.classes)}
class_to_num = {c:i for (i, c) in enumerate(train.classes)}
num_to_class

In [None]:
plt.imshow(inverse_processor(next(iter(train_data))[0][0]).permute(1,2,0))
plt.axis('off')
# next(iter(train_data))[0]

### Generate new dataset with synthetic images

In [None]:
df = pd.DataFrame([{'img': img, 'class_name': num_to_class[torch.argmax(c).item()]} for (img, c) in train_data])

In [None]:
# sns.set_theme(palette='deep')
sns.reset_orig()
# sns.set_theme(rc={'figure.figsize':(11.7,8.27)})

ax = sns.histplot(pd.Categorical(df['class_name'], train.classes), discrete = True)
ax.set(xlabel='Class', ylabel='#samples')

In [None]:
num_to_generate = df['class_name'].value_counts()['fear'] - df['class_name'].value_counts()['disgust']
print(f"Generating {num_to_generate} samples")

generated = []
for i in tqdm(range(len(df))):
    img = df.iloc[i]['img']
    class_name = df.iloc[i]['class_name']
    if class_name != 'disgust':
        label = nn.functional.one_hot(torch.Tensor([class_to_num['disgust']]).long(), num_classes = c_dim).float()
        if cuda:
            img = img.cuda()
            label = label.cuda()
        with torch.no_grad():
            gen = generator(img, label)[0].cpu()
        print(gen.shape)
        img = img.cpu()
        label = label.cpu()
#         show_img = torch.cat([img[0], gen], axis = 2)
#         plt.imshow(inverse_processor(show_img).permute(1, 2, 0).detach().cpu())
#         plt.show()
        generated.append(gen)
    if len(generated) == num_to_generate:
        break

In [None]:
# shutil.rmtree('images')

In [None]:
os.makedirs(f'images', exist_ok = True)
for split in ['train', 'validation']:
    print(f"Working on {split}...")
    shutil.copytree(os.path.join(root, split), f'images/{split}/')
    if split == 'train':
        for i, img in tqdm(enumerate(generated)):
            img = inverse_processor(img).permute(1, 2, 0).numpy()
            img = (img * 255).astype(np.uint8)
            im = Image.fromarray(img)
            df.loc[len(df)] = {'img': img, 'class_name': 'disgust'}
            im.save(f'images/train/disgust/synthetic_{i}.jpg')

In [None]:
# sns.set_theme(palette='deep')
sns.reset_orig()
# sns.set_theme(rc={'figure.figsize':(11.7,8.27)})

ax = sns.histplot(pd.Categorical(df['class_name'], train.classes), discrete = True)
ax.set(xlabel='Class', ylabel='#samples')