<a href="https://colab.research.google.com/github/KrishnaManeeshaDendukuri/IFT6135_Programming/blob/main/assignment3/q2_solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/KrishnaManeeshaDendukuri/IFT6135_Programming.git

Cloning into 'IFT6135_Programming'...
remote: Enumerating objects: 172, done.[K
remote: Counting objects: 100% (172/172), done.[K
remote: Compressing objects: 100% (145/145), done.[K
remote: Total 172 (delta 79), reused 73 (delta 24), pack-reused 0[K
Receiving objects: 100% (172/172), 3.72 MiB | 19.55 MiB/s, done.
Resolving deltas: 100% (79/79), done.


In [2]:
import sys
sys.path.append("/content/IFT6135_Programming/assignment3")

In [7]:
import torch
from q2_sampler import svhn_sampler
from q2_model import Critic, Generator
from torch import optim
from torchvision.utils import save_image


In [8]:

def lp_reg(x, y, critic):
    """
    COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail.

    *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf
    In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the
    equivalent of f in the paper. Also consider that the norm used is the L2 norm. This is important to consider,
    because we make the assumption that your implementation follows this notation when testing your function. ***

    :param x: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution P.
    :param y: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution Q.
    :param critic: (Module) - torch module that you want to regularize.
    :return: (FloatTensor) - shape: (1,) - Lipschitz penalty
    """
    # u_dist = torch.distributions.Uniform(0,1)
    # to be used as loss? 
    batch_size = x.size(2)
    a = torch.FloatTensor(batch_size).uniform_(0, 1)
    
    one_minus_a = (1 - a[:,None]).to(device)
    a = a[:,None].to(device)
    # print(a.get_device())
    # print(one_minus_a.get_device())
    # print(a.shape, one_minus_a.shape)
    z = x*a + y*one_minus_a

    z = torch.autograd.Variable(z,requires_grad=True)

    f_z = critic(z)

    grad_z = torch.autograd.grad(outputs=f_z, inputs=z,
                               grad_outputs=torch.ones(f_z.size()).to(device),
                               create_graph=True, retain_graph=True)[0]

    grad_z = grad_z.view(grad_z.size(0),-1)

    lp = torch.mean(torch.relu(torch.norm(grad_z,p=2,dim=-1, keepdim=True)-1)**2,dim=0)
    return lp

In [9]:
def vf_wasserstein_distance(p, q, critic):
    """
    COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail.

    *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf
    In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the
    equivalent of f in the paper. This is important to consider, because we make the assuption that your implementation
    follows this notation when testing your function. ***

    :param p: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution p.
    :param q: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution q.
    :param critic: (Module) - torch module used to compute the Wasserstein distance
    :return: (FloatTensor) - shape: (1,) - Estimate of the Wasserstein distance
    """
    f_p = critic(p)
    f_q = critic(q)

    wass_dist = torch.mean(f_p, dim=0) - torch.mean(f_q, dim=0)
    return wass_dist


In [10]:
if __name__ == '__main__':
    # Example of usage of the code provided and recommended hyper parameters for training GANs.
    data_root = './'
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    n_iter = 50000 # N training iterations
    n_critic_updates = 5 # N critic updates per generator update
    lp_coeff = 10 # Lipschitz penalty coefficient
    train_batch_size = 64
    test_batch_size = 64
    lr = 1e-4
    beta1 = 0.5
    beta2 = 0.9
    z_dim = 100

    train_loader, valid_loader, test_loader = svhn_sampler(data_root, train_batch_size, test_batch_size)

    generator = Generator(z_dim=z_dim).to(device)
    critic = Critic().to(device)

    optim_critic = optim.Adam(critic.parameters(), lr=lr, betas=(beta1, beta2))
    optim_generator = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))

    # COMPLETE TRAINING PROCEDURE
    train_iter = iter(train_loader)
    valid_iter = iter(valid_loader)
    test_iter = iter(test_loader)
    for i in range(n_iter):
        generator.train()
        critic.train()
        for _ in range(n_critic_updates):
            try:
                data = next(train_iter)[0].to(device)
            except Exception:
                train_iter = iter(train_loader)
                data = next(train_iter)[0].to(device)
            #####
            # train the critic model here
            #####
            # print(data[0].shape)
            optim_critic.zero_grad()
            generated_data = generator(torch.rand(data.shape[0], z_dim,1,1).to(device))
            
            d_real = critic(data)
            d_generated = critic(generated_data)
            
            # print(d_real.shape)
            # print(d_generated.shape)
            gp = lp_reg(data,generated_data,critic)
            d_loss = vf_wasserstein_distance(data, generated_data, critic) + lp_coeff * gp

            d_loss.backward()
            optim_critic.step()


        #####
        # train the generator model here
        #####
        optim_generator.zero_grad()
        generated_data = generator(torch.rand(data.shape[0], z_dim,1,1).to(device))
        d_generated = critic(generated_data)
        g_loss = d_generated.mean()
        g_loss.backward()
        optim_generator.step()

        # Save sample images 
        if i % 100 == 0:
            z = torch.randn(64, z_dim, device=device)
            imgs = generator(z)
            save_image(imgs, f'imgs_{i}.png', normalize=True, value_range=(-1, 1))


    # COMPLETE QUALITATIVE EVALUATION


Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./train_32x32.mat


  0%|          | 0/182040794 [00:00<?, ?it/s]

Using downloaded and verified file: ./train_32x32.mat
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./test_32x32.mat


  0%|          | 0/64275384 [00:00<?, ?it/s]

  cpuset_checked))


