# Imports

In [None]:
import os
import torch as to
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import make_grid
import torchvision.datasets as data
import numpy as np
import tensorflow as tf

In [None]:
class Configuration:
    learning_rate = 0.0002
    noise_dim = 100 #Input Noise dimension
    img_size = 64
    img_c = 1 #Input Image channels
    gen_c = 128 #generator channels
    disc_c = 128 #Discriminator channels
    beta = .5
    batch_size = 512
    number_epochs = 100
    workers = 0 # Number of processors to be used by program
    gpu = True
    
opt = Configuration()

# Data loading and preprocessing

In [None]:
transforms = transforms.Compose([transforms.Resize(opt.img_size)
                                ,transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
datasets= data.MNIST(root='mnist/',train = True , transform = transforms , download = True)
dataloader = to.utils.data.DataLoader(datasets,opt.batch_size,shuffle = True, num_workers = opt.workers)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

# Model Architecture

In [None]:
generator = nn.Sequential(
    nn.ConvTranspose2d(100,opt.gen_c*8,4,1,0),
    nn.BatchNorm2d(opt.gen_c*8),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(opt.gen_c*8,opt.gen_c*4,4,2,1),
    nn.BatchNorm2d(opt.gen_c*4),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(opt.gen_c*4,opt.gen_c*2,4,2,1),
    nn.BatchNorm2d(opt.gen_c*2),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(opt.gen_c*2,opt.gen_c,4,2,1),
    nn.BatchNorm2d(opt.gen_c),
    nn.ReLU(True),
    
    nn.ConvTranspose2d(opt.gen_c,1,4,2,1),
    nn.Tanh()
    )

# Discriminator
discriminator = nn.Sequential(
    nn.Conv2d(1,opt.disc_c,4,2,1),
    nn.LeakyReLU(0.2,inplace = True),
    
    nn.Conv2d(opt.disc_c,opt.disc_c*2,4,2,1),
    nn.BatchNorm2d(opt.gen_c*2),
    nn.LeakyReLU(0.2,inplace = True),
    
    nn.Conv2d(opt.disc_c*2,opt.disc_c*4,4,2,1),
    nn.BatchNorm2d(opt.gen_c*4),
    nn.LeakyReLU(0.2,inplace = True),
    
    nn.Conv2d(opt.disc_c*4,opt.disc_c*8,4,2,1),
    nn.BatchNorm2d(opt.gen_c*8),
    nn.LeakyReLU(0.2,inplace = True),
    
    nn.Conv2d(opt.disc_c*8,1,4,1,0)
)

def weight_init(m):
    class_name=m.__class__.__name__
    if class_name.find('Conv')!=-1:
        m.weight.data.normal_(0,0.02)
    elif class_name.find('Norm')!=-1:
        m.weight.data.normal_(1.0,0.02)


generator.apply(weight_init)
discriminator.apply(weight_init)
generator.cuda()
discriminator.cuda()

In [None]:
def plot_loss(d_losses, g_losses, num_epoch, save=False, save_dir='MNIST_DCGAN_results/', show=False):
    fig, ax = plt.subplots()
    ax.set_xlim(0, num_epoch)
    ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses))*1.1)
    plt.xlabel('Epoch {0}'.format(num_epoch + 1))
    plt.ylabel('Loss values')
    plt.plot(d_losses, label='Discriminator')
    plt.plot(g_losses, label='Generator')
    plt.legend()

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'MNIST_DCGAN_losses_epoch_{:d}'.format(num_epoch + 1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()

def plot_result(generator, noise, num_epoch, save=False, save_dir='MNIST_DCGAN_results/', show=False, fig_size=(5, 5)):
    generator.eval()

    noise = Variable(noise.cuda())
    gen_image = generator(noise)
    gen_image = denorm(gen_image)

    generator.train()

    n_rows = np.sqrt(noise.size()[0]).astype(np.int32)
    n_cols = np.sqrt(noise.size()[0]).astype(np.int32)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=fig_size)
    for ax, img in zip(axes.flatten(), gen_image):
        ax.axis('off')
        ax.set_adjustable('box-forced')
        ax.imshow(img.cpu().data.view(opt.img_size, opt.img_size).numpy(), cmap='gray', aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)
    title = 'Epoch {0}'.format(num_epoch+1)
    fig.text(0.5, 0.04, title, ha='center')

    # save figure
    if save:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        save_fn = save_dir + 'MNIST_DCGAN_epoch_{:d}'.format(num_epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()



# Optimization

In [None]:
criterion = nn.BCEWithLogitsLoss()
opt_Gen = Adam(generator.parameters(),lr = opt.learning_rate, betas = (0.5, 0.999))
opt_Disc = Adam(discriminator.parameters(),lr = opt.learning_rate, betas = (0.5, 0.999))

# Training

In [None]:

D_avg_losses = []
G_avg_losses = []
num_test_samples = 5*5
fixed_noise = to.randn(num_test_samples, 100).view(-1,100, 1, 1)
for epoch in range(opt.number_epochs):
    D_losses = []
    G_losses = []
    for i,(images, labels) in enumerate(dataloader):
        minibatch = images.size()[0]
        real_images = Variable(images.cuda()) 
        real_labels = Variable(to.ones(minibatch).cuda())
        fake_labels = Variable(to.zeros(minibatch).cuda())
        ##Train discriminator
        #First with real data 
        D_real_Decision = discriminator(real_images).squeeze()   
        D_real_loss = criterion(D_real_Decision,real_labels)        
        #with fake data        
        z_ = to.randn(minibatch,100 ).view(-1, 100, 1, 1)
        z_ = Variable(z_.cuda())
        gen_images = generator(z_)        
        D_fake_decision = discriminator(gen_images).squeeze()
        D_fake_loss = criterion(D_fake_decision,fake_labels)
        
        ## back propagation
        
        D_loss = D_real_loss + D_fake_loss
        discriminator.zero_grad()
        D_loss.backward()
        opt_Disc.step()
        
        # train generator
        z_ = to.randn(minibatch,100 ).view(-1, 100, 1, 1)
        z_ = Variable(z_.cuda())
        gen_images = generator(z_)
        
        D_fake_decisions = discriminator(gen_images).squeeze()
        G_loss = criterion(D_fake_decisions,real_labels)
        
        discriminator.zero_grad()
        generator.zero_grad()
        G_loss.backward()
        opt_Gen.step()
        
        #loss values
        D_losses.append(D_loss.data[0])
        G_losses.append(G_loss.data[0])
        
        print('Epoch [%d/%d], Step [%d/%d], D_loss: %.4f, G_loss: %.4f'
              % (epoch+1, opt.number_epochs, i+1, len(dataloader), D_loss.data[0], G_loss.data[0]))

    
    D_avg_loss = to.mean(to.FloatTensor(D_losses))
    G_avg_loss = to.mean(to.FloatTensor(G_losses))
    D_avg_losses.append(D_avg_loss)
    G_avg_losses.append(G_avg_loss)
    
    plot_loss(D_avg_losses, G_avg_losses, epoch, save=True)
    
    plot_result(generator, fixed_noise, epoch, save=True, fig_size=(5, 5))

    
    
to.save(discriminator.state_dict(),'epoch_wnet_discriminator.pth')
to.save(generator.state_dict(),'epoch_wnet_generator.pth')
