In [1]:
## load necessary modules
import argparse
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

from utils.tools import *
from utils.losses import *
from models.mnist import *

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

In [2]:
## hyper-parameters
random_seed = 2020
critic_iter = 3
critic_iter_d = 10
epochs = 30
std = 0.1
learning_rate = 2e-4
weight_decay = 0.1
batch_size = 250
nc = 1
z_dim = 5
d_dim = 512
g_dim = 512
lambda_mmd = 10.0
lambda_gp = 0.1
# lambda_power = 1.0

In [3]:
## Training
# torch.manual_seed(random_seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # check if gpu is available

## initialize models
netI = I_MNIST(num_classes=z_dim)
netG = G_MNIST(nz=z_dim)
netD = D_MNIST(nz=z_dim)
netI = netI.to(device)
netG = netG.to(device)
netD = netD.to(device)
netI = nn.DataParallel(netI)
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)

## set up optimizers
optim_I = optim.Adam(netI.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_G = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optim_D = optim.Adam(netD.parameters(), lr=learning_rate * 5, betas=(0.5, 0.999), 
                     weight_decay=weight_decay)

## load datasets
# train_gen, dev_gen, test_gen = load(batch_size, batch_size)
# data = inf_train_gen_mnist(train_gen)
transform    = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_gen    = dsets.MNIST(root="./datasets",train=True, transform=transform, download=True)
test_gen     = dsets.MNIST(root="./datasets",train=False, transform=transform, download=True)
train_loader = DataLoader(train_gen, batch_size=batch_size, shuffle=True)


## initial empty lists for training progress
primal_loss_GI = []
dual_loss_GI = []
# primal_loss = []
# dual_loss = []
# primal_loss_z = []
loss_mmd = []
gp = []
re = []

In [4]:
# ************************
# *** iWGANs Algorithm ***
# ************************
z_sample = torch.randn(batch_size, z_dim)
z_sample = z_sample.to(device)

for epoch in range(epochs):
    data = iter(train_loader)
    # 1. Update G, I network
    # (1). Set up parameters of G, I to update
    #      set up parameters of D to freeze
    for p in netD.parameters():
        p.requires_grad = False
    for p in netI.parameters():
        p.requires_grad = True
    for p in netG.parameters():
        p.requires_grad = True
    # (2). Update G and I
    for _ in range(critic_iter):
        images, _ = next(data)
        x = images.view(batch_size, 784).to(device)
        z = torch.randn(batch_size, z_dim).to(device)
        fake_z = netI(x)
        fake_x = netG(z)
        netI.zero_grad()
        netG.zero_grad()
        cost_GI = GI_loss(netI, netG, netD, z, fake_z)
        images, _ = next(data)
        x = images.view(batch_size, 784).to(device)
        z = torch.randn(batch_size, z_dim).to(device)
        fake_z = netI(x)
        mmd = mmd_penalty(fake_z, z, kernel="IMQ")
        primal_cost = cost_GI + lambda_mmd * mmd
        primal_cost.backward()
        optim_I.step()
        optim_G.step()
#     print('GI: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    print('GI: '+str(cost_GI.cpu().item()))
    # (3). Append primal and dual loss to list
    primal_loss_GI.append(primal(netI, netG, netD, z).cpu().item())
    dual_loss_GI.append(dual(netI, netG, netD, z, fake_z).cpu().item())
    # 2. Update D network
    # (1). Set up parameters of D to update
    #      set up parameters of G, I to freeze
    for p in netD.parameters():
        p.requires_grad = True
    for p in netI.parameters():
        p.requires_grad = False
    for p in netG.parameters():
        p.requires_grad = False
    # (2). Update D
    for _ in range(critic_iter_d):
        images, y = next(data)
        x = images.view(batch_size, 784).to(device)
        z = torch.randn(batch_size, z_dim).to(device)
        fake_z = netI(x)
        fake_x = netG(z)
        netD.zero_grad()
        cost_D = D_loss(netI, netG, netD, z, fake_z)
        images, y = next(data)
        x = images.view(batch_size, 784)
        x = x.to(device)
        z = torch.randn(batch_size, z_dim)
        z = z.to(device)
        fake_z = netI(x)
        gp_D = gradient_penalty_dual(x.data, z.data, netD, netG, netI)
        dual_cost = cost_D + lambda_gp * gp_D
        dual_cost.backward()
        optim_D.step()
        loss_mmd.append(mmd.cpu().item())
#     print('D: '+str(primal(netI, netG, netD, real_data).cpu().item()))
    print('D: '+str(cost_D.cpu().item()))
    gp.append(gp_D.cpu().item())
    re.append(primal(netI, netG, netD, z).cpu().item())
    if (epoch+1) % 5 == 0:
        torchvision.utils.save_image(netG(z_sample).view(-1, 1, 28, 28)[:25], "./outputs/MNIST/fake" + str(epoch) + ".png", nrow=5)
        torchvision.utils.save_image(images[:25], "./outputs/MNIST/real.png", nrow=5)

GI: 2.327436685562134
D: -0.003334291512146592
GI: 1.6014949083328247
D: -0.0024953619576990604
GI: 0.9863755106925964
D: -0.0003815331729128957
GI: 0.8945746421813965
D: -0.00010921848297584802
GI: 0.7781262397766113
D: -0.0011070163454860449
GI: 0.7441544532775879
D: 0.0002312831929884851
GI: 0.7735552191734314
D: 0.00042296017636545
GI: 0.7207300662994385
D: -3.767287853406742e-05
GI: 0.6404482126235962
D: -0.0001131862445618026
GI: 0.6236074566841125
D: 8.521151903551072e-05
GI: 0.5932885408401489
D: -5.1581504521891475e-05
GI: 0.6377440690994263
D: -9.527993825031444e-05
GI: 0.5538505911827087
D: 5.9613586927298456e-05
GI: 0.6012016534805298
D: -7.341147011175053e-06
GI: 0.5540521144866943
D: 1.8147469745599665e-05
GI: 0.6321923732757568
D: 2.563536327215843e-05
GI: 0.6001390218734741
D: 1.7229796867468394e-05
GI: 0.5066551566123962
D: 4.175901722192066e-06
GI: 0.5242589116096497
D: -2.236187538073864e-05
GI: 0.5274445414543152
D: -7.328987408072862e-07
GI: 0.6033384203910828
D: -

In [5]:
# print(primal_loss_GI)
# print(dual_loss_GI)
# print(loss_mmd)
# print(gp)
# print(re)

[2.0107905864715576, 1.3733643293380737, 1.0102617740631104, 0.7813577055931091, 0.8120712041854858, 0.7916543483734131, 0.751029908657074, 0.7200213670730591, 0.6411064863204956, 0.6568925976753235, 0.628759503364563, 0.5745697021484375, 0.628494918346405, 0.5533111691474915, 0.5810304284095764, 0.5128655433654785, 0.5417718291282654, 0.6250351667404175, 0.5309680104255676, 0.5583562254905701, 0.5390822887420654, 0.5868348479270935, 0.5538120865821838, 0.554635763168335, 0.4936140775680542, 0.489455908536911, 0.4334213435649872, 0.44873300194740295, 0.43881747126579285, 0.4886133074760437]
