In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.mnist import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

In [2]:
## hyper-parameters
random_seed = 2020
critic_iter = 10
critic_iter_d = 10
epochs = 100
std = 0.1
learning_rate = 1e-3
weight_decay = 0.0
batch_size = 500
z_dim = 8
structure_dim = 128
d_dim = 512
g_dim = 512
lambda_mmd = 10.0
lambda_gp = 5.0
lambda_power = 1.0

In [3]:
## Training
# torch.manual_seed(random_seed)
cuda_available = torch.cuda.is_available()  # check if gpu is available

## initialize models
netI = I_MNIST()
netI = netI.cuda() if cuda_available else netI
netG = G_MNIST()
netG = netG.cuda() if cuda_available else netG
netD = D_MNIST()
netD = netD.cuda() if cuda_available else netD

## set up optimizers
optim_I = optim.Adam(netI.parameters(), lr=learning_rate)
optim_G = optim.Adam(netG.parameters(), lr=learning_rate)
optim_D = optim.Adam(netD.parameters(), lr=learning_rate, weight_decay=weight_decay)

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_gen    = dsets.MNIST(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.MNIST(root="./datasets",train=False, transform=transform, download=True)
train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)


## initial empty lists for training progress
primal_loss_GI = []
dual_loss_GI = []
# primal_loss = []
# dual_loss = []
# primal_loss_z = []
loss_mmd = []
gp = []
re = []

In [None]:
# ************************
# *** iWGANs Algorithm ***
# ************************
z_sample = torch.randn(batch_size, z_dim)
z_sample = z_sample.cuda() if cuda_available else z_sample

for epoch in range(epochs):
    data = iter(train_loader)
    # 1. Update G, I network
    # (1). Set up parameters of G, I to update
    #      set up parameters of D to freeze
    for p in netD.parameters():
        p.requires_grad = False
    for p in netI.parameters():
        p.requires_grad = True
    for p in netG.parameters():
        p.requires_grad = True
    # (2). Update G and I
    for _ in range(critic_iter):
        images, _ = next(data)
        images = images.view(batch_size, 784)
        real_data = images.cuda() if cuda_available else images
        noise = torch.randn(batch_size, z_dim)
        noise = noise.cuda() if cuda_available else noise
        fake_data = netG(noise)
        netI.zero_grad()
        netG.zero_grad()
        cost_GI = GI_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.cuda() if cuda_available else x
        z = torch.randn(batch_size, z_dim)
        z = z.cuda() if cuda_available else z
        z_hat = netI(x)
        mmd = mmd_penalty(z_hat, z, kernel="IMQ")
        primal_cost = cost_GI + lambda_mmd * mmd
        primal_cost.backward()
        optim_I.step()
        optim_G.step()
    print('GI: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    # (3). Append primal and dual loss to list
    primal_loss_GI.append(primal(netI, netG, netD, real_data).cpu().item())
    dual_loss_GI.append(dual(netI, netG, netD, real_data, fake_data).cpu().item())
    # 2. Update D network
    # (1). Set up parameters of D to update
    #      set up parameters of G, I to freeze
    for p in netD.parameters():
        p.requires_grad = True
    for p in netI.parameters():
        p.requires_grad = False
    for p in netG.parameters():
        p.requires_grad = False
    # (2). Update D
    for _ in range(critic_iter_d):
#         _data = next(data)
#         real_data = torch.Tensor(_data)
#         real_data = real_data.cuda() if cuda_available else real_data
        images, labels = next(data)
        images = images.view(batch_size, 784)
        real_data = images.cuda() if cuda_available else images
        noise = torch.randn(batch_size, z_dim)
        noise = noise.cuda() if cuda_available else noise
        fake_data = netG(noise)
        netD.zero_grad()
        cost_D = D_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.cuda() if cuda_available else x
        z = torch.randn(batch_size, z_dim)
        z = z.cuda() if cuda_available else z
        z_hat = netI(x)
        gp_D = gradient_penalty_dual(x.data, z.data, netD, netG, netI)
        dual_cost = cost_D + lambda_gp * gp_D
        dual_cost.backward()
        optim_D.step()
        loss_mmd.append(mmd.cpu().item())
#         print(gp_D.cpu().item())
    print('D: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    gp.append(gp_D.cpu().item())
    re.append(primal(netI, netG, netD, real_data).cpu().item())
    if (epoch+1) % 1 == 0:
        torchvision.utils.save_image(netG(z_sample).view(-1, 1, 28, 28)[:25], "./outputs/MNIST/fake.png", nrow=5)
        torchvision.utils.save_image(images[:25], "./outputs/MNIST/real.png", nrow=5)

  return F.conv_transpose2d(


GI: 27.129621505737305
D: 27.1368408203125
GI: 26.942691802978516
D: 26.946292877197266
GI: 26.938915252685547
D: 26.93610191345215
GI: 26.950302124023438
D: 26.939496994018555
GI: 26.946680068969727
D: 26.931333541870117
GI: 26.943883895874023
D: 26.937591552734375
GI: 26.936716079711914
D: 26.941329956054688
GI: 26.943450927734375
D: 26.9330997467041
GI: 26.925891876220703
D: 26.945804595947266
GI: 26.927154541015625
D: 26.936641693115234
GI: 26.907766342163086
D: 26.931110382080078
GI: 26.944108963012695
D: 26.918224334716797
GI: 26.942060470581055
D: 26.927263259887695
GI: 26.938873291015625
D: 26.959138870239258
GI: 26.937654495239258
D: 26.92211151123047
GI: 26.923307418823242
D: 26.960264205932617
GI: 26.936552047729492
D: 26.9279727935791
GI: 26.9129581451416
D: 26.9408016204834
GI: 26.937633514404297
D: 26.923248291015625
GI: 26.91982078552246
D: 26.928829193115234
GI: 26.963851928710938
D: 26.93140983581543
GI: 26.926429748535156
D: 26.933063507080078
GI: 26.90867042541504
D:

In [None]:
# print(primal_loss_GI)
# print(dual_loss_GI)
# print(loss_mmd)
# print(gp)
# print(re)