In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import imageio
import numpy as np
import matplotlib
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
matplotlib.style.use('ggplot')

In [2]:
# hyper parameters
batch_size = 512
epochs = 200
sample_size = 64 # fixed sample size #noise vector to be sampled and then fed to the generator
nz = 128 # latent vector size
k = 1 # number of steps to apply to the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])
to_pil_image = transforms.ToPILImage()

In [4]:
from ipywidgets import IntProgress

In [5]:
train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

In [7]:
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )
    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

In [8]:
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)
print('##### GENERATOR #####')
print(generator)
print('######################')
print('\n##### DISCRIMINATOR #####')
print(discriminator)
print('######################')

##### GENERATOR #####
Generator(
  (main): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)
######################

##### DISCRIMINATOR #####
Discriminator(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1

In [9]:
# optimizers
optim_g = optim.Adam(generator.parameters(), lr=0.0002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# loss function
criterion = nn.BCELoss()

In [10]:
losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator

In [11]:
#utility functions to create 1s and 0s tensors

# to create real labels (1s)
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)
# to create fake labels (0s)
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)


#utility function for creating nosie

def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)


#utility function to save the images
def save_generator_image(image, path):
    save_image(image, path)

In [12]:
#trainig the discriminator

# function to train the discriminator network
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)
    optimizer.zero_grad()
    output_real = discriminator(data_real)
    loss_real = criterion(output_real, real_label)
    output_fake = discriminator(data_fake)
    loss_fake = criterion(output_fake, fake_label)
    loss_real.backward()
    loss_fake.backward()
    optimizer.step()
    return loss_real + loss_fake






In [13]:
# function to train the generator network
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = label_real(b_size)
    optimizer.zero_grad()
    output = discriminator(data_fake)
    loss = criterion(output, real_label)
    loss.backward()
    optimizer.step()
    return loss

In [14]:
# create the noise vector
noise = create_noise(sample_size, nz)

In [15]:
generator.train()
discriminator.train()

