## Representation Learning using Deep InfoMAX 

In this notebook we are going to implement the Deep InfoMAX paper by Hjelm et. al.

We are going to implement the algorithm based on the publicly available Chinese Fine Art dataset. You can collect the dataset from this link: https://www.kaggle.com/rickyjli/chinese-fine-art.

So, let's get started by importing the necessary packages. 

In [None]:
## Importing the necessary packages ##

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torchvision.models import alexnet
from torchvision.transforms import transforms

from tqdm import tqdm
import os
import PIL.Image as Image
import numpy as np
import random
import matplotlib.pyplot as plt

Okay done!

We are all set with the packages. Now let's get going with our dataset.

In [None]:
## Setting transformations ##

aug = transforms.Compose([
    transforms.Resize((256 , 256)),
    transforms.ToTensor()
])

## Creating our dataset ##

class ArtDataset(Dataset):
    
    def __init__(self , augment = aug , root = 'Images'):
        self.root = root
        self.images = os.listdir(root)
        self.augment = augment
        self.length = len(self.images)
        
    def __getitem__(self , index):
        idx = index % self.length
        img_name = self.images[idx]
        img_path = os.path.join(self.root , img_name)
        img_pil = Image.open(img_path).convert('RGB')
        img_tensor = aug(img_pil)
        
        fake_idx = random.choice([ele for ele in range(self.length) if ele != idx])
        fake_name = self.images[fake_idx]
        fake_path = os.path.join(self.root , fake_name)
        fake_pil = Image.open(fake_path).convert('RGB')
        fake_tensor = aug(fake_pil)
        
        return img_tensor , fake_tensor
    
    def __len__(self):
        
        return self.length
    
## Creating our dataset object ##

art_dataset = ArtDataset()


## Sanity checking our dataset ##

random_idx = int(np.random.randint(low = 0 , high = len(art_dataset) , size = 1))

real , fake = art_dataset[random_idx]

plt.imshow(real.permute(1, 2 , 0))
plt.show()

plt.imshow(fake.permute(1 , 2 , 0))
plt.show()

Now we need to set up our dataloader, which will feed in batches of data from the dataset.

In [None]:
## Setting our dataloader ##

art_dataloader = DataLoader(dataset = art_dataset , 
                            batch_size = 4 , 
                            shuffle = True)


## Setting a visualization utility function ##

def show_img(real , fake):
    
    fig , ax = plt.subplots(figsize = (2 , 2))
    plt.imshow(make_grid(real.detach().to('cpu') , 2).permute(1 , 2 , 0))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.title('Real Images')
    plt.show()
    
    fig , ax = plt.subplots(figsize = (2 , 2))
    plt.imshow(make_grid(fake.detach().to('cpu') , 2).permute(1 , 2 , 0))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    plt.title('Fake Images')
    plt.show()
    
## Sanity checking ##

for real , fake in art_dataloader:
    
    show_img(real , fake)
    
    break

Now we must put all our dataloader images into our GPU to promote parallel computing. So the next part is all about transferring the data to the GPU. 

In [None]:
## GPU Transfer utility functions ##

## Checking if cuda is available and setting default device as cuda ##

def get_device():
    
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

## Setting the default device ##

device = get_device()

## Transferring data to a specific device ##

def transfer_data(data, device):
    
    if isinstance(data , (list , tuple)):
        return [transfer_data(each_data , device) for each_data in data]
    return data.to(device)

## Setting the GPU Dataloader ##

class GPUDataloader:
    
    def __init__(self , dl , device):
        
        self.dl = dl
        self.device  = device
        
    def __iter__(self):
        
        for batch in self.dl:
            
            yield transfer_data(batch , self.device)
            
    def __len__(self):
        
        return len(self.dl)
    
## Making our GPU Dataloader object ##

art_dl = GPUDataloader(art_dataloader , device)

## Visualizing a mini-batch ##

