In [None]:
import torch
from q2_sampler import svhn_sampler
from q2_model import Critic, Generator
from torch import optim
from torchvision.utils import save_image
from torch.autograd import Variable



def lp_reg(x, y, critic):
    """
    COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail.

    *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf
    In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the
    equivalent of f in the paper. Also consider that the norm used is the L2 norm. This is important to consider,
    because we make the assumption that your implementation follows this notation when testing your function. ***

    :param x: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution P.
    :param y: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution Q.
    :param critic: (Module) - torch module that you want to regularize.
    :return: (FloatTensor) - shape: (1,) - Lipschitz penalty
    """
    batch_size = x.size(0)
    lambdas = torch.rand(batch_size, 1, 1, 1).to(x.device)
    # lambdas = lambdas.expand(x.size())
    
    interpolation = lambdas * x + (1 - lambdas) * y
    interpolation.requires_grad_(True)
    # interpolation.retain_grad()

    interp_logit = critic(interpolation)
    # grad_output = torch.ones_like(interp_logit)
    
    gradient = torch.autograd.grad(
        outputs=interp_logit,
        inputs=interpolation,
        grad_outputs=torch.ones_like(interp_logit),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradient = gradient.view(batch_size, -1)
    zeros = torch.zeros(batch_size, 1)
    grad_norm = gradient.norm(2, dim=1)

    return torch.mean(torch.max((grad_norm - 1), zeros)** 2)



def vf_wasserstein_distance(p, q, critic):
    """
    COMPLETE ME. DONT MODIFY THE PARAMETERS OF THE FUNCTION. Otherwise, tests might fail.

    *** The notation used for the parameters follow the one from Petzka et al: https://arxiv.org/pdf/1709.08894.pdf
    In other word, x are samples from the distribution mu and y are samples from the distribution nu. The critic is the
    equivalent of f in the paper. This is important to consider, because we make the assuption that your implementation
    follows this notation when testing your function. ***

    :param p: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution p.
    :param q: (FloatTensor) - shape: (batchsize x featuresize) - Samples from a distribution q.
    :param critic: (Module) - torch module used to compute the Wasserstein distance
    :return: (FloatTensor) - shape: (1,) - Estimate of the Wasserstein distance
    """
    vf_wasserstein_distance = torch.mean(critic(p)) - torch.mean(critic(q))
    return vf_wasserstein_distance

In [None]:
# Example of usage of the code provided and recommended hyper parameters for training GANs.
data_root = './'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
n_iter = 50000 # N training iterations
n_critic_updates = 5 # N critic updates per generator update
lp_coeff = 10 # Lipschitz penalty coefficient
train_batch_size = 64
test_batch_size = 64
lr = 1e-4
beta1 = 0.5
beta2 = 0.9
z_dim = 100

train_loader, valid_loader, test_loader = svhn_sampler(data_root, train_batch_size, test_batch_size)

generator = Generator(z_dim=z_dim).to(device)
critic = Critic().to(device)

optim_critic = optim.Adam(critic.parameters(), lr=lr, betas=(beta1, beta2))
optim_generator = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))

In [None]:
# COMPLETE TRAINING PROCEDURE
train_iter = iter(train_loader)
valid_iter = iter(valid_loader)
test_iter = iter(test_loader)
for i in range(n_iter):
    generator.train()
    critic.train()
    for _ in range(n_critic_updates):
        try:
            data = next(train_iter)[0].to(device)
        except Exception:
            train_iter = iter(train_loader)
            data = next(train_iter)[0].to(device)
        #####
        # train the critic model here
        # train on real
        real_input = data.to(device=device, dtype=torch.float32)
        output_real = critic(real_input)
        loss_real = output_real.mean()
        loss_real.backward(torch.FloatTensor([-1]).to(device=device))
        # train on fake
        noise = Variable(torch.randn(batch_size, 100).to(device=device))
        fake_input = generator(noise)
        output_fake = critic(fake_input.data)
        loss_fake = output_fake.mean()
        loss_fake.backward(torch.FloatTensor([1]).to(device=device)) 
        # train on gradient
        gradient_penalty = lp_reg(real_input.data, fake_input.data,
                                                critic)
        gradient_penalty.backward(retain_graph=True)
        optim_critic.step() 
        #####

    #####
    # train the generator model here
    generator.zero_grad()

    noise = Variable(torch.randn(batch_size, 100).to(device=device))
    fake_input = generator(noise)
    score_generator = critic(fake_input)
    loss_gen = score_generator.mean()
    loss_gen.backward(torch.FloatTensor([-1]).to(device=device))
    optim_generator.step()
    #####

    # Save sample images 
    if i % 100 == 0:
        z = torch.randn(64, z_dim, device=device)
        imgs = generator(z)
        save_image(imgs, f'imgs_{i}.png', normalize=True, value_range=(-1, 1))


# COMPLETE QUALITATIVE EVALUATION


Downloading http://ufldl.stanford.edu/housenumbers/train_32x32.mat to ./train_32x32.mat


  0%|          | 0/182040794 [00:00<?, ?it/s]

Using downloaded and verified file: ./train_32x32.mat
Downloading http://ufldl.stanford.edu/housenumbers/test_32x32.mat to ./test_32x32.mat


  0%|          | 0/64275384 [00:00<?, ?it/s]