## Generating Chinese characters using Variational Autoencoder

In this notebook we are going to implement a simple Variational Autoencoder to generate Chinese Characters.

As always we are going to import the necessary packages.

In [None]:
## Importing necessary packages ##

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torchvision.transforms import transforms

from tqdm import tqdm
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

With all that aside, lets move on and import our dataset.

Now, our dataset is basically a csv file. So, we need to work around that before moving ahead.

So, lets first check the csv file using pandas.

In [None]:
## Checking the dataset csv file ##

data = pd.read_csv('../input/chinese-mnist/chineseMNIST.csv')

data.head()

So, it is quite evident that there are 4096 pixels, meaning that the images are 64 * 64 in size. Since, it is MNIST inspired hence there is no chance of it being 3 channeled.

The entire dataset has 4098 columns, with one being the label. The character column is useless and needs to be removed.

Now we will explore a bit more to get a good taste of the dataset.

In [None]:
## Deriving descriptive statistics ##

data.describe()

Now the range of the dataset is not between 0 and 1. Hence a transformation is necessary.

Furthermore, a lot of things are learnt from the dataset, like, there are total 15,000 images. Minimum value is 0.

Now, we are going to set up our pytorch dataset.

In [None]:
## Setting up pytorch dataset ##

class chinese_dataset(Dataset):
    
    def __init__(self , path):
        
        super().__init__()
        
        self.dataset_csv = pd.read_csv(path)
        
        self.imgs = self.dataset_csv.iloc[: , :-2].values
        
        self.labels = self.dataset_csv['label']
        
    def __getitem__(self , idx):
        
        cyclic_idx = idx % len(self.imgs)
        
        unprocessed_img = self.imgs[cyclic_idx]
        
        processed_img = (unprocessed_img - np.min(unprocessed_img)) / (np.max(unprocessed_img) - np.min(unprocessed_img))
        
        img = torch.from_numpy(processed_img)
        
        label = img.clone()
        
        return img , label
        
    def __len__(self):
        
        return len(self.imgs)
    
## Creating our dataset instance ##

chinese_data = chinese_dataset('../input/chinese-mnist/chineseMNIST.csv')

Now lets get insight on the dataset by visualizing the images.

In [None]:
## Exploring and Visualizing an image ##

random_idx = int(np.random.randint(low = 0 , high = len(chinese_data) , size = 1))

img , label = chinese_data[random_idx]

print('A normal image sample has :' , img.shape)

print('Maximum value of the image is :' , torch.max(img))

print('Minimum value of the image is :' , torch.min(img))

plt.imshow(img.reshape(64 , 64 , 1) , cmap = 'gray')

plt.show()

Our dataset is ready.

In [None]:
## Implementing dataloader ##

chinese_dataloader = DataLoader(dataset = chinese_data,
                                batch_size = 16 ,
                                shuffle = True)


## Checking the length of dataloader ##

print('The dataloader has' , len(chinese_dataloader) , 'mini-batches!')


## Visualizing a batch of data ##

for img , label in chinese_dataloader:
    
    fig , ax = plt.subplots(figsize = (4 , 4)) 
    
    ax.set_xticklabels([])
    
    ax.set_yticklabels([])
    
    plt.imshow(make_grid(img.reshape(-1 , 1 , 64 , 64) , 4).permute(1 , 2 , 0))
    
    plt.show()
    
    break

Now with that aside, we are going to implement our very own GPU Dataloader.

In [None]:
## Utility functions for GPU Dataloader ##

def get_device():
    
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')


device = get_device()

def transfer_to_device(data , device):
    
    if isinstance(data , (list , tuple)):
        return [transfer_to_device(each_data , device) for each_data in data]
    return data.to(device)

## Implementing GPU Dataloader ##

class GPUDataLoader:
    
    def __init__(self , dl = chinese_dataloader , device = device):
        
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        
        for batch in self.dl:
            yield transfer_to_device(batch , self.device)
            
    def __len__(self):
        
        return len(self.dl)
    