Discriminator(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=256, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

In [21]:
for epoch in range(epochs):
    loss_g = 0.0
    loss_d = 0.0
    for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
        image, _ = data
        image = image.to(device)
        b_size = len(image)
        # run the discriminator for k number of steps
        for step in range(k):
            data_fake = generator(create_noise(b_size, nz)).detach()
            data_real = image
            # train the discriminator network
            loss_d += train_discriminator(optim_d, data_real, data_fake)
        data_fake = generator(create_noise(b_size, nz))
        # train the generator network
        loss_g += train_generator(optim_g, data_fake)
    # create the final fake image for the epoch
    generated_img = generator(noise).cpu().detach()
    # make the images as grid
    generated_img = make_grid(generated_img)
    # save the generated torch tensor models to disk
    #save_generator_image(generated_img, gen_img{epoch}.png")
    save_generator_image(generated_img, f"{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / bi # total generator loss for the epoch
    epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    
    print(f"Epoch {epoch} of {epochs}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

118it [00:27,  4.34it/s]                         


Epoch 0 of 200
Generator loss: 3.04846168, Discriminator loss: 0.51752239


118it [00:25,  4.54it/s]                         


Epoch 1 of 200
Generator loss: 2.82308340, Discriminator loss: 0.52435666


118it [00:25,  4.69it/s]                         


Epoch 2 of 200
Generator loss: 2.88246775, Discriminator loss: 0.54591036


118it [00:27,  4.24it/s]                         


Epoch 3 of 200
Generator loss: 2.51740265, Discriminator loss: 0.59869552


118it [00:28,  4.21it/s]                         


Epoch 4 of 200
Generator loss: 2.75420189, Discriminator loss: 0.60543346


118it [00:25,  4.57it/s]                         


Epoch 5 of 200
Generator loss: 2.83259845, Discriminator loss: 0.55235434


118it [00:26,  4.39it/s]                         


Epoch 6 of 200
Generator loss: 2.80146718, Discriminator loss: 0.52395427


118it [00:27,  4.37it/s]                         


Epoch 7 of 200
Generator loss: 2.65280032, Discriminator loss: 0.59798837


118it [00:26,  4.53it/s]                         


Epoch 8 of 200
Generator loss: 2.85673928, Discriminator loss: 0.49899334


118it [00:28,  4.14it/s]                         


Epoch 9 of 200
Generator loss: 3.18620133, Discriminator loss: 0.43610138


118it [00:30,  3.93it/s]                         


Epoch 10 of 200
Generator loss: 2.77474785, Discriminator loss: 0.54569989


118it [00:29,  3.97it/s]                         


Epoch 11 of 200
Generator loss: 2.71807504, Discriminator loss: 0.50940245


118it [00:26,  4.41it/s]                         


Epoch 12 of 200
Generator loss: 2.86842179, Discriminator loss: 0.49809438


118it [00:27,  4.33it/s]                         


Epoch 13 of 200
Generator loss: 2.74201894, Discriminator loss: 0.57633585


118it [00:29,  3.96it/s]                         


Epoch 14 of 200
Generator loss: 2.70346904, Discriminator loss: 0.56185722


118it [00:28,  4.07it/s]                         


Epoch 15 of 200
Generator loss: 2.76393366, Discriminator loss: 0.52073294


118it [00:27,  4.29it/s]                         


Epoch 16 of 200
Generator loss: 2.84721899, Discriminator loss: 0.58344960


118it [00:26,  4.51it/s]                         


Epoch 17 of 200
Generator loss: 2.51780105, Discriminator loss: 0.64954013


118it [00:28,  4.11it/s]                         


Epoch 18 of 200
Generator loss: 2.61252928, Discriminator loss: 0.60877568


118it [00:29,  4.00it/s]                         


Epoch 19 of 200
Generator loss: 2.68753171, Discriminator loss: 0.57316256


118it [00:26,  4.38it/s]                         


Epoch 20 of 200
Generator loss: 2.67391109, Discriminator loss: 0.56103575


118it [00:25,  4.54it/s]                         


Epoch 21 of 200
Generator loss: 2.96209669, Discriminator loss: 0.50920504


118it [00:28,  4.18it/s]                         


Epoch 22 of 200
Generator loss: 3.00987625, Discriminator loss: 0.48139390


118it [00:26,  4.43it/s]                         


Epoch 23 of 200
Generator loss: 2.76035571, Discriminator loss: 0.58778316


118it [00:26,  4.45it/s]                         


Epoch 24 of 200
Generator loss: 2.75864744, Discriminator loss: 0.61037421


118it [00:28,  4.08it/s]                         


Epoch 25 of 200
Generator loss: 2.76375675, Discriminator loss: 0.57587492


118it [00:28,  4.08it/s]                         


Epoch 26 of 200
Generator loss: 2.91757655, Discriminator loss: 0.49474776


118it [00:30,  3.90it/s]                         


Epoch 27 of 200
Generator loss: 2.64715242, Discriminator loss: 0.60932481


118it [00:27,  4.25it/s]                         


Epoch 28 of 200
Generator loss: 2.56502223, Discriminator loss: 0.60816979


118it [00:27,  4.26it/s]                         


Epoch 29 of 200
Generator loss: 2.62114429, Discriminator loss: 0.58445793


118it [00:28,  4.19it/s]                         


Epoch 30 of 200
Generator loss: 2.63752699, Discriminator loss: 0.57450879


118it [00:27,  4.31it/s]                         


Epoch 31 of 200
Generator loss: 2.72104216, Discriminator loss: 0.54187608


118it [00:25,  4.69it/s]                         


Epoch 32 of 200
Generator loss: 2.74122739, Discriminator loss: 0.54461300


118it [00:27,  4.24it/s]                         


Epoch 33 of 200
Generator loss: 2.82945085, Discriminator loss: 0.57193810


118it [00:28,  4.12it/s]                         


Epoch 34 of 200
Generator loss: 2.57086658, Discriminator loss: 0.59077978


118it [00:27,  4.34it/s]                         


Epoch 35 of 200
Generator loss: 2.66312289, Discriminator loss: 0.58763647


118it [00:27,  4.37it/s]                         


Epoch 36 of 200
Generator loss: 2.66125751, Discriminator loss: 0.57345116


118it [00:28,  4.18it/s]                         


Epoch 37 of 200
Generator loss: 2.65303540, Discriminator loss: 0.59089327


118it [00:26,  4.40it/s]                         


Epoch 38 of 200
Generator loss: 2.51141143, Discriminator loss: 0.60319018


118it [00:25,  4.54it/s]                         


Epoch 39 of 200
Generator loss: 2.37673044, Discriminator loss: 0.67372757


118it [00:28,  4.12it/s]                         


Epoch 40 of 200
Generator loss: 2.39768124, Discriminator loss: 0.66524142


118it [00:28,  4.15it/s]                         


Epoch 41 of 200
Generator loss: 2.31458282, Discriminator loss: 0.66520280


118it [00:26,  4.47it/s]                         


Epoch 42 of 200
Generator loss: 2.31992579, Discriminator loss: 0.65908498


118it [00:28,  4.10it/s]                         


Epoch 43 of 200
Generator loss: 2.26091981, Discriminator loss: 0.68147171


118it [00:29,  4.05it/s]                         


Epoch 44 of 200
Generator loss: 2.24501371, Discriminator loss: 0.66655374


118it [00:28,  4.08it/s]                         


Epoch 45 of 200
Generator loss: 2.28224444, Discriminator loss: 0.67739493


118it [00:30,  3.86it/s]                         


Epoch 46 of 200
Generator loss: 2.21204686, Discriminator loss: 0.71844560


118it [00:30,  3.92it/s]                         


Epoch 47 of 200
Generator loss: 2.24132252, Discriminator loss: 0.69020802


118it [00:26,  4.44it/s]                         


Epoch 48 of 200
Generator loss: 2.36464334, Discriminator loss: 0.66671282


118it [00:28,  4.14it/s]                         


Epoch 49 of 200
Generator loss: 2.28156281, Discriminator loss: 0.69833612


118it [00:27,  4.26it/s]                         


Epoch 50 of 200
Generator loss: 2.26731205, Discriminator loss: 0.71987665


118it [00:29,  4.03it/s]                         


Epoch 51 of 200
Generator loss: 2.08398724, Discriminator loss: 0.77726763


118it [00:27,  4.26it/s]                         


Epoch 52 of 200
Generator loss: 2.12056923, Discriminator loss: 0.74591416


118it [00:28,  4.13it/s]                         


Epoch 53 of 200
Generator loss: 2.14242482, Discriminator loss: 0.72816819


118it [00:28,  4.15it/s]                         


Epoch 54 of 200
Generator loss: 2.06659102, Discriminator loss: 0.75327575


118it [00:30,  3.86it/s]                         


Epoch 55 of 200
Generator loss: 2.13898349, Discriminator loss: 0.74791378


118it [00:30,  3.83it/s]                         


Epoch 56 of 200
Generator loss: 2.03845215, Discriminator loss: 0.74903470


118it [00:29,  4.05it/s]                         


Epoch 57 of 200
Generator loss: 2.08324885, Discriminator loss: 0.74177045


118it [00:26,  4.43it/s]                         


Epoch 58 of 200
Generator loss: 2.11664319, Discriminator loss: 0.70147371


118it [00:27,  4.36it/s]                         


Epoch 59 of 200
Generator loss: 2.13879824, Discriminator loss: 0.73725516


118it [00:28,  4.18it/s]                         


Epoch 60 of 200
Generator loss: 1.96105087, Discriminator loss: 0.80817300


118it [00:27,  4.23it/s]                         


Epoch 61 of 200
Generator loss: 1.93791866, Discriminator loss: 0.79482174


118it [00:27,  4.29it/s]                         


Epoch 62 of 200
Generator loss: 2.07849503, Discriminator loss: 0.76635957


118it [00:28,  4.15it/s]                         


Epoch 63 of 200
Generator loss: 1.89291215, Discriminator loss: 0.81127334


118it [00:30,  3.82it/s]                         


Epoch 64 of 200
Generator loss: 2.04916906, Discriminator loss: 0.76509708


118it [00:29,  4.04it/s]                         


Epoch 65 of 200
Generator loss: 2.02758360, Discriminator loss: 0.76841575


118it [00:28,  4.20it/s]                         


Epoch 66 of 200
Generator loss: 1.96963811, Discriminator loss: 0.79463100


118it [00:27,  4.29it/s]                         


Epoch 67 of 200
Generator loss: 1.83641458, Discriminator loss: 0.82452697


118it [00:30,  3.88it/s]                         


Epoch 68 of 200
Generator loss: 1.91185856, Discriminator loss: 0.78890103


118it [00:28,  4.10it/s]                         


Epoch 69 of 200
Generator loss: 1.90041459, Discriminator loss: 0.80294174


118it [00:30,  3.88it/s]                         


Epoch 70 of 200
Generator loss: 1.96452320, Discriminator loss: 0.81383944


118it [00:31,  3.69it/s]                         


Epoch 71 of 200
Generator loss: 1.88288963, Discriminator loss: 0.81748706


118it [00:28,  4.12it/s]                         


Epoch 72 of 200
Generator loss: 1.94979990, Discriminator loss: 0.79821557


118it [00:26,  4.52it/s]                         


Epoch 73 of 200
Generator loss: 2.04881477, Discriminator loss: 0.77630770


118it [00:29,  3.99it/s]                         


Epoch 74 of 200
Generator loss: 1.87866545, Discriminator loss: 0.83594459


118it [00:28,  4.14it/s]                         


Epoch 75 of 200
Generator loss: 1.96832871, Discriminator loss: 0.79457718


118it [00:26,  4.41it/s]                         


Epoch 76 of 200
Generator loss: 1.95998299, Discriminator loss: 0.81319106


118it [00:28,  4.10it/s]                         


Epoch 77 of 200
Generator loss: 1.86723018, Discriminator loss: 0.81119728


118it [00:31,  3.70it/s]                         


Epoch 78 of 200
Generator loss: 1.80346870, Discriminator loss: 0.87606412


118it [00:30,  3.89it/s]                         


Epoch 79 of 200
Generator loss: 1.80034745, Discriminator loss: 0.86702377


118it [00:31,  3.78it/s]                         


Epoch 80 of 200
Generator loss: 1.82504296, Discriminator loss: 0.83328652


118it [00:30,  3.81it/s]                         


Epoch 81 of 200
Generator loss: 1.79615235, Discriminator loss: 0.85813463


118it [00:30,  3.84it/s]                         


Epoch 82 of 200
Generator loss: 1.76210737, Discriminator loss: 0.86599737


118it [00:30,  3.83it/s]                         


Epoch 83 of 200
Generator loss: 1.74667609, Discriminator loss: 0.87377632


118it [00:30,  3.85it/s]                         


Epoch 84 of 200
Generator loss: 1.84695995, Discriminator loss: 0.85108989


118it [00:31,  3.76it/s]                         


Epoch 85 of 200
Generator loss: 1.73250115, Discriminator loss: 0.87937552


118it [00:30,  3.84it/s]                         


Epoch 86 of 200
Generator loss: 1.70000422, Discriminator loss: 0.87900198


118it [00:27,  4.33it/s]                         


Epoch 87 of 200
Generator loss: 1.69401610, Discriminator loss: 0.87544608


118it [00:27,  4.27it/s]                         


Epoch 88 of 200
Generator loss: 1.71298862, Discriminator loss: 0.88358635


118it [00:30,  3.89it/s]                         


Epoch 89 of 200
Generator loss: 1.70234716, Discriminator loss: 0.89479047


118it [00:35,  3.36it/s]                         


Epoch 90 of 200
Generator loss: 1.69333327, Discriminator loss: 0.89325964


118it [00:34,  3.46it/s]                         


Epoch 91 of 200
Generator loss: 1.60485291, Discriminator loss: 0.91290957


118it [00:28,  4.12it/s]                         


Epoch 92 of 200
Generator loss: 1.69904721, Discriminator loss: 0.89330679


118it [00:28,  4.20it/s]                         


Epoch 93 of 200
Generator loss: 1.60676515, Discriminator loss: 0.92593604


118it [00:29,  3.98it/s]                         


Epoch 94 of 200
Generator loss: 1.69016874, Discriminator loss: 0.92441529


118it [00:33,  3.56it/s]                         


Epoch 95 of 200
Generator loss: 1.63463509, Discriminator loss: 0.90540272


118it [00:30,  3.85it/s]                         


Epoch 96 of 200
Generator loss: 1.67825675, Discriminator loss: 0.90868545


118it [00:30,  3.82it/s]                         


Epoch 97 of 200
Generator loss: 1.64502358, Discriminator loss: 0.90880036


118it [00:28,  4.15it/s]                         


Epoch 98 of 200
Generator loss: 1.65667307, Discriminator loss: 0.91496849


118it [00:26,  4.39it/s]                         


Epoch 99 of 200
Generator loss: 1.75477087, Discriminator loss: 0.89275020


118it [00:29,  4.00it/s]                         


Epoch 100 of 200
Generator loss: 1.65249681, Discriminator loss: 0.89590102


118it [00:29,  4.05it/s]                         


Epoch 101 of 200
Generator loss: 1.65984845, Discriminator loss: 0.92799407


 13%|█▎        | 15/117 [00:03<00:23,  4.37it/s]


KeyboardInterrupt: 