GAN with model summary

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import imageio
import numpy as np
import matplotlib

from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from tqdm import tqdm
from data.data import load_mnist

matplotlib.style.use('ggplot')

# p
BATCH_SIZE = 512
epochs = 200
sample_size = 64 # fixed sample size
nz = 128 # latent vector size
k = 1 # number of steps to apply to the discriminator
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST Dataset
(xtrain, ytrain), (xval, yval), num_cls = load_mnist(final=True)

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,)),
])

to_pil_image = transforms.ToPILImage()
train_data = datasets.MNIST(
    root='input/data',
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

In [3]:
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.nz = nz
        self.main = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.main(x).view(-1, 1, 28, 28)

In [4]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.n_input = 784
        self.main = nn.Sequential(
            nn.Linear(self.n_input, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),

            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(-1, 784)
        return self.main(x)

In [5]:
# to create real labels (1s)
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)

In [6]:
# to create fake labels (0s)
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

In [7]:

# function to create the noise vector
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

In [8]:
# function to train the discriminator network
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)

    optimizer.zero_grad()

    output_real = discriminator(data_real)
    loss_real = criterion(output_real, real_label)

    output_fake = discriminator(data_fake)
    loss_fake = criterion(output_fake, fake_label)


    loss_real.backward()
    loss_fake.backward()
    optimizer.step()

    return loss_real + loss_fake

In [9]:
# function to train the generator network
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = label_real(b_size)

    optimizer.zero_grad()

    output = discriminator(data_fake)
    loss = criterion(output, real_label)

    loss.backward()
    optimizer.step()

    return loss   

In [10]:
import time
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [11]:
from pytorch_model_summary import summary 
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)

# show output shape and hierarchical view of net
print(summary(generator,torch.zeros((1, 1, 128, 128)), show_input=False, show_hierarchical=True))
print(summary(discriminator, torch.zeros((1, 1, 28, 28)), show_input=False, show_hierarchical=True))
#print(generator)
#print(discriminator)


# optimizers
optim_g = optim.Adam(generator.parameters(), lr=0.0002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# loss function
criterion = nn.BCELoss()

losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator

def train():
    # create the noise vector
    noise = create_noise(sample_size, nz)

    generator.train()
    discriminator.train()

    for epoch in range(epochs):
        start_time = time.time()
        loss_g = 0.0
        loss_d = 0.0
        for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
            image, _ = data
            image = image.to(device)
            b_size = len(image)
            # run the discriminator for k number of steps
            for step in range(k):
                data_fake = generator(create_noise(b_size, nz)).detach()
                data_real = image
                # train the discriminator network
                loss_d += train_discriminator(optim_d, data_real, data_fake)
            data_fake = generator(create_noise(b_size, nz))
            # train the generator network
            loss_g += train_generator(optim_g, data_fake)

        # create the final fake image for the epoch
        generated_img = generator(noise).cpu().detach()
        # make the images as grid
        generated_img = make_grid(generated_img)
        images.append(generated_img)
        epoch_loss_g = loss_g / bi # loss generator 
        epoch_loss_d = loss_d / bi # loss discriminator
        losses_g.append(epoch_loss_g)
        losses_d.append(epoch_loss_d)

        #SAVE IMAGE
        imgs = [np.array(to_pil_image(img)) for img in images]
        from PIL import Image
        im = Image.fromarray(imgs[epoch])
        im.save("Images_output/epoch_"+str(epoch)+".jpeg")

        end_time = time.time()
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)
        print(f'Epoch: {epoch+1:02} of {epochs}| Epoch Time: {epoch_mins}m {epoch_secs}s')
        print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")
        print(25*'==')

    return images

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Linear-1      [1, 1, 128, 256]          33,024          33,024
       LeakyReLU-2      [1, 1, 128, 256]               0               0
          Linear-3      [1, 1, 128, 512]         131,584         131,584
       LeakyReLU-4      [1, 1, 128, 512]               0               0
          Linear-5     [1, 1, 128, 1024]         525,312         525,312
       LeakyReLU-6     [1, 1, 128, 1024]               0               0
          Linear-7      [1, 1, 128, 784]         803,600         803,600
            Tanh-8      [1, 1, 128, 784]               0               0
Total params: 1,493,520
Trainable params: 1,493,520
Non-trainable params: 0
-------------------------------------------------------------------------



Generator(
  (main): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True), 33,024 params
    (1

In [12]:
images = train()

118it [02:08,  1.09s/it]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 01 of 200| Epoch Time: 2m 8s
Generator loss: 1.25310433, Discriminator loss: 0.94914639


118it [01:09,  1.71it/s]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 02 of 200| Epoch Time: 1m 9s
Generator loss: 2.20587516, Discriminator loss: 1.36936140


118it [01:10,  1.68it/s]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 03 of 200| Epoch Time: 1m 10s
Generator loss: 3.73245740, Discriminator loss: 0.72168392


118it [02:25,  1.23s/it]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 04 of 200| Epoch Time: 2m 25s
Generator loss: 1.23195577, Discriminator loss: 1.23472238


118it [01:46,  1.11it/s]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 05 of 200| Epoch Time: 1m 46s
Generator loss: 4.06655645, Discriminator loss: 1.06901169


118it [02:05,  1.06s/it]                         
  0%|          | 0/117 [00:00<?, ?it/s]

Epoch: 06 of 200| Epoch Time: 2m 5s
Generator loss: 1.41932285, Discriminator loss: 1.35428214


 91%|█████████ | 106/117 [01:29<00:09,  1.18it/s]


KeyboardInterrupt: 

In [None]:
imgs = [np.array(to_pil_image(img)) for img in images]

imageio.mimsave('Images_output/generator_images.gif', imgs)
from PIL import Image
im = Image.fromarray(imgs[0])
im.save("Images_output/epoch_0.jpeg")