## Implementation of vanilla GAN(Generative Adverserial Network)

In [21]:
import torch 
import torch.nn as nn
import torch.functional as f
import torch.optim as optim
from torch.utils.data import DataLoader 
import torchvision
import torchvision.transforms as transformers
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Available device is: {device}')

Available device is: cuda


In [4]:
transform = transformers.Compose([
    transformers.ToTensor(),
    transformers.Normalize((0.5, ), (0.5))
])
# loading dataset

MNIST_train = datasets.MNIST(root='/home/aniket-kalra/BS/GenAI_repo/data', train= True, download= True, transform= transform)
MNIST_test = datasets.MNIST(root='/home/aniket-kalra/BS/GenAI_repo/data', train= False, download= True)

100%|██████████| 9.91M/9.91M [00:06<00:00, 1.52MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 113kB/s]
100%|██████████| 1.65M/1.65M [00:04<00:00, 379kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.53MB/s]


In [10]:
train_loader = DataLoader(MNIST_train, batch_size= 64, shuffle= True)
train_loader


<torch.utils.data.dataloader.DataLoader at 0x78d125fefcb0>

In [14]:
class Generator(nn.Module):
    def __init__(self, noise_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 4096),
            nn.ReLU(),
            nn.Linear(4096, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Linear(1024, img_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.model(x)
    

class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 4),
            nn.ReLU(),
            nn.Linear(4,2),
            nn.Sigmoid(),
        )
    def forward(self, z):
        return self.model(z)

In [19]:
noise_dim = 100
img_dim = 28*28
generator = Generator(noise_dim=noise_dim, img_dim= img_dim).to(device)
discriminator = Discriminator(img_dim=img_dim).to(device)

# optimizers 
g_optim = optim.Adam(generator.parameters(), lr= 0.0002)
d_optim = optim.Adam(discriminator.parameters(), lr= 0.0002)

# loss
criterion = nn.BCELoss()


In [22]:
def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake_imgs = generator(fixed_noise).reshape(-1, 1, 28, 28)
        fake_imgs = fake_imgs * 0.5 + 0.5
    grid = torchvision.utils.make_grid(fake_imgs, nrow= 8)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0).cup().numpy())
    plt.show()
    generator.train()

In [None]:

def train_gan(train_loader, epochs, mode='one_on_one'):
    fixed_noise = torch.randn(64, noise_dim).to(device)
    for epoch in range(epochs):
        for batch_idx , (real, _) in enumerate(train_loader):
            batch_size = real.size(0)
            real = real.view(batch_size, -1).to(device)

            #Create real and fake labels
            real_labels = torch.ones(batch_size, 1, device= device)
            fake_labels = torch.zeros(batch_size, 1, device= device)

            ### Train Discriminator

## will do it tomorrow