In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch import autograd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import numpy as np
import h5py
import matplotlib.pyplot as plt
import gc

from adv_models import ConvAE, DiscriminateModel
from data import PCam_Dataset_local

In [2]:
def calc_gradient_penalty(netD, real_data, fake_data, batch_size, dtype, use_cuda=True, gpu=0, LAMBDA=10):

    alpha = torch.rand(batch_size, 1)
    alpha = alpha.expand(batch_size, int(real_data.nelement()/batch_size)).contiguous().view(batch_size, 3, 48, 48)
    alpha = alpha.type(dtype)

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if use_cuda:
        interpolates = interpolates.type(dtype)

    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).type(dtype) if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

In [3]:
def adv_training(G_net, D_net, D_solver, G_solver, N_critic, weight, dtype, part=0, batch_size=32, num_epochs=20, print_every=100):

    epoch = 0
    G_net.train()
    D_net.train()

    one = torch.tensor(1, dtype=torch.float)
    min_one = one * -1
    one = one.type(dtype)
    min_one = min_one.type(dtype)

    MSELoss = nn.MSELoss()
    Wasserstein_D, D_cost, G_cost, rec_loss = 0.0, 0.0, 0.0, 0.0
    pcam = PCam_Dataset_local('data.h5', part=-1, download=False, train=True)
    loader_train = DataLoader(pcam, batch_size=batch_size, shuffle=True)
    
    while epoch < num_epochs:

        print('Starting epoch %d / %d' % (epoch, num_epochs))
        critic_iters = 0
  
        for iter_, (x, y) in enumerate(loader_train):
      
            working_size = x.shape[0]
            for param in D_net.parameters():
                param.requires_grad = True

            if critic_iters < N_critic:
                D_net.zero_grad()
                real_images = autograd.Variable(y.type(dtype))
                D_real = D_net(real_images)
                D_real = D_real.mean()
                D_real.backward(min_one)

                cropped_images = autograd.Variable(x.type(dtype))
                fake_images = autograd.Variable(G_net(cropped_images).detach())
                inp_D = fake_images
                D_fake = D_net(inp_D)
                D_fake = D_fake.mean()
                D_fake.backward(one)

                gradient_penalty = calc_gradient_penalty(D_net, real_images.data, fake_images.data, working_size, dtype, use_cuda=True, gpu=0)
                gradient_penalty.backward()

                D_cost = D_real - D_fake + gradient_penalty
                Wasserstein_D = D_real - D_fake
                D_solver.step()

                critic_iters += 1

            else:
                for param in D_net.parameters():
                    param.requires_grad = False
                    
                G_net.zero_grad()
                cropped_images = autograd.Variable(x.type(dtype))
                fake_imgs = G_net(cropped_images)
                G_fake = D_net(fake_imgs)
                G_fake = G_fake.mean()
                G_cost = -G_fake

                rec_loss = MSELoss(fake_imgs, y.type(dtype))
                total_loss = (1-weight) * rec_loss + weight * G_cost

                total_loss.backward(one)
                D_solver.step()

                critic_iters = 0

            if (iter_ % print_every == 0) and iter_ > 0:
                print('Iter = {0}, Wasserstein_D = {1}, D_cost = {2}, rec_loss = {3}, G_cost = {4}'
                .format(iter_, Wasserstein_D.item(), D_cost.item(), rec_loss.item(), G_cost.item()))

        torch.save({
          'G': G_net.state_dict(),
          'G_opt': G_solver.state_dict(),
          'D': D_net.state_dict(),
          'D_opt': D_solver.state_dict()
          },
        'chkpntWES.pt')

        epoch += 1

    return G_net, D_net

In [4]:
def run_adv(adv_rec_weight, batch_size, chkpnt_file=None):

    if torch.cuda.is_available():
        gpu_dtype = torch.cuda.FloatTensor
    else:
        gpu_dtype = torch.FloatTensor

    assert torch.cuda.is_available() == True, "CUDA is not available"

    G_net = ConvAE(3, 3).type(gpu_dtype)
    D_net = DiscriminateModel(3, 1).type(gpu_dtype)

    optimizerG = optim.Adam(G_net.parameters(), lr=1e-3)
    optimizerD = optim.Adam(D_net.parameters(), lr=1e-3)

    if chkpnt_file is not None:
        saved_dict = torch.load(chkpnt_file)

        G_net.load_state_dict(saved_dict['G'])
        D_net.load_state_dict(saved_dict['D'])

        optimizerG.load_state_dict(saved_dict['G_opt'])
        optimizerD.load_state_dict(saved_dict['D_opt'])

    #def adv_training(G_net, D_net, D_solver, G_solver, N_critic, weight, dtype, part=0, batch_size=32, num_epochs=20, print_every=100)
    gen, dic = adv_training(G_net, D_net, optimizerD, optimizerG, 4, adv_rec_weight, gpu_dtype, part=0, batch_size=batch_size, num_epochs=50, print_every=100)

    return gen, dic

In [5]:
g, d = run_adv(0.001, 32, chkpnt_file='chkpntWES.pt')

AssertionError: CUDA is not available