for real , fake in art_dl:
    
    show_img(real , fake)
    
    break

Now we are going to **build our model**.

Our model consists of 3 basic parts: a global feature map extractor, a local feature map extractor (which is the extension of the global feature map) and a discriminator.

In [None]:
## Lets first build the global feature extractor ##

## It is the output from the last conv layer ##

class Global_DIM(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.global_encoder = alexnet(pretrained = False).features[:11]
        self.flat = nn.Flatten()
        self.extended_encoder = nn.Sequential(nn.Linear(57600 , 4096),
                                              nn.ReLU(),
                                              nn.Linear(4096 , 4096),
                                              nn.ReLU(),
                                              nn.Linear(4096 , 64))
        self.mutual_info = nn.Sequential(nn.Linear(57664 , 512) , 
                                         nn.ReLU(),
                                         nn.Linear(512 , 512) ,
                                         nn.ReLU(),
                                         nn.Linear(512 , 1))
        
    def forward(self , image):
        
        global_encode = self.global_encoder(image)
        flatten_global_encode = self.flat(global_encode)
        extended_encode = self.extended_encoder(flatten_global_encode)
        final_out = self.mutual_info(torch.cat([flatten_global_encode , extended_encode] , 1))
        
        return final_out , global_encode , extended_encode
    
    
## Now lets build the local feature extractor ##

class Local_DIM(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels = 320 , out_channels = 512 , kernel_size = (1 , 1) , bias = True),
                                  nn.ReLU(),
                                  nn.Conv2d(in_channels = 512 , out_channels = 512 , kernel_size = (1 , 1) , bias = True),
                                  nn.ReLU(),
                                  nn.Conv2d(in_channels = 512 , out_channels = 1 , kernel_size = (1 , 1) , bias = True))
        
    def forward(self , x):
        out = self.conv(x)
        
        return out
    
## Now our Deep InfoMAX model ##

class DeepInfomax(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.global_dim = Global_DIM()
        self.local_dim = Local_DIM()
        
    def forward(self , img):
        
        global_out , global_encode , extended_encode = self.global_dim(img)
        
        #print(global_encode.shape)
        #print(extended_encode.shape)
        
        local_out = self.local_dim(torch.cat([global_encode , extended_encode.unsqueeze(2).unsqueeze(3).repeat(1 , 1 , 
                                                                                              global_encode.shape[2] , 
                                                                                              global_encode.shape[3])] , 1))
        return global_out , local_out
    

## Creating our model ##

model = DeepInfomax().to(device)

Now lets set our loss functions and optimizer.

In [None]:
## GLobal Loss functions ##

def global_loss(tensor1 , tensor2):
    
    loss = -torch.mean((tensor1 - torch.mean(torch.log(torch.sum(torch.exp(tensor2))))))
    
    return loss

## Local Loss function ##

def local_loss(tensor1 , tensor2):
    
    loss = -(1 / (tensor1.shape[2] * tensor1.shape[3])) * torch.mean(
        (tensor1 - torch.mean(torch.log(torch.sum(torch.exp(tensor2))))))
    
    return loss

## Setting optimizer ##

optim = torch.optim.Adam(model.parameters() , lr = 3e-4)

Now let's train our network!!

In [None]:
## Training our model ##

num_epochs = 50

loop = tqdm(art_dl)

for epoch in range(num_epochs):
    
    for real , fake in loop:
        
        real_global_out , real_local_out = model(real)
        
        fake_global_out , fake_local_out = model(fake)
        
        glob_loss = global_loss(real_global_out , fake_global_out)
        
        loc_loss = local_loss(real_local_out , fake_local_out)
        
        total_loss = glob_loss + loc_loss
        
        total_loss.backward()
        
        optim.zero_grad()
        
        optim.step()
        
        loop.set_description('Epochs : {} / {}'.format(epoch + 1 , num_epochs))
        loop.set_postfix(loss = total_loss.item())