KeyboardInterrupt: ignored

In [13]:
d_real[0]

tensor(1.0514, device='cuda:0', grad_fn=<SelectBackward0>)

In [14]:
perturb_val = 50
for i in range(z_dim):
  z_perturb = torch.randn(64, z_dim, device=device)
  print(z_perturb.shape)
  z_perturb[:, i] += perturb_val
  imgs = generator(z_perturb)
  save_image(imgs, f'disentangled_{i}.png', normalize=True, value_range=(-1, 1))


torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size([64, 100])
torch.Size

In [15]:
torch.save(generator, 'gan_model.pt')

In [12]:
from google.colab import files
files.download("gan_model.pt")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
gen = torch.load('gan_model.pt')

In [18]:
torch.save(critic, 'gan_critic.pt')
crit = torch.load('gan_critic.pt')

# Perturbation

In [17]:
perturb_val = 50
for i in range(72,73):
  z_perturb = torch.randn(64, z_dim, device=device)
  print(z_perturb.shape)
  z_perturb[:, i] += perturb_val
  imgs = gen(z_perturb)
  save_image(imgs, f'zz_test_{i}.png', normalize=True, value_range=(-1, 1))


torch.Size([64, 100])


# Latent and Data Space Interpolations

In [54]:
import numpy as np
alpha_set = np.round(np.linspace(0,1,10), 1)
alpha_set = torch.from_numpy(alpha_set)
alpha_set

tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.6000, 0.7000, 0.8000, 0.9000,
        1.0000], dtype=torch.float64)

In [55]:
torch.manual_seed(42)
 
z_0 = torch.randn(64, z_dim, device=device)
z_1 = torch.randn(64, z_dim, device=device)

In [56]:
for alpha in alpha_set:
  z_alpha = alpha*z_0 + (1-alpha)*z_1
  imgs = gen(z_alpha)

  save_image(imgs, f'data_space_{int(alpha*10)}.png', normalize=True, value_range=(-1, 1))
  

In [57]:
x_0 = gen(z_0)
x_1 = gen(z_1)

In [58]:
for alpha in alpha_set:
  imgs = alpha*x_0 + (1-alpha)*x_1
  save_image(imgs, f'latent_space_{int(alpha*10)}.png', normalize=True, value_range=(-1, 1))

In [None]:
!zip -r gan_results.zip /content/sample_data/

  adding: content/sample_data/ (stored 0%)
  adding: content/sample_data/anscombe.json (deflated 83%)
  adding: content/sample_data/README.md (deflated 42%)
  adding: content/sample_data/mnist_test.csv (deflated 88%)
  adding: content/sample_data/mnist_train_small.csv (deflated 88%)
  adding: content/sample_data/california_housing_train.csv (deflated 79%)
  adding: content/sample_data/california_housing_test.csv (deflated 76%)


