In [1]:
%load_ext autoreload
%autoreload 2

import os, argparse, sys 
sys.path.append('..')
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision.utils as vutils
from ALI_BiGAN.model import Generator, Encoder, Discriminator
from ALI_BiGAN.util import weights_init, log_sum_exp, get_log_odds

batch_size = 100
lr = 1e-4
latent_size = 256
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', help='cifar10 | svhn', default="cifar10")
parser.add_argument('--dataroot', help='path to dataset', default="../data/ALI_BiGAN/cifar10")
parser.add_argument('--save_model_dir', default="../models/ALI_BiGAN")
parser.add_argument('--save_image_dir', default="../data/ALI_BiGAN")
opt = parser.parse_args([])

In [2]:
if opt.dataset == 'svhn':
    train_loader = torch.utils.data.DataLoader(
        datasets.SVHN(root=opt.dataroot, split='extra', download=True,
                      transform=transforms.Compose([
                          transforms.ToTensor()
                      ])),
        batch_size=batch_size, shuffle=True)
elif opt.dataset == 'cifar10':
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root=opt.dataroot, train=True, download=True,
                      transform=transforms.Compose([
                          transforms.ToTensor()
                      ])),
        batch_size=batch_size, shuffle=True)
else:
    raise NotImplementedError

E = Encoder(latent_size, True).to(device).apply(weights_init)
G = Generator(latent_size).to(device).apply(weights_init)
D = Discriminator(latent_size, 0.2, 1).to(device).apply(weights_init)
optG = optim.Adam([{'params' : E.parameters()}, {'params' : G.parameters()}], lr=lr, betas=(0.5,0.999))
optD = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
criterion = nn.BCELoss()

Files already downloaded and verified


In [4]:
for epoch in range(num_epochs):
    i = 0
    for (d_real, target) in train_loader:
        real_label = torch.ones(batch_size).to(device)
        fake_label = torch.zeros(batch_size).to(device)
        d_real = d_real.to(device)
        noise1 = torch.Tensor(d_real.size()).normal_(0, 0.1 * (num_epochs - epoch) / num_epochs).to(device)
        noise2 = torch.Tensor(d_real.size()).normal_(0, 0.1 * (num_epochs - epoch) / num_epochs).to(device)
        if epoch == 0 and i == 0:
            G.output_bias.data = get_log_odds(d_real)
        if d_real.size()[0] != batch_size:
            continue
        z_fake = torch.randn(batch_size, latent_size, 1, 1).to(device)
        d_fake = G(z_fake)
        z_real, _, _, _ = E(d_real)
        z_real = z_real.view(batch_size, -1)

        mu, log_sigma = z_real[:, :latent_size], z_real[:, latent_size:]
        sigma = torch.exp(log_sigma)
        epsilon = torch.randn(batch_size, latent_size).to(device)
        output_z = mu + epsilon * sigma
        output_real, _ = D(d_real + noise1, output_z.view(batch_size, latent_size, 1, 1))
        output_fake, _ = D(d_fake + noise2, z_fake)
        lossD = criterion(output_real, real_label) + criterion(output_fake, fake_label)
        lossG = criterion(output_fake, real_label) + criterion(output_real, fake_label)

        if lossG.item() < 3.5:
            optD.zero_grad()
            lossD.backward(retain_graph=True)
            optD.step()
        optG.zero_grad()
        lossG.backward()
        optG.step()

        if i % 100 == 0:
            print("Epoch :", epoch, "Iter :", i, "D Loss :", lossD.item(), "G loss :", lossG.item(),
                  "D(x) :", output_real.mean().item(), "D(G(x)) :", output_fake.mean().item())
        if i % 100 == 0:
            vutils.save_image(d_fake.cpu().data[:16, ], './%s/fake.png' % (opt.save_image_dir))
            vutils.save_image(d_real.cpu().data[:16, ], './%s/real.png'% (opt.save_image_dir))
        i += 1

    if epoch % 10 == 0:
        torch.save(G.state_dict(), './%s/G_epoch_%d.pth' % (opt.save_model_dir, epoch))
        torch.save(E.state_dict(), './%s/E_epoch_%d.pth' % (opt.save_model_dir, epoch))
        torch.save(D.state_dict(), './%s/D_epoch_%d.pth' % (opt.save_model_dir, epoch))
        vutils.save_image(d_fake.cpu().data[:16, ], './%s/fake_%d.png' % (opt.save_image_dir, epoch))

Epoch : 0 Iter : 0 D Loss : 0.884990394115448 G loss : 2.8410940170288086 D(x) : 0.6036473512649536 D(G(x)) : 0.22583098709583282
Epoch : 0 Iter : 100 D Loss : 1.146598219871521 G loss : 2.5247802734375 D(x) : 0.6244033575057983 D(G(x)) : 0.379229873418808
Epoch : 0 Iter : 200 D Loss : 1.4644286632537842 G loss : 2.104116439819336 D(x) : 0.726813793182373 D(G(x)) : 0.6078115701675415
Epoch : 0 Iter : 300 D Loss : 0.9332486391067505 G loss : 2.744138717651367 D(x) : 0.7125073075294495 D(G(x)) : 0.37302741408348083
Epoch : 0 Iter : 400 D Loss : 1.1716604232788086 G loss : 2.4608097076416016 D(x) : 0.5363125205039978 D(G(x)) : 0.2998923063278198
Epoch : 1 Iter : 0 D Loss : 0.888060450553894 G loss : 2.695221185684204 D(x) : 0.7542799115180969 D(G(x)) : 0.40297073125839233
Epoch : 1 Iter : 100 D Loss : 0.769252598285675 G loss : 3.285524368286133 D(x) : 0.6518683433532715 D(G(x)) : 0.19086799025535583
Epoch : 1 Iter : 200 D Loss : 0.9318257570266724 G loss : 2.9698657989501953 D(x) : 0.551