<a href="https://colab.research.google.com/github/Jeevan008/GANs/blob/main/GANs_G1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

import warnings
warnings.filterwarnings('ignore')

In [29]:
#Generator: Turns noise into fake images
class Generator(nn.Module):
    def __init__(self, noise_dim,input_size=100, output_size=784): #28*28=784
      super(Generator,self).__init__()
      self.model =nn.Sequential(
          nn.Linear(input_size,256),
          nn.ReLU(),
          nn.Linear(256,output_size),
          nn.Tanh()
      )

    def forward(self,z):
      return self.model(z)




In [30]:
#Discriminator: Classifies real or fake
class Discriminator(nn.Module):
  def __init__(self,input_size=784):
    super(Discriminator, self).__init__()
    self.model=nn.Sequential(
        nn.Linear(input_size, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, 1),
        nn.Sigmoid()  # Outputs probability
    )

  def forward(self, x):
    return self.model(x)

In [31]:
#Noise generator
def generator_noise(batch_size,z_dim):
  return torch.randn(batch_size,z_dim)


In [32]:
#Load MNIST data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

dataset = datasets.MNIST(root ='./data',train=True,download=True,transform=transform)
dataloader = DataLoader(dataset,batch_size=64,shuffle=True)

In [33]:
#Initialize the models, Optimizers,loss function
z_dim=100
G = Generator(z_dim)
D = Discriminator()
generator = Generator(z_dim)
discriminator = Discriminator()
criterion = nn.BCELoss()
G_opt = optim.Adam(generator.parameters(),lr=0.0002)
D_opt = optim.Adam(discriminator.parameters(),lr=0.0002)

In [35]:
#Training Loop
epochs =5
fixed_noise = generator_noise(16,z_dim)
os.makedirs('gan_outputs', exist_ok=True)

for epoch in range(epochs):
  for real_images,_ in dataloader:
    batch_size = real_images.size(0)
    real_images = real_images.view(batch_size, -1)

    real_labels=torch.ones(batch_size,1)
    fake_labels=torch.zeros(batch_size,1)

    #Train the discriminator
    noise = generator_noise(batch_size,z_dim)
    fake_images = G(noise).detach()
    D_loss_real = criterion(D(real_images),real_labels)
    D_loss_fake = criterion(D(fake_images),fake_labels)
    D_loss = D_loss_real + D_loss_fake

    D.zero_grad()
    D_loss.backward()
    D_opt.step()

    #Train the Generator
    noise = generator_noise(batch_size,z_dim)
    fake_images = G(noise)
    G_loss = criterion(D(fake_images),real_labels)

    G.zero_grad()
    G_loss.backward()
    G_opt.step()

  with torch.no_grad():
        samples = G(fixed_noise).view(-1, 1, 28, 28)
        grid = utils.make_grid(samples, nrow=4, normalize=True)
        plt.imshow(grid.permute(1, 2, 0).cpu())
        plt.axis('off')
        plt.title(f"Epoch {epoch+1}")
        plt.savefig(f"gan_outputs/epoch_{epoch+1}.png")
        plt.close()

