In [1]:
import torch
import numpy as np
from torch import nn,optim
from torchvision import datasets,transforms

In [3]:
from utils import Logger

In [2]:
def data():
    preprocess=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
    return (datasets.MNIST(root='./dataset',train=True,transform=preprocess,download=True))
    

In [3]:
mnist_data=data()

In [4]:
dataloader=torch.utils.data.DataLoader(mnist_data,batch_size=100,shuffle=True)
num_batches=len(dataloader)

In [5]:
#Discriminator Net
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        n_features=784 #28*28 is the input size for each image
        n_out=1 #we need a binary decision here: real or fake
        self.lin0=nn.Linear(n_features,1024)
        self.lin1=nn.Linear(1024,512)
        self.lin2=nn.Linear(512,256)
        self.out=nn.Linear(256,n_out)
        self.dropout=nn.Dropout(p=0.3)
        self.leaky=nn.LeakyReLU(0.2)
        self.sigmoid=nn.Sigmoid()
        
    def forward(self,x):
        x=self.dropout(self.leaky(self.lin0(x)))
        x=self.dropout(self.leaky(self.lin1(x)))
        x=self.dropout(self.leaky(self.lin2(x)))
        x=self.sigmoid(self.out(x))
        return x

In [6]:
#object creation
disc=Discriminator()

In [7]:
#additional functionality
def image_to_vec(image):
    return image.view(image.size(0),784)

def vec_to_image(vector):
    return vector.view(vector.size(0),1,28,28)


In [8]:
#Generator net
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        n_features=100
        n_out=784
        self.lin0=nn.Linear(n_features,256)
        self.lin1=nn.Linear(256,512)
        self.lin2=nn.Linear(512,1024)
        self.out=nn.Linear(1024,n_out)
        self.leaky=nn.LeakyReLU(0.2)
        self.tanh=nn.Tanh()

        def forward(self,x):
            x=self.leaky(self.lin0(x))
            x=self.leaky(self.lin1(x))
            x=self.leaky(self.lin2(x))
            x=self.tanh(self.out(x))
            return x
        
        
        

In [11]:
gen=Generator()

In [9]:
#functionality for random noise
def noise(size):
    n=torch.randn(size,100)#100 for the batch size
    return n

In [12]:
#optimizers
disc_optimizer=optim.Adam(disc.parameters(),lr=0.0002)
gen_optimizer=optim.Adam(gen.parameters(),lr=0.0002)

In [13]:
loss=nn.BCELoss()#resembles the loss we need

In [14]:
#more functionality for image targets
def ones_target(size):
    data=torch.ones(size,1)
    return data

def zeros_target(size):
    data=torch.zeros(size,1)
    return data
    

In [15]:
def train_disc(opt,real_data,fake_data):
    N=real_data.size(0)
    opt.zero_grad()#reset gradients
    
    pred_real=disc(real_data)
    error_real=loss(pred_real,ones_target(N))
    error_real.backward()
    
    pred_fake=disc(fake_data)
    error_fake=loss(pred_fake,zeros_target(N))
    error_fake.backward()
    
    opt.step()
    
    return error_real+error_fake,pred_real,pred_fake

    

In [16]:
def train_gen(opt,fake_data):
    N=fake_data.size(0)
    opt.zero_grad()
    
    pred=disc(fake_data)
    error=loss(pred,ones_target(N))
    error.backward()
    opt.step()
    return error


In [17]:
num_test_samples=16
test_noise=noise(num_test_samples)

In [None]:
#training
logger = Logger(model_name='VGAN', data_name='MNIST')

num_epochs=200
for epoch in range(num_epochs):
    for n_batch,(real_batch,_) in enumerate(dataloader):
        N=real_batch.size(0)
        
        real_data=image_to_vec(real_batch)
        
        fake_data=gen(noise(N)).detach()
        d_error,d_pred_real,d_pred_fake=train_disc(d_opt,real_data,fake_data)
        
        fake_data=gen(noise(N))
        g_error=train_gen(_opt,fake_data)
        
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        
        if (n_batch)%10==0: 
            test_images=vec_to_image(generator(test_noise))
            test_images=test_images.data
        
        logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
        logger.display_status(epoch, num_epochs, n_batch, num_batches,d_error, g_error, d_pred_real, d_pred_fake)
            