In [83]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [84]:
bs = 100 ##　batch size

# MNIST Dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
##test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True,drop_last=True)
##test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [85]:
class Generator(nn.Module):
    def __init__(self,z_dim,out_dim):
        super(Generator,self).__init__()
        self.fc1 = nn.Linear(z_dim,256)
        self.fc2 = nn.Linear(256,512)
        self.fc3 = nn.Linear(512,out_dim)
    def forward(self,x):
        x = F.leaky_relu(self.fc1(x),0.25)
        x = F.leaky_relu(self.fc2(x),0.25)
        return F.tanh(self.fc3(x))
class Discriminator(nn.Module):
    def __init__(self,input_dim):
        super(Discriminator,self).__init__()
        self.fc1 = nn.Linear(input_dim,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,1)
    def forward(self,x):
        x = F.leaky_relu(self.fc1(x),0.25)
        x = F.dropout(x,0.3)
        x = F.leaky_relu(self.fc2(x),0.25)
        x = F.dropout(x,0.3)
        return F.sigmoid(self.fc3(x))

In [86]:
batch_size = 100
mnist_size = 784 ## 28 * 28
latent_space_size = 100
lr = 0.00025
G = Generator(z_dim= latent_space_size,out_dim= mnist_size).to(device)
D = Discriminator(input_dim= mnist_size).to(device)
G_optim = optim.Adam(G.parameters(),lr = lr)
D_optim = optim.Adam(D.parameters(),lr = lr)
loss = nn.BCELoss()

In [87]:
def D_train(x):
    D.zero_grad()
    x_real , y_real = x.view(-1,784).to(device) ,torch.ones(batch_size,1,device=device)
    real_loss = loss(D(x_real),y_real)

    z , y_fake = torch.randn(batch_size,latent_space_size,device = device) , torch.zeros(batch_size,1,device=device)
    fake_loss = loss(D(G(z)),y_fake)

    D_loss = real_loss + fake_loss
    D_loss.backward()
    D_optim.step()
    return D_loss.item()

def G_train(x):
    G.zero_grad()
    z , y = torch.randn(batch_size,latent_space_size,device=device) , torch.ones(batch_size,1,device=device)
    G_loss = loss(D(G(z)),y)
    G_loss.backward()
    G_optim.step()
    return G_loss.item()

In [88]:
n_epoch = 200
D_losses , G_losses = [] , []
for epoch in range(1 , n_epoch+1):
    for batch_idx , (x,_) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    if(epoch % 50 == 0):
        with torch.no_grad():
            test_z = torch.randn(bs, 100).to(device)
            generated = G(test_z)
        save_image(generated.view(generated.size(0), 1, 28, 28), './samples/Stage_'+str(epoch/50) + '.png')
        torch.save(G,'./trained/Generator_s'+str(epoch/50)+'.pth')
        torch.save(D,'./trained/Discriminator_s'+str(epoch/50)+'.pth')

    

[1/200]: loss_d: 0.590, loss_g: 3.547
[2/200]: loss_d: 0.685, loss_g: 2.830
[3/200]: loss_d: 0.693, loss_g: 2.640
[4/200]: loss_d: 0.685, loss_g: 2.627
[5/200]: loss_d: 0.700, loss_g: 2.549
[6/200]: loss_d: 0.724, loss_g: 2.484
[7/200]: loss_d: 0.739, loss_g: 2.427
[8/200]: loss_d: 0.754, loss_g: 2.363
[9/200]: loss_d: 0.774, loss_g: 2.302
[10/200]: loss_d: 0.794, loss_g: 2.239
[11/200]: loss_d: 0.807, loss_g: 2.200
[12/200]: loss_d: 0.820, loss_g: 2.161
[13/200]: loss_d: 0.838, loss_g: 2.109
[14/200]: loss_d: 0.852, loss_g: 2.064
[15/200]: loss_d: 0.861, loss_g: 2.021
[16/200]: loss_d: 0.872, loss_g: 1.984
[17/200]: loss_d: 0.885, loss_g: 1.944
[18/200]: loss_d: 0.896, loss_g: 1.914
[19/200]: loss_d: 0.905, loss_g: 1.881
[20/200]: loss_d: 0.915, loss_g: 1.853
[21/200]: loss_d: 0.921, loss_g: 1.829
[22/200]: loss_d: 0.929, loss_g: 1.806
[23/200]: loss_d: 0.934, loss_g: 1.787
[24/200]: loss_d: 0.939, loss_g: 1.769
[25/200]: loss_d: 0.944, loss_g: 1.752
[26/200]: loss_d: 0.948, loss_g: 1