In [None]:
# Assignment: Sinkhorn based Generative Modelling

# In this you will play with the following code that uses Sinkhorn based generative models, and report results. 
# There are no coding to be done in this. But you will play around with various options. 
# You will clone a github repo, that allows you to play with various options.
# You are encouraged to explore lot more than what is suggested and report if anything is interesting!

# As is known, the Sinkhorn based models are not known to be good! However, with a extension with GAN, it does seem to 
# show promising results. IT remains to be seen how well it does with GAN on CIFAR10. You may try exploring.

# Paper: Improving GANs Using Optimal Transport" by Tim Salimans, Han Zhang, Alec Radford, 
# Dimitris Metaxas (Link: https://arxiv.org/abs/1803.05573)

# Submit the notebook file with generated figures and data. 

In [None]:
!pip3 install pickleshare

In [None]:
# Import our git repository
!git clone --recursive https://github.com/Alexandre-Rio/ot_generative_models.git
#%mkdir ot_generative_models
%cd ot_generative_models

In [None]:
!nvidia-smi

In [None]:
# Load MNIST train dataset
#%mkdir data
import torchvision
from data_preprocessing import mnist_transforms
mnist = torchvision.datasets.MNIST('./data', train=True, transform=mnist_transforms, download=True)
print("Number of samples in MNIST training dataset: {}".format(len(mnist)))

# Create data loader to load and display data batches
from torch.utils.data import DataLoader
mnist_data_loader = DataLoader(mnist, batch_size=64, shuffle=True)

In [None]:
# Plot a batch of MNIST digits
from utils import plot_grid
data, _ = next(iter(mnist_data_loader))
plot_grid(data)

In [None]:
# TODO: run the following

In [None]:
%run main.py --model='sinkhorn_gan' --architecture='simple' --display=True --dataset='mnist' --hidden_dim=500 --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=2 --latent_space='uniform' --batch_size=200 --learning_rate=1e-4 --generator_steps=3 --checkpoints 10 20 50 100 150

In [None]:
from utils import generate_plot_grid
generator = Generator(input_dim=2, hidden_dim=500, output_dim=1024)
generate_plot_grid(generator)

In [None]:
generator.load_state_dict(torch.load('models/saved_models/sinkhorn_gan_generator.pth'))
generate_plot_grid(generator)

In [None]:
# What do you observe? Do you see any good output. If not, then continue with following run

In [None]:
%run main.py --model='sinkhorn_gan' --architecture='simple' --display=False --dataset='mnist' --hidden_dim=500 --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=2 --latent_space='uniform' --distance='cosine' --batch_size=200 --learning_rate=1e-4 --checkpoints 10 20 50 100 150

In [None]:
generator.load_state_dict(torch.load('models/saved_models/sinkhorn_gan_generator_cp10epochs.pth'))
generate_plot_grid(generator)

In [None]:
generator.load_state_dict(torch.load('models/saved_models/sinkhorn_gan_generator_cp50epochs.pth'))
generate_plot_grid(generator)

In [None]:
generator.load_state_dict(torch.load('models/saved_models/sinkhorn_gan_generator.pth'))
generate_plot_grid(generator)

In [None]:
# TODO: Any success? 

In [None]:
# Now try convolution architecture. 
# TODO: What do you observe? 

In [None]:
%run main.py --model='sinkhorn_gan' --architecture='conv' --display=False --dataset='mnist' --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=50 --latent_space='uniform' --distance='cosine' --batch_size=200 --learning_rate=1e-4 --checkpoints 10 20 50 100 150

In [None]:
# Now let us generate using the following code. 
# TODO: What do you observe? Does it generate anything relevant? Is there mode collapse?

In [None]:
from architectures import ConvGenerator
generator = ConvGenerator(50, mode='mnist')
generator.load_state_dict(torch.load('models/saved_models/sinkhorn_gan_generator.pth'))
generate_plot_grid(generator, model='sinkhorn_gan', latent_dim=50)

In [None]:
# Now try using a critic or discriminator 
# Does the result improve? mode collapse? 

In [None]:
%run main.py --model='sinkhorn_gan' --architecture='conv' --use_critic=True --display=False --dataset='mnist' --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=50 --latent_space='uniform' --distance='cosine' --batch_size=200 --learning_rate=1e-4 --checkpoints 10 20 50 100 150

In [None]:
# Now let us try OT-GAN

In [None]:
%run main.py --model='ot_gan' --architecture='conv' --use_critic=True --display=False --dataset='mnist' --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=50 --latent_space='uniform' --batch_size=200 --learning_rate=1e-4 --generator_steps=3 --checkpoints 10 20 50 100 150

In [None]:
from architectures import ConvGenerator
generator = ConvGenerator(50, mode='mnist')
generate_plot_grid(generator, model='ot_gan', latent_dim=50)

In [None]:
# Now let us generate from saved checkpoint

In [None]:
generator.load_state_dict(torch.load('models/saved_models/ot_gan_generator_cp10epochs.pth'))
generate_plot_grid(generator, model='ot_gan', latent_dim=50)

In [None]:
# let us generate again

In [None]:
generator.load_state_dict(torch.load('models/saved_models/ot_gan_generator.pth'))
generate_plot_grid(generator, model='ot_gan', latent_dim=50)

In [None]:
# TODO: does the code suffer from mode collapse? do we see missing digits? Also do you see strange digit shapes? 
# let us try using Gaussian latent

In [None]:
%run main.py --model='ot_gan' --architecture='conv' --use_critic=True --display=False --dataset='mnist' --entropy_regularization=1 --sinkhorn_iterations=10 --latent_dim=50 --latent_space='gaussian' --batch_size=200 --learning_rate=1e-4 --generator_steps=3 --checkpoints 10 20 50 100 150

In [None]:
# TODO: Now let us generate the images from generator. What do you observe? Are the images better? 
# TODO: If yes, then why do you think that using latent as Gaussian rather than uniform is better? 

In [None]:
generator.load_state_dict(torch.load('models/saved_models/ot_gan_generator.pth'))
generate_plot_grid(generator, model='ot_gan', latent_dim=50)

In [None]:
# In the figure generated above, do you observe any mode collapse, that is, are there any digits that are missing? If yes, then which ones?

In [None]:
# Let us try changing few more parameters: epsilon = 0.1  and Sinkhorn iterations, L = 100

In [None]:
%run main.py --model='ot_gan' --architecture='conv' --use_critic=True --display=False --dataset='mnist' --entropy_regularization=0.1 --sinkhorn_iterations=100 --latent_dim=50 --latent_space='gaussian' --batch_size=200 --learning_rate=1e-4 --generator_steps=3 --checkpoints 10 20 50 100 150

In [None]:
# Did you see any change? Also, it ran slower due to large Sinkhorn iterations? 
# TODO: try same epsilon, but L = 10 this time.

In [None]:
%run main.py --model='ot_gan' --architecture='conv' --use_critic=True --display=False --dataset='mnist' --entropy_regularization=0.1 --sinkhorn_iterations=10 --latent_dim=50 --latent_space='gaussian' --batch_size=200 --learning_rate=1e-4 --generator_steps=3 --checkpoints 10 20 50 100 150

In [None]:
# TODO: How were the results? Did it improve? Do you still see mode collapse, that is, is there any digit missing? 