In [None]:
import torch
from torch import nn
from torch.autograd import  Variable
from torch import optim
from classify_svhn import get_data_loader

import GAN  # needed to allow the reload
import importlib
importlib.reload(GAN)
from GAN import Generator, Discriminator

import matplotlib.pyplot as plt

lr = 0.0001
betas = (0, 0.9)
batch_size = 64
z_size = 100
im_size = 32
n_critic = 5
num_epoch = 10

if torch.cuda.is_available():
    print("Using cuda")
    device = torch.device("cuda")
else:
    print("Running on cpu")
    device = torch.device("cpu")
    
G = Generator(z_size).to(device)
D = Discriminator(im_size, device).to(device)

g_optim = optim.Adam(G.parameters(), lr=lr, betas=betas)
d_optim = optim.Adam(D.parameters(), lr=lr, betas=betas)

In [None]:
def showImg(x):
    x = x.permute(1, 2, 0)
    plt.imshow((x.numpy() * 0.5) + 0.5)
    
train_loader, valid_loader, test_loader = get_data_loader("svhn", batch_size)

# Show an image
real_sample, target = next(iter(train_loader))

showImg(real_sample[0])

In [None]:
def train(loader):
    d_train_loss = 0
    
    for epoch in range(num_epoch):
        
        for data_idx, real_sample in enumerate(loader):
            G.train()
            D.train()
            
            step = epoch * len(loader) + data_idx + 1
            
            # Train more the dicriminator
            d_optim.zero_grad()
            g_optim.zero_grad()
    
            z = Variable(torch.randn(batch_size, z_size, device=device))
            
            fake_sample = G(z)
            real_sample = real_sample.to(device)
            
            d_loss = D.loss(real_sample, fake_sample)
            d_loss.backward()
            d_optim.step()
    
            if step % n_critic == 0:
                # Train the generator
                d_optim.zero_grad()
                g_optim.zero_grad()
            
                z = Variable(torch.randn(batch_size, z_size, device=device))
            
                fake_sample = G(z)
                fake_result = D(fake_sample)
                g_loss = G.loss(fake_result)
                g_loss.backward()
                g_optim.step()
                
        print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, d_train_loss / len(loader.dataset)))
        
        createSample(G)

In [None]:
def createSample(generator):
    generator.eval()
    
    z = Variable(torch.randn(z_size, device=device))
    im = generator(z)
    showImg(im)

In [None]:
train(train_loader)