<a href="https://colab.research.google.com/github/Galatonic-rebel/GenAI/blob/main/GenAI_Week_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Deep Convolutional GANs

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os

# Hyperparameters
batch_size = 128
z_dim = 100
image_size = 28
channels = 1
epochs = 50
lr = 0.0002
beta1 = 0.5

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs("generated_images", exist_ok=True)

In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim):
    super(Generator, self).__init__()
    self.net = nn.Sequential(
        nn.ConvTranspose2d(z_dim, 256, 7, 1, 0, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True),

        nn.ConvTranspose2d(256, 128,4, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),

        nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
        nn.Tanh()
    )

  def forward(self, z):
    return self.net(z)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.net = nn.Sequential(
        nn.Conv2d(1, 64, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Conv2d(64, 128, 4, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),

        nn.Flatten(),
        nn.Linear(128*7*7, 1),
        nn.Sigmoid()
    )

  def forward(self, img):
    return self.net(img)

In [None]:
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)

# Binary Cross Entropy loss
criterion = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
def generate_and_save_images(epoch):
  generator.eval()
  with torch.no_grad():
    z = torch.randn(64, z_dim, 1,1).to(device)
    fake_images = generator(z)
    fake_images = fake_images * 0.5 + 0.5
    save_image(fake_images, f"generated_images/sample_epoch_{epoch}.png", nrow=8)
  generator.train()


In [None]:
k = 3 # generator updates per iteration
p = 1 # discriminator updates per iteration

In [26]:

for epoch in range(1, epochs + 1):
    for i, (real_imgs, _) in enumerate(train_loader):
        batch_size_curr = real_imgs.size(0)
        real_imgs = real_imgs.to(device)

        real = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        ### ---- Train Discriminator p times ---- ###
        for _ in range(p):
            z = torch.randn(batch_size_curr, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            # Real
            real_validity = discriminator(real_imgs)
            d_real_loss = criterion(real_validity, real)

            # Fake
            fake_validity = discriminator(fake_imgs.detach())
            d_fake_loss = criterion(fake_validity, fake)

            d_loss = d_real_loss + d_fake_loss

            optimizer_D.zero_grad()
            d_loss.backward()
            optimizer_D.step()

        ### ---- Train Generator k times ---- ###
        for _ in range(k):
            z = torch.randn(batch_size_curr, z_dim, 1, 1, device=device)
            fake_imgs = generator(z)

            validity = discriminator(fake_imgs)
            g_loss = criterion(validity, real)  # fool D → label as real

            optimizer_G.zero_grad()
            g_loss.backward()
            optimizer_G.step()

        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
                  f"D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # Save sample images
    generator.eval()
    with torch.no_grad():
        z = torch.randn(64, z_dim, 1, 1, device=device)
        samples = generator(z)
        samples = samples * 0.5 + 0.5  # Denormalize
        save_image(samples, f"generated_images/epoch_{epoch}.png", nrow=8)
    generator.train()


[Epoch 1/50] [Batch 0/469] D Loss: 1.3528 | G Loss: 0.4326


KeyboardInterrupt: 