In [None]:
#!/usr/bin/env python3
# Imports and setup
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import torchvision
from torch.utils.data import DataLoader
import wandb
import matplotlib.pyplot as plt
import numpy as np

# Initialize W&B project
wandb.init(project="mnist_dcgan", config={
    "dataset": "MNIST",
    "framework": "PyTorch",
    "model": "DCGAN"
})

# Set seed for reproducibility
torch.manual_seed(42)

# Define preprocessing for MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download MNIST training dataset
dataset = datasets.MNIST(root='./data',
                         train=True,
                         download=True,
                         transform=transform)

# Create DataLoader with batch size and shuffle
batch_size = 128
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Get one batch to check shapes
real_batch = next(iter(dataloader))
images, labels = real_batch

# Print shape of batch to verify
print(images.shape)
print(labels.shape)

# Undo normalization for visual
images = images * 0.5 + 0.5
plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(images[i].squeeze(), cmap='gray')
    plt.axis('off')
plt.show()

In [None]:
# Hyperparameters
latent_dim = 100
lr = 0.0002
img_ch = 1
img_size = 28
Beta1 = 0.5
batch_size = 128
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim, img_ch):
        super(Generator, self).__init__()

        # Starting size will be 7x7 after first layer
        self.init_size = 7

        # Linear layer to expand latent vector
        self.fc = nn.Linear(latent_dim, 128 * self.init_size ** 2)

        self.conv_blocks = nn.Sequential(
          nn.BatchNorm2d(128),

          # Upsample to 14x14
          nn.Upsample(scale_factor=2),
          nn.Conv2d(128, 128, 3, stride=1, padding=1),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(inplace=True),

          # Upsample to 28x28
          nn.Upsample(scale_factor=2),
          nn.Conv2d(128, 64, 3, stride=1, padding=1),
          nn.BatchNorm2d(64),
          nn.LeakyReLU(inplace=True),

          # Output layer
          nn.Conv2d(64, img_ch, 3, stride=1, padding=1),
          nn.Tanh()
          # Output range [-1, 1]
        )
    def forward(self, z):
        # z shape: (batch_size, latent_dim)
        out = self.fc(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


In [None]:
# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, img_ch):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # Input: 1x28x28
            # 14x14
            nn.Conv2d(img_ch, 64, 4, stride=2, padding= 1),
            nn.LeakyReLU(0.2, inplace=True),

            # 7x7
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 3x3
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            )
        self.adv_layer = nn.Sequential(
        # Flatten and output single value
            nn.Flatten(),
            nn.Linear(256 * 3 * 3, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        out = self.model(img)
        validity = self.adv_layer(out)
        return validity

In [None]:
# Training function
# Weight initialization (important for GANs)
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.normal_(m.weight.data, 0.0, 0.02)
  elif classname.find('BatchNorm') != -1:
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0)

def train_dcgan(dataloader, num_epochs, batch_size):

  # Initialize models
  generator = Generator(latent_dim, img_ch).to(device)
  discriminator = Discriminator(img_ch).to(device)

  # Apply weight initialization
  generator.apply(weights_init)
  discriminator.apply(weights_init)
  # Loss function
  bce_loss = nn.BCELoss()
  # Optimizers
  optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(Beta1, 0.999))
  optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(Beta1, 0.999))
  # update WandB
  wandb.config.update({
      "latent_dim": latent_dim,
      "batch_size": batch_size,
      "learning_rate": lr,
      "num_epochs": num_epochs,
      "architecture": "Baseline DCGAN"
      })
  # wandb.watch(generator)
  # wandb.watch(discriminator)

  # Fixed noise for visualization
  fixed_noise = torch.randn(64, latent_dim).to(device)
  print(f"Generator parameters: {sum(p.numel() for p in generator.parameters()):,}")
  print(f"Discriminator parameters: {sum(p.numel() for p in discriminator.parameters()):,}")

  for epoch in range(num_epochs):
    for i, (real_imgs,_) in enumerate(dataloader):
      curr_batch_size = real_imgs.size(0) # Get current batch size
      real_imgs = real_imgs.to(device)

      # Labels for real and fake images, use curr_batch_size
      real_labels = torch.ones(curr_batch_size, 1).to(device)
      fake_labels = torch.zeros(curr_batch_size, 1).to(device)

      # ---------------------
      #  Train Discriminator
      # ---------------------

      optimizer_D.zero_grad()

      # Loss on real images
      real_out = discriminator(real_imgs)
      d_real_loss = bce_loss(real_out, real_labels)

      # Generate fake images
      z = torch.randn(curr_batch_size, latent_dim).to(device) # Use curr_batch_size for fake images
      fake_imgs = generator(z)

      # Loss on fake images
      fake_out = discriminator(fake_imgs.detach())
      d_fake_loss = bce_loss(fake_out, fake_labels)

      # Total discriminator loss
      d_loss = d_real_loss + d_fake_loss
      d_loss.backward()
      optimizer_D.step()

    # -----------------
    #  Train Generator
    # -----------------
    optimizer_G.zero_grad()

    # Generate fake images
    z = torch.randn(curr_batch_size, latent_dim).to(device) # Use curr_batch_size for fake images
    fake_imgs = generator(z)

    # Loss on fake images
    fake_out = discriminator(fake_imgs)
    g_loss = bce_loss(fake_out, real_labels)

    # Generator tries to fool discriminator
    g_loss.backward()
    optimizer_G.step()

    # ---------------------
    #  Log metrics
    # ---------------------

    if i % 100 == 0:
      print(f"[Epoch {epoch}/{num_epochs}],"
      f"[Batch {i}/{len(dataloader)}],"
      f"D_loss: {d_loss.item():.4f},"
      f"G_loss: {g_loss.item():.4f}")

    # Log to WandB at end of epoch
    wandb.log({
        "epoch": epoch,
        "d_loss": d_loss.item(),
        "g_loss": g_loss.item(),
        "d_real_acc": (real_out >= 0.5).float().mean().item(), # Corrected accuracy calculation
        "d_fake_acc": (fake_out < 0.5).float().mean().item() # Corrected accuracy calculation
        })

    # Generate and log images every 5 epochs
    if epoch % 5 == 0:
      with torch.no_grad():
        fake_imgs = generator(fixed_noise).detach().cpu()

        # Denormalize images from [-1, 1] to [0, 1]
        fake_imgs = (fake_imgs + 1) / 2
        # Create grid of images
        grid = torchvision.utils.make_grid(fake_imgs,
                                           nrow=8,
                                           padding=2,
                                           normalize=False)
        # Log to WandB
        wandb.log({"generated_images": wandb.Image(grid)})
        save_image(fake_imgs[:64],f'generated_epoch_{epoch}.png', nrow=8, padding=2, normalize=False)

# Save final models
  torch.save(generator.state_dict(), "generator.pth")
  torch.save(discriminator.state_dict(), "discriminator.pth")
  wandb.save("generator.pth")
  wandb.save("discriminator.pth")

  return generator, discriminator

In [None]:
# Main execution
# if __name__ == "__main__":
  # Train the DCGAN
generator, discriminator = train_dcgan(dataloader, epochs, batch_size)
wandb.finish()

Generator parameters: 856,065
Discriminator parameters: 659,905


In [None]:
print(f"Using device: {device}")