In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch.utils.data.dataloader import DataLoader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

training_data = torchvision.datasets.MNIST(
    root="data",
    train = True,
    transform=transform,
    download = True
)
train_loader = DataLoader(training_data,batch_size = 128, shuffle = True)

In [None]:
class Generator(nn.Module):
  def __init__(self,z_dim,img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim,256),
        nn.ReLU(True),
        nn.Linear(256,1024),
        nn.ReLU(True),
        nn.Linear(1024,img_dim),
        nn.Tanh()
    )
  def forward(self,z):
    return self.gen(z)

In [None]:
class Discriminator(nn.Module):
  def __init__(self,img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim,512),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Linear(512,256),
        nn.LeakyReLU(0.2,inplace=True),
        nn.Linear(256,1),
        nn.Sigmoid()
    )
  def forward(self,img):
    return self.disc(img)


In [None]:
noise_dim = 100
img_dim = 28 * 28

generator = Generator(noise_dim,img_dim).to(device)
discriminator = Discriminator(img_dim).to(device)

g_optimizer = optim.Adam(generator.parameters(),lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(),lr=0.0002)

criterion = nn.BCELoss()

In [None]:
generator = generator.to(device)
discriminator = discriminator.to(device)

In [None]:
def show_generated_images(epoch,generator,fixed_noise):
  generator.eval()
  with torch.no_grad():
    fake_imgs = generator(fixed_noise).reshape(-1,1,28,28)
    fake_imgs = fake_imgs * 0.5 + 0.5
  grid = torchvision.utils.make_grid(fake_imgs,nrow=8)
  plt.figure(figsize = (8,8))
  plt.imshow(grid.permute(1,2,0).cpu().numpy())
  plt.title(f'Generated Images at epoch {epoch}')
  plt.axis('off')
  plt.show()
  generator.train()


In [None]:
def train_gan(train_loader,num_epochs,num_gen=1,num_disc=1):
  fixed_noise = torch.randn(64,noise_dim).to(device)
  for epoch in range(num_epochs):
    for batch_idx,(real,_) in enumerate(train_loader):
      batch_size = real.shape[0]
      real = real.view(batch_size,-1).to(device)
      real_labels = torch.ones(batch_size,1).to(device)
      fake_labels = torch.zeros(batch_size,1).to(device)


      for _ in range(num_disc):
        #  1. Train Discriminator
        outputs = discriminator(real)
        d_real_loss = criterion(outputs,real_labels)
        real_score = outputs

        z = torch.randn(batch_size,noise_dim).to(device)
        fake_imgs = generator(z)
        outputs = discriminator(fake_imgs.detach())
        d_fake_loss = criterion(outputs,fake_labels)
        fake_score = outputs


        d_loss = d_real_loss + d_fake_loss
        discriminator.zero_grad()
        d_loss.backward()
        d_optimizer.step()

      for _ in range(num_gen):
        #  2. Train Generator
        z = torch.randn(batch_size,noise_dim).to(device)
        fake_imgs = generator(z)
        outputs = discriminator(fake_imgs)

        g_loss = criterion(outputs,real_labels)  # Think why real_labels are being used ??
        generator.zero_grad()
        g_loss.backward()
        g_optimizer.step()


    if (epoch + 1)%10 == 0:
      print(f"Epoch : [{epoch+1}/{num_epochs}] , D_Loss : {d_loss.item():.4f}, G_Loss : {g_loss.item():.4f}")
      show_generated_images(epoch+1,generator,fixed_noise)

In [None]:
train_gan(train_loader,num_epochs=50,num_gen=1,num_disc=1)

In [None]:
# Visualizing Saturating vs Non-Saturating GAN Losses and Their Gradients
import numpy as np
import matplotlib.pyplot as plt

# Discriminator output range
d = np.linspace(1e-5, 1 - 1e-5, 500)

# Loss functions
saturating_loss = np.log(1 - d)                 # Generator minimax loss
non_saturating_loss = -np.log(d)                # Non-saturating loss

# Gradients (magnitude)
grad_saturating = 1 / (1 - d)
grad_non_saturating = 1 / d

# ---- Plot 1: Loss curves ----
plt.figure()
plt.plot(d, saturating_loss, label="Saturating Loss: log(1 - D(G(z)))")
plt.plot(d, non_saturating_loss, label="Non-Saturating Loss: -log(D(G(z)))")
plt.xlabel("D(G(z))")
plt.ylabel("Loss")
plt.title("Generator Loss Functions")
plt.legend()
plt.show()

# ---- Plot 2: Gradient magnitude curves ----
plt.figure()
plt.plot(d, grad_saturating, label="Gradient (Saturating Loss)")
plt.plot(d, grad_non_saturating, label="Gradient (Non-Saturating Loss)")
plt.xlabel("D(G(z))")
plt.ylabel("Gradient Magnitude")
plt.title("Gradient Behavior of Generator Losses")
plt.ylim(0, 20)
plt.legend()
plt.show()



In [None]:
# Correct tangent visualization for Saturating vs Non-Saturating GAN losses
import numpy as np
import matplotlib.pyplot as plt

# Discriminator output range
d = np.linspace(1e-5, 1 - 1e-5, 500)

# Loss functions
saturating_loss = np.log(1 - d)          # minimax loss
non_saturating_loss = -np.log(d)         # non-saturating loss

# Point where discriminator is very confident
d0 = 0.05

# Loss values at d0
loss0_sat = np.log(1 - d0)
loss0_ns = -np.log(d0)

# Correct derivatives (slopes)
grad_sat = -1 / (1 - d0)   # derivative of log(1 - x)
grad_ns = -1 / d0          # derivative of -log(x)

# Tangent x-range (local region)
tangent_x = np.linspace(d0 - 0.03, d0 + 0.03, 100)

# Tangent lines
tangent_sat = loss0_sat + grad_sat * (tangent_x - d0)
tangent_ns = loss0_ns + grad_ns * (tangent_x - d0)

# Plot
plt.figure(figsize=(8, 6))
plt.plot(d, saturating_loss, label="Saturating Loss", linewidth=2)
plt.plot(d, non_saturating_loss, label="Non-Saturating Loss", linewidth=2)

plt.plot(tangent_x, tangent_sat, "--", label="Tangent (Saturating)", linewidth=2)
plt.plot(tangent_x, tangent_ns, "--", label="Tangent (Non-Saturating)", linewidth=2)

plt.scatter([d0], [loss0_sat], s=60)
plt.scatter([d0], [loss0_ns], s=60)

plt.xlabel("D(G(z))")
plt.ylabel("Loss")
plt.title("Slope (Gradient) Comparison at Low D(G(z))")
plt.legend()
plt.show()
