# GAN ( Generative Adversarial Networks)

### Importing packages

In [1]:
#!pip install torchvision

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Loading the Dataset

In [2]:
train_dataset = datasets.CIFAR10(root='./data',train=True, download=True, transform=ToTensor())
dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True,num_workers=4,\
                                             pin_memory=True)

Files already downloaded and verified


In [3]:
classes = train_dataset.classes
classes

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [4]:
class_count = {}
for _, index in train_dataset:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

{'frog': 5000,
 'truck': 5000,
 'deer': 5000,
 'automobile': 5000,
 'bird': 5000,
 'horse': 5000,
 'ship': 5000,
 'cat': 5000,
 'dog': 5000,
 'airplane': 5000}

In [None]:
for images, _ in dataloader:
    print('images.shape:', images.shape)
    plt.figure(figsize=(16,8))
    plt.axis('off')
    plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
    break

images.shape: torch.Size([32, 3, 32, 32])


### Defining parameters to be used(hyperparameters)

In [5]:
latent_dim =100
lr = 0.0002
beta1= 0.5
beta2 = 0.9999
num_epochs = 10

### Utility class for Building the Generator

In [6]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        
        self.model = nn.Sequential(
        nn.Linear(latent_dim, 128*8*8),
            nn.ReLU(),
            nn.Unflatten(1, (128,8,8)),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128, momentum=0.78),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64, momentum=0.78),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Tanh()       
        )

    def forward(self, z):
        img = self.model(z)
        return img

### Utility class for Building the Discriminator

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
 
        self.model = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
        nn.ZeroPad2d((0, 1, 0, 1)),
        nn.BatchNorm2d(64, momentum=0.82),
        nn.LeakyReLU(0.25),
        nn.Dropout(0.25),
        nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(128, momentum=0.82),
        nn.LeakyReLU(0.2),
        nn.Dropout(0.25),
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(256, momentum=0.8),
        nn.LeakyReLU(0.25),
        nn.Dropout(0.25),
        nn.Flatten(),
        nn.Linear(256 * 5 * 5, 1),
        nn.Sigmoid()
    )
 
    def forward(self, img):
        validity = self.model(img)
        return validity

### Building the GAN

In [8]:
#Define and initialize Generator and Discriminator
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

#Loss function
adversarial_loss = nn.BCELoss()

#Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

### Training the Generative Adversarial Network

In [None]:
for epoch in range(num_epochs):
    for i, batch in enumerate(dataloader):
        #Convert list to tensor
        real_images = batch[0].to(device)
        
        #Adversarial ground truths
        valid = torch.ones(real_images.size(0), 1, device=device)
        fake = torch.zeros(real_images.size(0), 1, device=device)
        
        #Configure input
        real_images = real_images.to(device)
        
         # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        
        #sample noise as generator input
        z = torch.randn(real_images.size(0), latent_dim, device=device)
        
               # Generate a batch of images
        fake_images = generator(z)
 
        # Measure discriminator's ability
        # to classify real and fake images
        real_loss = adversarial_loss(discriminator\
                                     (real_images), valid)
        fake_loss = adversarial_loss(discriminator\
                                     (fake_images.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
 
        # Backward pass and optimize
        d_loss.backward()
        optimizer_D.step()
 
        # -----------------
        #  Train Generator
        # -----------------
 
        optimizer_G.zero_grad()
 
        # Generate a batch of images
        gen_images = generator(z)
 
        # Adversarial loss
        g_loss = adversarial_loss(discriminator(gen_images), valid)
 
        # Backward pass and optimize
        g_loss.backward()
        optimizer_G.step()
 
        # ---------------------
        #  Progress Monitoring
        # ---------------------
 
        if (i + 1) % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}]\
                        Batch {i+1}/{len(dataloader)} "
                f"Discriminator Loss: {d_loss.item():.4f} "
                f"Generator Loss: {g_loss.item():.4f}"
            )
 
    # Save generated images for every epoch
    if (epoch + 1) % 10 == 0:
        with torch.no_grad():
            z = torch.randn(16, latent_dim, device=device)
            generated = generator(z).detach().cpu()
            grid = torchvision.utils.make_grid(generated,\
                                        nrow=4, normalize=True)
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.axis("off")
            plt.show()
        

Epoch [1/10]                        Batch 100/1563 Discriminator Loss: 0.6232 Generator Loss: 1.1653
Epoch [1/10]                        Batch 200/1563 Discriminator Loss: 0.6602 Generator Loss: 0.9345
Epoch [1/10]                        Batch 300/1563 Discriminator Loss: 0.7164 Generator Loss: 0.8172
Epoch [1/10]                        Batch 400/1563 Discriminator Loss: 0.6349 Generator Loss: 0.8141
Epoch [1/10]                        Batch 500/1563 Discriminator Loss: 0.7524 Generator Loss: 0.8042
Epoch [1/10]                        Batch 600/1563 Discriminator Loss: 0.6341 Generator Loss: 0.9826
Epoch [1/10]                        Batch 700/1563 Discriminator Loss: 0.6478 Generator Loss: 1.1971
Epoch [1/10]                        Batch 800/1563 Discriminator Loss: 0.5715 Generator Loss: 0.9058
Epoch [1/10]                        Batch 900/1563 Discriminator Loss: 0.5660 Generator Loss: 0.9817
Epoch [1/10]                        Batch 1000/1563 Discriminator Loss: 0.5913 Generator Lo

Epoch [6/10]                        Batch 700/1563 Discriminator Loss: 0.5434 Generator Loss: 0.7162
Epoch [6/10]                        Batch 800/1563 Discriminator Loss: 0.7455 Generator Loss: 0.7814
Epoch [6/10]                        Batch 900/1563 Discriminator Loss: 0.5010 Generator Loss: 1.1024
Epoch [6/10]                        Batch 1000/1563 Discriminator Loss: 0.8211 Generator Loss: 0.7927
Epoch [6/10]                        Batch 1100/1563 Discriminator Loss: 0.6854 Generator Loss: 1.2134
Epoch [6/10]                        Batch 1200/1563 Discriminator Loss: 0.6860 Generator Loss: 0.9473
Epoch [6/10]                        Batch 1300/1563 Discriminator Loss: 0.6295 Generator Loss: 1.0821
Epoch [6/10]                        Batch 1400/1563 Discriminator Loss: 0.5734 Generator Loss: 1.1597
Epoch [6/10]                        Batch 1500/1563 Discriminator Loss: 0.7155 Generator Loss: 1.0721
Epoch [7/10]                        Batch 100/1563 Discriminator Loss: 0.5986 Generat