In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.mnist_my_models import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

In [2]:
## hyper-parameters
random_seed = 2020
critic_iter = 3
critic_iter_d = 10
epochs = 30
std = 1.0
learning_rate = 2e-4
weight_decay = 0.0
batch_size = 250
nc = 1
z_dim = 5
structure_dim = 128
d_dim = 512
g_dim = 512
lambda_mmd = 10.0
lambda_gp = 5.0
lambda_power = 1.0

In [3]:
## Training
# torch.manual_seed(random_seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # check if gpu is available

## initialize models
netI = I_MNIST(num_classes=z_dim)
# netI         = models.resnet18(pretrained=False, num_classes=z_dim)
# netI.conv1   = nn.Conv2d(nc, netI.conv1.weight.shape[0], 3, 1, 1, bias=False) 
# netI.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)
netI = netI.to(device)
netG = G_MNIST(nz=z_dim)
netG = netG.to(device)
netD = D_MNIST()
netD = netD.to(device)

## set up optimizers
optim_I = optim.Adam(netI.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_G = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_D = optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.5, 0.999), 
                     weight_decay=weight_decay)

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_gen    = dsets.MNIST(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.MNIST(root="./datasets",train=False, transform=transform, download=True)
train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)


## initial empty lists for training progress
primal_loss_GI = []
dual_loss_GI = []
# primal_loss = []
# dual_loss = []
# primal_loss_z = []
loss_mmd = []
gp = []
re = []

In [4]:
# ************************
# *** iWGANs Algorithm ***
# ************************
z_sample = torch.randn(batch_size, z_dim)
z_sample = z_sample.to(device)

for epoch in range(epochs):
    data = iter(train_loader)
    # 1. Update G, I network
    # (1). Set up parameters of G, I to update
    #      set up parameters of D to freeze
    for p in netD.parameters():
        p.requires_grad = False
    for p in netI.parameters():
        p.requires_grad = True
    for p in netG.parameters():
        p.requires_grad = True
    # (2). Update G and I
    for _ in range(critic_iter):
        images, _ = next(data)
        images = images.view(batch_size, 784)
        real_data = images.to(device)
        noise = torch.randn(batch_size, z_dim)
        noise = noise.to(device)
        fake_data = netG(noise)
        netI.zero_grad()
        netG.zero_grad()
        cost_GI = GI_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.to(device)
        z = torch.randn(batch_size, z_dim)
        z = z.to(device)
        z_hat = netI(x)
        mmd = mmd_penalty(z_hat, z, kernel="IMQ")
        primal_cost = cost_GI + lambda_mmd * mmd
        primal_cost.backward()
        optim_I.step()
        optim_G.step()
#     print('GI: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    print('GI: '+str(cost_GI.cpu().item()))
    # (3). Append primal and dual loss to list
    primal_loss_GI.append(primal(netI, netG, netD, real_data).cpu().item())
    dual_loss_GI.append(dual(netI, netG, netD, real_data, fake_data).cpu().item())
    # 2. Update D network
    # (1). Set up parameters of D to update
    #      set up parameters of G, I to freeze
    for p in netD.parameters():
        p.requires_grad = True
    for p in netI.parameters():
        p.requires_grad = False
    for p in netG.parameters():
        p.requires_grad = False
    # (2). Update D
    for _ in range(critic_iter_d):
        images, labels = next(data)
        images = images.view(batch_size, 784)
        real_data = images.to(device)
        noise = torch.randn(batch_size, z_dim)
        noise = noise.to(device)
        fake_data = netG(noise)
        netD.zero_grad()
        cost_D = D_loss(netI, netG, netD, real_data, fake_data)
        images, _ = next(data)
        x = images.view(batch_size, 784)
        x = x.to(device)
        z = torch.randn(batch_size, z_dim)
        z = z.to(device)
        z_hat = netI(x)
        gp_D = gradient_penalty_dual(x.data, z.data, netD, netG, netI)
        dual_cost = cost_D + lambda_gp * gp_D
        dual_cost.backward()
        optim_D.step()
        loss_mmd.append(mmd.cpu().item())
#     print('D: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    print('D: '+str(cost_D.cpu().item()))
    gp.append(gp_D.cpu().item())
    re.append(primal(netI, netG, netD, real_data).cpu().item())
    if (epoch+1) % 5 == 0:
        torchvision.utils.save_image(netG(z_sample).view(-1, 1, 28, 28)[:25], "./outputs/MNIST/fake" + str(epoch) + ".png", nrow=5)
        torchvision.utils.save_image(images[:25], "./outputs/MNIST/real.png", nrow=5)

  return F.conv_transpose2d(


GI: 25.71761131286621
D: -0.38953298330307007
GI: 19.711950302124023
D: 0.010419532656669617
GI: 17.494157791137695
D: -0.06165475770831108
GI: 16.053998947143555
D: -0.045342255383729935
GI: 15.201705932617188
D: -0.14945997297763824
GI: 15.73883056640625
D: -0.48592621088027954
GI: 14.823787689208984
D: -0.49539533257484436
GI: 14.01698112487793
D: -0.4445992410182953
GI: 13.336373329162598
D: -0.5425910949707031
GI: 13.604642868041992
D: -0.4207318425178528
GI: 12.939006805419922
D: -0.3854500949382782
GI: 12.672863960266113
D: -0.28831830620765686
GI: 12.528281211853027
D: -0.4452279806137085
GI: 12.64477252960205
D: -0.009348268620669842
GI: 12.417774200439453
D: -0.22624725103378296
GI: 12.66002082824707
D: -0.19193357229232788
GI: 11.558305740356445
D: -0.04383554309606552
GI: 11.75400161743164
D: 0.2237234264612198
GI: 11.378434181213379
D: 0.2743649482727051
GI: 11.095952033996582
D: 0.147584930062294
GI: 11.280893325805664
D: 0.21830955147743225
GI: 11.2061185836792
D: 0.1509

In [5]:
# print(primal_loss_GI)
# print(dual_loss_GI)
# print(loss_mmd)
# print(gp)
# print(re)