In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

In [None]:
# --- Step 1: Hyperparameters and Setup ---
# Define key training parameters
BATCH_SIZE = 64
LATENT_DIM = 100  # Dimension of the random noise vector
IMAGE_SIZE = 28 * 28  # 28x28 pixels
NUM_EPOCHS = 50
LEARNING_RATE = 0.0002

In [None]:

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Create a directory to save generated images
if not os.path.exists('generated_images'):
    os.makedirs('generated_images)

In [None]:
# --- Step 2: Define the Generator and Discriminator Networks ---

class Generator(nn.Module):
    """
    The Generator network takes a random noise vector as input and outputs a fake image.
    """
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(LATENT_DIM, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, IMAGE_SIZE),
            nn.Tanh()  # Tanh activation to output values between -1 and 1
        )
    
    def forward(self, z):
        # The input z is a random noise vector
        img = self.model(z)
        # Reshape the output to be an image (1 channel, 28x28)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [None]:
class Discriminator(nn.Module):
    """
    The Discriminator network takes an image as input and outputs a single value
    representing the probability that the image is real.
    """
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(IMAGE_SIZE, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # Sigmoid activation to output a probability between 0 and 1
        )
    def forward(self, img):
        # Flatten the image into a vector
        img_flat = img.view(img.size(0), -1)
        # Pass the flattened image through the network
        validity = self.model(img_flat)
        return validity

In [None]:
# --- Step 3: Load and Prepare Data ---
# We use a transform that normalizes the images to the range [-1, 1]
# to match the output of the Generator's Tanh activation.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# Download and load the training data
train_dataset = datasets.FashionMNIST(
    root='./data', 
    train=True, 
    transform=transform, 
    download=True
)

In [None]:
train_loader = DataLoader(
    dataset=train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

In [None]:
# --- Step 4: Instantiate Models, Loss, and Optimizers ---
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
# Binary cross-entropy loss is standard for GANs
adversarial_loss = nn.BCELoss()

In [None]:
# Optimizers for both networks. Adam is a good choice.
optimizer_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE)
optimizer_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE)


In [None]:
# --- Step 5: Training Loop ---
print("Starting GAN training...")

for epoch in range(NUM_EPOCHS):
    for i, (imgs, _) in enumerate(train_loader):
        # Create labels for real and fake images
        real_labels = torch.ones(imgs.size(0), 1).to(device)
        fake_labels = torch.zeros(imgs.size(0), 1).to(device)
        
        # --- Train the Discriminator ---
        optimizer_D.zero_grad()

        # 1. Train with real images
        real_imgs = imgs.to(device)
        real_validity = discriminator(real_imgs)
        d_loss_real = adversarial_loss(real_validity, real_labels)
        
        # 2. Train with fake images
        z = torch.randn(imgs.size(0), LATENT_DIM).to(device)
        fake_imgs = generator(z).detach() # Detach to prevent gradients from flowing to the Generator
        fake_validity = discriminator(fake_imgs)
        d_loss_fake = adversarial_loss(fake_validity, fake_labels)
        
        # Combine losses and update Discriminator
        d_loss = (d_loss_real + d_loss_fake) / 2
        d_loss.backward()
        optimizer_D.step()

        
        # --- Train the Generator ---
        optimizer_G.zero_grad()

        # Generate new fake images and calculate their validity
        z = torch.randn(imgs.size(0), LATENT_DIM).to(device)
        gen_imgs = generator(z)
        gen_validity = discriminator(gen_imgs)

        # Generator's loss is how well it fools the discriminator
        g_loss = adversarial_loss(gen_validity, real_labels)
        
        # Update Generator
        g_loss.backward()
        optimizer_G.step()
         
        # Print progress
        if i % 100 == 0:
            print(f"[Epoch {epoch+1}/{NUM_EPOCHS}] [Batch {i}/{len(train_loader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
        
        # Save a sample of generated images after each epoch
    if (epoch + 1) % 10 == 0:
        z_sample = torch.randn(16, LATENT_DIM).to(device)
        generated_sample = generator(z_sample)
        save_image(generated_sample.data, f'generated_images/epoch_{epoch+1}.png', normalize=True)

print("\nGAN training finished.")
print("Generated images are saved in the 'generated_images' directory.")

            