# Enhanced GAN Training on Jetstream2
## CS650 - Neural Networks & Deep Learning
### Student: Matthew [Your ID Here]
---
This notebook implements an enhanced Generative Adversarial Network

In [None]:
STUDENT_NAME = "Matthew Walker"
STUDENT_ID = "901635710"
COURSE = "CS650"
TASK = "Jetstream2 GPU Training Task"

print("="*60)
print(f"Student: {STUDENT_NAME}")
print(f"ID: {STUDENT_ID}")
print(f"Course: {COURSE}")
print(f"Task: {TASK}")
print("="*60)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from matplotlib import gridspec
from torchvision.utils import save_image, make_grid
import numpy as np
from IPython.display import clear_output
import time
from datetime import datetime

plt.style.use('seaborn-v0_8-darkgrid')
print("✓ All libraries imported successfully")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("\n" + "="*60)
print("GPU CONFIGURATION")
print("="*60)
print(f"Device: {device}")

if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"PyTorch Version: {torch.__version__}")
    print("✓ GPU is available and will be used for training")
else:
    print("⚠ WARNING: GPU not available")
print("="*60)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)
print(f"✓ Dataset loaded: {len(dataset)} images")

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 256), nn.BatchNorm1d(256), nn.ReLU(True),
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ReLU(True),
            nn.Linear(512, 1024), nn.BatchNorm1d(1024), nn.ReLU(True),
            nn.Linear(1024, img_dim), nn.Tanh()
        )
    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(1024, 512), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(256, 1), nn.Sigmoid()
        )
    def forward(self, x):
        return self.model(x)

generator = Generator().to(device)
discriminator = Discriminator().to(device)
print("✓ Models initialized")

In [None]:
criterion = nn.BCELoss()
optim_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

z_dim = 100
epochs = 30
os.makedirs("samples", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
fixed_noise = torch.randn(64, z_dim).to(device)
print(f"✓ Training config: {epochs} epochs")

In [None]:
print(f"STARTING TRAINING - {STUDENT_NAME} (ID: {STUDENT_ID})")
print(f"Device: {device}")

g_losses = []
d_losses = []
start_time = time.time()

for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.view(-1, 28*28).to(device)
        batch_size = real_imgs.size(0)

        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z).detach()
        d_loss_real = criterion(discriminator(real_imgs), torch.ones(batch_size, 1).to(device))
        d_loss_fake = criterion(discriminator(fake_imgs), torch.zeros(batch_size, 1).to(device))
        d_loss = d_loss_real + d_loss_fake
        optim_D.zero_grad()
        d_loss.backward()
        optim_D.step()

        z = torch.randn(batch_size, z_dim).to(device)
        fake_imgs = generator(z)
        g_loss = criterion(discriminator(fake_imgs), torch.ones(batch_size, 1).to(device))
        optim_G.zero_grad()
        g_loss.backward()
        optim_G.step()

        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())

    print(f"Epoch [{epoch+1}/{epochs}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")
    
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            fake = generator(fixed_noise).view(-1, 1, 28, 28)
            save_image(fake, f"samples/fake_epoch_{epoch+1:03d}.png", normalize=True)

print(f"\nTRAINING COMPLETE! Time: {(time.time()-start_time)/60:.2f} min")
print(f"Final G Loss: {g_losses[-1]:.4f} | D Loss: {d_losses[-1]:.4f}")

In [None]:
torch.save({
    'generator': generator.state_dict(),
    'discriminator': discriminator.state_dict(),
    'g_losses': g_losses,
    'd_losses': d_losses,
    'student_name': STUDENT_NAME,
    'student_id': STUDENT_ID
}, 'checkpoints/gan_checkpoint.pth')
print("✓ Model saved")