In [None]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.utils.data import Subset

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
print(device)

In [None]:
bs = 100 #paper 64, blog 256 -- ideal batch size ranges from 32 to 128

# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) #mean 0.5, and std dev 0.5

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

print(len(train_dataset), len(test_dataset))
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle = True)
print(len(train_loader), len(test_loader))

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3) #if we are overfitting
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
print(train_dataset.data.size(0),train_dataset.data.size(1),train_dataset.data.size(2))

In [None]:
z_dim = 100
mnist_dim = train_dataset.data.size(1) * train_dataset.data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [None]:
G

In [None]:
D

In [None]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [None]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [None]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [None]:
n_epoch = 1000
st_losses_g = [] #store losses for plotting
st_losses_d = [] #store losses for plotting
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
    st_losses_g.append(torch.mean(torch.FloatTensor(G_losses))) #add this to other one
    st_losses_d.append(torch.mean(torch.FloatTensor(D_losses))) #add this to other one
    
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % ((epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    if epoch == 1 or epoch %10 == 0:
        with torch.no_grad():
            test_z = Variable(torch.randn(bs, z_dim).to(device)) #generate noise
            generated = G(test_z)
            filename = './output_bs100/GAN_bs100_%04d_epoch.png' %epoch
            save_image(generated.view(100, 1, 28, 28), filename)

[1/100]: loss_d: 0.784, loss_g: 4.191
[2/100]: loss_d: 0.823, loss_g: 3.266
[3/100]: loss_d: 0.729, loss_g: 2.635
[4/100]: loss_d: 0.312, loss_g: 4.321
[5/100]: loss_d: 0.539, loss_g: 3.450
[6/100]: loss_d: 0.519, loss_g: 2.883
[7/100]: loss_d: 0.551, loss_g: 2.844
[8/100]: loss_d: 0.591, loss_g: 2.619
[9/100]: loss_d: 0.616, loss_g: 2.493
[10/100]: loss_d: 0.666, loss_g: 2.291
[11/100]: loss_d: 0.690, loss_g: 2.180
[12/100]: loss_d: 0.786, loss_g: 1.956
[13/100]: loss_d: 0.774, loss_g: 2.015
[14/100]: loss_d: 0.789, loss_g: 2.023
[15/100]: loss_d: 0.804, loss_g: 1.935
[16/100]: loss_d: 0.859, loss_g: 1.808
[17/100]: loss_d: 0.848, loss_g: 1.821
[18/100]: loss_d: 0.901, loss_g: 1.722
[19/100]: loss_d: 0.908, loss_g: 1.629
[20/100]: loss_d: 0.933, loss_g: 1.589
[21/100]: loss_d: 0.965, loss_g: 1.515
[22/100]: loss_d: 0.964, loss_g: 1.494
[23/100]: loss_d: 0.959, loss_g: 1.526
[24/100]: loss_d: 0.987, loss_g: 1.469
[25/100]: loss_d: 0.975, loss_g: 1.455
[26/100]: loss_d: 0.996, loss_g: 1.439
[27/100]: loss_d: 1.026, loss_g: 1.348
[28/100]: loss_d: 1.019, loss_g: 1.352
[29/100]: loss_d: 1.026, loss_g: 1.385
[30/100]: loss_d: 1.024, loss_g: 1.364
[31/100]: loss_d: 1.059, loss_g: 1.305
[32/100]: loss_d: 1.065, loss_g: 1.263
[33/100]: loss_d: 1.065, loss_g: 1.278
[34/100]: loss_d: 1.067, loss_g: 1.282
[35/100]: loss_d: 1.074, loss_g: 1.270
[36/100]: loss_d: 1.096, loss_g: 1.218
[37/100]: loss_d: 1.112, loss_g: 1.182
[38/100]: loss_d: 1.110, loss_g: 1.205
[39/100]: loss_d: 1.128, loss_g: 1.152
[40/100]: loss_d: 1.123, loss_g: 1.155
[41/100]: loss_d: 1.131, loss_g: 1.144
[42/100]: loss_d: 1.144, loss_g: 1.129
[43/100]: loss_d: 1.145, loss_g: 1.118
[44/100]: loss_d: 1.141, loss_g: 1.125
[45/100]: loss_d: 1.143, loss_g: 1.144
[46/100]: loss_d: 1.148, loss_g: 1.118
[47/100]: loss_d: 1.156, loss_g: 1.120
[48/100]: loss_d: 1.141, loss_g: 1.132
[49/100]: loss_d: 1.164, loss_g: 1.093
[50/100]: loss_d: 1.172, loss_g: 1.075
[51/100]: loss_d: 1.165, loss_g: 1.099
[52/100]: loss_d: 1.172, loss_g: 1.085
[53/100]: loss_d: 1.161, loss_g: 1.097
[54/100]: loss_d: 1.173, loss_g: 1.077
[55/100]: loss_d: 1.180, loss_g: 1.067
[56/100]: loss_d: 1.192, loss_g: 1.053
[57/100]: loss_d: 1.188, loss_g: 1.044
[58/100]: loss_d: 1.186, loss_g: 1.065
[59/100]: loss_d: 1.200, loss_g: 1.026
[60/100]: loss_d: 1.205, loss_g: 1.022
[61/100]: loss_d: 1.204, loss_g: 1.055
[62/100]: loss_d: 1.200, loss_g: 1.033
[63/100]: loss_d: 1.212, loss_g: 0.999
[64/100]: loss_d: 1.210, loss_g: 1.004
[65/100]: loss_d: 1.209, loss_g: 1.019
[66/100]: loss_d: 1.198, loss_g: 1.031
[67/100]: loss_d: 1.210, loss_g: 1.012
[68/100]: loss_d: 1.208, loss_g: 1.014
[69/100]: loss_d: 1.214, loss_g: 1.005
[70/100]: loss_d: 1.217, loss_g: 1.004
[71/100]: loss_d: 1.217, loss_g: 0.994
[72/100]: loss_d: 1.212, loss_g: 1.009
[73/100]: loss_d: 1.229, loss_g: 0.991
[74/100]: loss_d: 1.223, loss_g: 0.981
[75/100]: loss_d: 1.223, loss_g: 0.979
[76/100]: loss_d: 1.223, loss_g: 0.977
[77/100]: loss_d: 1.223, loss_g: 0.996
[78/100]: loss_d: 1.217, loss_g: 0.994
[79/100]: loss_d: 1.229, loss_g: 0.986
[80/100]: loss_d: 1.228, loss_g: 0.984
[81/100]: loss_d: 1.224, loss_g: 0.980
[82/100]: loss_d: 1.230, loss_g: 0.982
[83/100]: loss_d: 1.235, loss_g: 0.969
[84/100]: loss_d: 1.231, loss_g: 0.969
[85/100]: loss_d: 1.241, loss_g: 0.955
[86/100]: loss_d: 1.231, loss_g: 0.983
[87/100]: loss_d: 1.231, loss_g: 0.973
[88/100]: loss_d: 1.236, loss_g: 0.974
[89/100]: loss_d: 1.230, loss_g: 0.973
[90/100]: loss_d: 1.234, loss_g: 0.964
[91/100]: loss_d: 1.245, loss_g: 0.944
[92/100]: loss_d: 1.237, loss_g: 0.958
[93/100]: loss_d: 1.236, loss_g: 0.968
[94/100]: loss_d: 1.241, loss_g: 0.954
[95/100]: loss_d: 1.238, loss_g: 0.961
[96/100]: loss_d: 1.241, loss_g: 0.962
[97/100]: loss_d: 1.251, loss_g: 0.930
[98/100]: loss_d: 1.245, loss_g: 0.959
[99/100]: loss_d: 1.243, loss_g: 0.954
[100/100]: loss_d: 1.252, loss_g: 0.939

[101/200]: loss_d: 1.247, loss_g: 0.953
[102/200]: loss_d: 1.250, loss_g: 0.936
[103/200]: loss_d: 1.245, loss_g: 0.958
[104/200]: loss_d: 1.242, loss_g: 0.962
[105/200]: loss_d: 1.236, loss_g: 0.973
[106/200]: loss_d: 1.237, loss_g: 0.965
[107/200]: loss_d: 1.239, loss_g: 0.968
[108/200]: loss_d: 1.241, loss_g: 0.950
[109/200]: loss_d: 1.247, loss_g: 0.952
[110/200]: loss_d: 1.244, loss_g: 0.963
[111/200]: loss_d: 1.246, loss_g: 0.946
[112/200]: loss_d: 1.253, loss_g: 0.929
[113/200]: loss_d: 1.248, loss_g: 0.943
[114/200]: loss_d: 1.243, loss_g: 0.952
[115/200]: loss_d: 1.247, loss_g: 0.945
[116/200]: loss_d: 1.251, loss_g: 0.947
[117/200]: loss_d: 1.249, loss_g: 0.942
[118/200]: loss_d: 1.244, loss_g: 0.943
[119/200]: loss_d: 1.249, loss_g: 0.946
[120/200]: loss_d: 1.250, loss_g: 0.951
[121/200]: loss_d: 1.241, loss_g: 0.956
[122/200]: loss_d: 1.254, loss_g: 0.939
[123/200]: loss_d: 1.257, loss_g: 0.925
[124/200]: loss_d: 1.251, loss_g: 0.945
[125/200]: loss_d: 1.244, loss_g: 0.953
[126/200]: loss_d: 1.252, loss_g: 0.937
[127/200]: loss_d: 1.251, loss_g: 0.937
[128/200]: loss_d: 1.246, loss_g: 0.959
[129/200]: loss_d: 1.251, loss_g: 0.945
[130/200]: loss_d: 1.249, loss_g: 0.934
[131/200]: loss_d: 1.255, loss_g: 0.933
[132/200]: loss_d: 1.249, loss_g: 0.937
[133/200]: loss_d: 1.253, loss_g: 0.944
[134/200]: loss_d: 1.256, loss_g: 0.928
[135/200]: loss_d: 1.255, loss_g: 0.935
[136/200]: loss_d: 1.244, loss_g: 0.953
[137/200]: loss_d: 1.252, loss_g: 0.939
[138/200]: loss_d: 1.249, loss_g: 0.941
[139/200]: loss_d: 1.249, loss_g: 0.940
[140/200]: loss_d: 1.255, loss_g: 0.934
[141/200]: loss_d: 1.249, loss_g: 0.952
[142/200]: loss_d: 1.245, loss_g: 0.949
[143/200]: loss_d: 1.255, loss_g: 0.931
[144/200]: loss_d: 1.254, loss_g: 0.944
[145/200]: loss_d: 1.250, loss_g: 0.953
[146/200]: loss_d: 1.254, loss_g: 0.926
[147/200]: loss_d: 1.257, loss_g: 0.930
[148/200]: loss_d: 1.260, loss_g: 0.923
[149/200]: loss_d: 1.256, loss_g: 0.928
[150/200]: loss_d: 1.260, loss_g: 0.927
[151/200]: loss_d: 1.253, loss_g: 0.941
[152/200]: loss_d: 1.252, loss_g: 0.932
[153/200]: loss_d: 1.252, loss_g: 0.942
[154/200]: loss_d: 1.251, loss_g: 0.943
[155/200]: loss_d: 1.255, loss_g: 0.933
[156/200]: loss_d: 1.261, loss_g: 0.922
[157/200]: loss_d: 1.259, loss_g: 0.927
[158/200]: loss_d: 1.253, loss_g: 0.933
[159/200]: loss_d: 1.252, loss_g: 0.931
[160/200]: loss_d: 1.255, loss_g: 0.933
[161/200]: loss_d: 1.245, loss_g: 0.948
[162/200]: loss_d: 1.254, loss_g: 0.938
[163/200]: loss_d: 1.246, loss_g: 0.944
[164/200]: loss_d: 1.259, loss_g: 0.935
[165/200]: loss_d: 1.259, loss_g: 0.918
[166/200]: loss_d: 1.253, loss_g: 0.934
[167/200]: loss_d: 1.264, loss_g: 0.912
[168/200]: loss_d: 1.250, loss_g: 0.938
[169/200]: loss_d: 1.259, loss_g: 0.918
[170/200]: loss_d: 1.256, loss_g: 0.940
[171/200]: loss_d: 1.255, loss_g: 0.940
[172/200]: loss_d: 1.261, loss_g: 0.918
[173/200]: loss_d: 1.252, loss_g: 0.937
[174/200]: loss_d: 1.251, loss_g: 0.937
[175/200]: loss_d: 1.255, loss_g: 0.922
[176/200]: loss_d: 1.259, loss_g: 0.924
[177/200]: loss_d: 1.257, loss_g: 0.940
[178/200]: loss_d: 1.261, loss_g: 0.923
[179/200]: loss_d: 1.257, loss_g: 0.933
[180/200]: loss_d: 1.255, loss_g: 0.935
[181/200]: loss_d: 1.253, loss_g: 0.931
[182/200]: loss_d: 1.260, loss_g: 0.929
[183/200]: loss_d: 1.262, loss_g: 0.922
[184/200]: loss_d: 1.255, loss_g: 0.946
[185/200]: loss_d: 1.252, loss_g: 0.930
[186/200]: loss_d: 1.259, loss_g: 0.926
[187/200]: loss_d: 1.248, loss_g: 0.943
[188/200]: loss_d: 1.253, loss_g: 0.933
[189/200]: loss_d: 1.257, loss_g: 0.916
[190/200]: loss_d: 1.252, loss_g: 0.943
[191/200]: loss_d: 1.251, loss_g: 0.937
[192/200]: loss_d: 1.253, loss_g: 0.926
[193/200]: loss_d: 1.262, loss_g: 0.920
[194/200]: loss_d: 1.258, loss_g: 0.930
[195/200]: loss_d: 1.255, loss_g: 0.932
[196/200]: loss_d: 1.257, loss_g: 0.921
[197/200]: loss_d: 1.251, loss_g: 0.930
[198/200]: loss_d: 1.257, loss_g: 0.925
[199/200]: loss_d: 1.246, loss_g: 0.955
[200/200]: loss_d: 1.247, loss_g: 0.949

In [None]:
# plot and save the generator and discriminator loss
import matplotlib
import matplotlib.pyplot as plt
plt.figure()
plt.plot(st_losses_g, label='Generator loss')
plt.plot(st_losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig('./output_bs100/GAN_loss_bs100_1000epoch.png')

In [None]:
!zip -r /content/output.zip /content/output_bs100

In [None]:
from google.colab import files
files.download("/content/output.zip")