In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.mnist import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

In [2]:
## hyper-parameters
random_seed = 2020
critic_iter = 10
critic_iter_d = 10
epochs = 100
std = 0.1
learning_rate = 1e-4
weight_decay = 0.0
batch_size = 250
z_dim = 8
structure_dim = 128
d_dim = 512
g_dim = 512
lambda_mmd = 10.0
lambda_gp = 5.0
lambda_power = 1.0

In [3]:
## Training
# torch.manual_seed(random_seed)
cuda_available = torch.cuda.is_available()  # check if gpu is available

## initialize models
netI = I_MNIST()
netI = netI.cuda() if cuda_available else netI
netG = G_MNIST()
netG = netG.cuda() if cuda_available else netG
netD = D_MNIST()
netD = netD.cuda() if cuda_available else netD

## set up optimizers
optim_I = optim.Adam(netI.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_G = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_D = optim.Adam(netD.parameters(), lr=learning_rate * 10, betas=(0.5, 0.999), 
                     weight_decay=weight_decay)

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_gen    = dsets.MNIST(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.MNIST(root="./datasets",train=False, transform=transform, download=True)
train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)


## initial empty lists for training progress
primal_loss_GI = []
dual_loss_GI = []
# primal_loss = []
# dual_loss = []
# primal_loss_z = []
loss_mmd = []
gp = []
re = []

In [4]:
# ************************
# *** iWGANs Algorithm ***
# ************************
z_sample = torch.randn(batch_size, z_dim)
z_sample = z_sample.cuda() if cuda_available else z_sample

for epoch in range(epochs):
    data = iter(train_loader)
    # 1. Update G, I network
    # (1). Set up parameters of G, I to update
    #      set up parameters of D to freeze
    for p in netD.parameters():
        p.requires_grad = False
    for p in netI.parameters():
        p.requires_grad = True
    for p in netG.parameters():
        p.requires_grad = True
    # (2). Update G and I
    for _ in range(critic_iter):
        images, _ = next(data)
        images = images.view(batch_size, 784)
        real_data = images.cuda() if cuda_available else images
        noise = torch.randn(batch_size, z_dim)
        noise = noise.cuda() if cuda_available else noise
        fake_data = netG(noise)
        netI.zero_grad()
        netG.zero_grad()
        cost_GI = GI_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.cuda() if cuda_available else x
        z = torch.randn(batch_size, z_dim)
        z = z.cuda() if cuda_available else z
        z_hat = netI(x)
        mmd = mmd_penalty(z_hat, z, kernel="IMQ")
        primal_cost = cost_GI + lambda_mmd * mmd
        primal_cost.backward()
        optim_I.step()
        optim_G.step()
    print('GI: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    # (3). Append primal and dual loss to list
    primal_loss_GI.append(primal(netI, netG, netD, real_data).cpu().item())
    dual_loss_GI.append(dual(netI, netG, netD, real_data, fake_data).cpu().item())
    # 2. Update D network
    # (1). Set up parameters of D to update
    #      set up parameters of G, I to freeze
    for p in netD.parameters():
        p.requires_grad = True
    for p in netI.parameters():
        p.requires_grad = False
    for p in netG.parameters():
        p.requires_grad = False
    # (2). Update D
    for _ in range(critic_iter_d):
#         _data = next(data)
#         real_data = torch.Tensor(_data)
#         real_data = real_data.cuda() if cuda_available else real_data
        images, labels = next(data)
        images = images.view(batch_size, 784)
        real_data = images.cuda() if cuda_available else images
        noise = torch.randn(batch_size, z_dim)
        noise = noise.cuda() if cuda_available else noise
        fake_data = netG(noise)
        netD.zero_grad()
        cost_D = D_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.cuda() if cuda_available else x
        z = torch.randn(batch_size, z_dim)
        z = z.cuda() if cuda_available else z
        z_hat = netI(x)
        gp_D = gradient_penalty_dual(x.data, z.data, netD, netG, netI)
        dual_cost = cost_D + lambda_gp * gp_D
        dual_cost.backward()
        optim_D.step()
        loss_mmd.append(mmd.cpu().item())
#         print(gp_D.cpu().item())
    print('D: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    gp.append(gp_D.cpu().item())
    re.append(primal(netI, netG, netD, real_data).cpu().item())
    if (epoch+1) % 1 == 0:
        torchvision.utils.save_image(netG(z_sample).view(-1, 1, 28, 28)[:25], "./outputs/MNIST/fake" + str(epoch) + ".png", nrow=5)
        torchvision.utils.save_image(images[:25], "./outputs/MNIST/real.png", nrow=5)

  return F.conv_transpose2d(


GI: 31.49949836730957
D: 31.500247955322266
GI: 28.112903594970703
D: 28.106735229492188
GI: 27.7220516204834
D: 27.701404571533203
GI: 27.39789581298828
D: 27.427082061767578
GI: 27.248533248901367
D: 27.222490310668945
GI: 27.159019470214844
D: 27.18914222717285
GI: 27.106067657470703
D: 27.09993553161621
GI: 27.051029205322266
D: 27.068084716796875
GI: 27.05169105529785
D: 27.028879165649414
GI: 27.01833724975586
D: 27.010408401489258
GI: 27.020153045654297
D: 26.967164993286133
GI: 26.983808517456055
D: 27.015954971313477
GI: 26.978130340576172
D: 27.003190994262695
GI: 26.983919143676758
D: 26.974637985229492
GI: 26.962158203125
D: 26.969371795654297
GI: 26.949581146240234
D: 26.975534439086914
GI: 26.93528175354004
D: 26.939306259155273
GI: 26.934711456298828
D: 26.937637329101562
GI: 26.948707580566406
D: 26.923080444335938
GI: 26.954662322998047
D: 26.93804359436035
GI: 26.946836471557617
D: 26.92667007446289
GI: 26.93946075439453
D: 26.973037719726562
GI: 26.922712326049805
D:

In [5]:
# print(primal_loss_GI)
# print(dual_loss_GI)
# print(loss_mmd)
# print(gp)
# print(re)