## Setting the GPU dataloader object ##

chinese_dl = GPUDataLoader()
    
## Checking the length of dataloader ##

print('The dataloader has' , len(chinese_dl) , 'mini-batches!')


## Visualizing a batch of data ##

for img , label in chinese_dl:
    
    fig , ax = plt.subplots(figsize = (4 , 4)) 
    
    ax.set_xticklabels([])
    
    ax.set_yticklabels([])
    
    plt.imshow(make_grid(img.to('cpu').reshape(-1 , 1 , 64 , 64) , 4).permute(1 , 2 , 0))
    
    plt.show()
    
    break

Now its time to set up our Variational Autoencoder Model.

Lets go!

In [None]:
## Implementing Variational Autoencoder ##

class VAE(nn.Module):
    
    def __init__(self , in_feature , code_feature):
        
        super().__init__()
        
        self.hidden1 = nn.Linear(in_feature , code_feature * 8)
        self.hidden2 = nn.Linear(code_feature * 8 , code_feature * 4)
        self.mean = nn.Linear(code_feature * 4 , code_feature)
        self.gamma = nn.Linear(code_feature * 4 , code_feature)
        self.hidden3 = nn.Linear(code_feature , code_feature * 4)
        self.hidden4 = nn.Linear(code_feature * 4 , code_feature * 8)
        self.hidden5 = nn.Linear(code_feature * 8 , in_feature)
        self.elu = nn.ELU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self , x):
        x = self.hidden1(x)
        x = self.elu(x)
        x = self.hidden2(x)
        x = self.elu(x)
        x_mean = self.mean(x)
        x_gamma = self.gamma(x)
        x_sigma = torch.exp(x_gamma / 0.5)
        dist = torch.distributions.Normal(x_mean , x_sigma)
        code_layer = dist.rsample()
        x = self.hidden3(code_layer)
        x = self.elu(x)
        x = self.hidden4(x)
        x = self.elu(x)
        out = self.hidden5(x)
        
        return x_gamma , x_mean , out

Now lets create our model instance.

In [None]:
## Creating model object ##

model = VAE(4096 , 32).to(device)

## Testing model ##

var = torch.randn(4096).to(device)

gamma , mean , out = model(var)

print('The shape of output is :' , out.shape)

Okay cool!

Our model is working perfectly it seems.

Now we need to set our loss function.

In [None]:
## Defining loss function ##

re_loss = nn.BCEWithLogitsLoss()

def latent_loss(gamma , mean):
    
    loss = 0.5 * torch.mean(torch.exp(gamma) + torch.square(mean) - 1 - gamma)
    
    return loss

Now lets define our optimizer.

In [None]:
## Optimizer ##

optim = torch.optim.Adam(model.parameters(), lr = 3e-4)

Lets also define a visualization utility function.

In [None]:
## Visualization utility function ##

def vis_img(img):
    
    fig , ax = plt.subplots(figsize = (4 , 4)) 
    
    ax.set_xticklabels([])
    
    ax.set_yticklabels([])
    
    plt.imshow(make_grid(img.detach().to('cpu').reshape(-1 , 1 , 64 , 64) , 4).permute(1 , 2 , 0))
    
    plt.show()

Now lets train!

In [None]:
## Training our model ##

num_epochs = 50

for epoch in range(num_epochs):
    
    for img , target in chinese_dl:
        
        img = img.type(torch.cuda.FloatTensor)
        
        target = target.type(torch.cuda.FloatTensor)
        
        gamma , mean , out = model(img)
        
        #print(out.shape)
        
        #print(img.shape)
        
        #print(target.shape)
        
        reconstruction_loss = re_loss(out , target)
        
        code_loss = latent_loss(gamma , mean)
        
        loss = reconstruction_loss + code_loss
        
        optim.zero_grad()
        
        loss.backward()
        
        optim.step()
        
    print('Epoch : {} / {} --> Loss : {}'.format(epoch + 1  , num_epochs , loss.item()))        
            
    vis_img(img)       
    
    vis_img(out)