In [None]:
/content/disentangled_0.png

4

In [59]:
!find /content -name 'latent_space_*' | zip latent_space.zip -@

updating: content/latent_space_8.png (deflated 0%)
updating: content/latent_space_1.png (deflated 0%)
updating: content/latent_space_9.png (deflated 0%)
updating: content/latent_space_7.png (deflated 0%)
updating: content/latent_space_2.png (deflated 0%)
updating: content/latent_space_0.png (deflated 0%)
updating: content/latent_space_3.png (deflated 0%)
updating: content/latent_space_4.png (deflated 0%)
updating: content/latent_space_10.png (deflated 0%)
updating: content/latent_space_6.png (deflated 0%)


In [48]:
!find /content -name 'data_space_*' | zip data_space.zip -@

  adding: content/data_space_6.png (deflated 0%)
  adding: content/data_space_4.png (deflated 0%)
  adding: content/data_space_10.png (deflated 0%)
  adding: content/data_space_0.png (deflated 0%)
  adding: content/data_space_1.png (deflated 0%)
  adding: content/data_space_2.png (deflated 0%)
  adding: content/data_space_8.png (deflated 0%)
  adding: content/data_space_7.png (deflated 0%)
  adding: content/data_space_3.png (deflated 0%)
  adding: content/data_space_9.png (deflated 0%)


In [60]:
!find /content -name 'disentangled_*' | zip disentangled.zip -@

  adding: content/disentangled_95.png (deflated 0%)
  adding: content/disentangled_19.png (deflated 0%)
  adding: content/disentangled_65.png (deflated 0%)
  adding: content/disentangled_44.png (deflated 0%)
  adding: content/disentangled_52.png (deflated 0%)
  adding: content/disentangled_29.png (deflated 0%)
  adding: content/disentangled_56.png (deflated 0%)
  adding: content/disentangled_62.png (deflated 0%)
  adding: content/disentangled_28.png (deflated 0%)
  adding: content/disentangled_99.png (deflated 0%)
  adding: content/disentangled_53.png (deflated 0%)
  adding: content/disentangled_35.png (deflated 0%)
  adding: content/disentangled_6.png (deflated 0%)
  adding: content/disentangled_77.png (deflated 0%)
  adding: content/disentangled_70.png (deflated 0%)
  adding: content/disentangled_3.png (deflated 0%)
  adding: content/disentangled_83.png (deflated 0%)
  adding: content/disentangled_50.png (deflated 0%)
  adding: content/disentangled_45.png (deflated 0%)
  adding: cont

In [None]:
!find /content -name 'imgs_*' | zip training_gan.zip -@

  adding: content/imgs_12000.png (deflated 0%)
  adding: content/imgs_35000.png (deflated 0%)
  adding: content/imgs_3000.png (deflated 0%)
  adding: content/imgs_13200.png (deflated 0%)
  adding: content/imgs_200.png (deflated 0%)
  adding: content/imgs_27100.png (deflated 0%)
  adding: content/imgs_600.png (deflated 0%)
  adding: content/imgs_17400.png (deflated 0%)
  adding: content/imgs_25900.png (deflated 0%)
  adding: content/imgs_7500.png (deflated 0%)
  adding: content/imgs_19400.png (deflated 0%)
  adding: content/imgs_5800.png (deflated 0%)
  adding: content/imgs_17100.png (deflated 0%)
  adding: content/imgs_5600.png (deflated 0%)
  adding: content/imgs_36800.png (deflated 0%)
  adding: content/imgs_2300.png (deflated 0%)
  adding: content/imgs_18300.png (deflated 0%)
  adding: content/imgs_19600.png (deflated 0%)
  adding: content/imgs_1900.png (deflated 0%)
  adding: content/imgs_36300.png (deflated 0%)
  adding: content/imgs_5200.png (deflated 0%)
  adding: content/imgs_2