In [None]:
import torch
import torchvision
from torch import nn
from torch import optim
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)

# Define hyperparameters and variables
LEARNING_RATE = 0.0005
BATCH_SIZE = 1024
IMAGE_SIZE = 64
EPOCHS = 50
image_channels = 1
noise_channels = 256
gen_features = 64
disc_features = 64

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the transform
data_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

# Load the dataset
dataset = FashionMNIST(root="dataset/", train=True, transform=data_transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# Generator
class Generator(nn.Module):
    def __init__(self, noise_channels, image_channels, features):
        super(Generator, self).__init__()
        """
        The generator model uses 4 ConvTranspose blocks. Each block contains 
        a ConvTranspose2d, BatchNorm2d, and ReLU activation.
        """
        self.model = nn.Sequential(
            # Transpose block 1
            nn.ConvTranspose2d(noise_channels, features*16, kernel_size=4, stride=1, padding=0),
            nn.ReLU(),
            # Transpose block 2
            nn.ConvTranspose2d(features*16, features*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*8),
            nn.ReLU(),
            # Transpose block 3
            nn.ConvTranspose2d(features*8, features*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*4),
            nn.ReLU(),
            # Transpose block 4
            nn.ConvTranspose2d(features*4, features*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features*2),
            nn.ReLU(),
            # Last transpose block
            nn.ConvTranspose2d(features*2, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.model(x)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self, image_channels, features):
        super(Discriminator, self).__init__()
        """
        The discriminator model has 5 Conv blocks with Conv2d, BatchNorm, and LeakyReLU activation.
        """
        self.model = nn.Sequential(
            # Conv block 1
            nn.Conv2d(image_channels, features, kernel_size=4, stride=2, padding=1),  # 64x64 -> 32x32
            nn.LeakyReLU(0.2),
            # Conv block 2
            nn.Conv2d(features, features*2, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2),
            # Conv block 3
            nn.Conv2d(features*2, features*4, kernel_size=4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2),
            # Conv block 4
            nn.Conv2d(features*4, features*8, kernel_size=4, stride=2, padding=1),  # 8x8 -> 4x4
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2),
            # Conv block 5
            nn.Conv2d(features*8, 1, kernel_size=4, stride=1, padding=0),  # 4x4 -> 1x1
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.model(x)
        return output.view(-1)  # Flatten to [batch_size]

# Load models
gen_model = Generator(noise_channels, image_channels, gen_features).to(device)
disc_model = Discriminator(image_channels, disc_features).to(device)

# Setup optimizers
gen_optimizer = optim.Adam(gen_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
disc_optimizer = optim.Adam(disc_model.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

# Define loss function
criterion = nn.BCELoss()

# Set models to training mode
gen_model.train()
disc_model.train()

# Define labels with smoothing
fake_label = 0.1
real_label = 0.9

# Define fixed noise for consistent image generation
fixed_noise = torch.randn(64, noise_channels, 1, 1).to(device)

# TensorBoard writers
writer_real = SummaryWriter(f"runs/fashion/test_real")
writer_fake = SummaryWriter(f"runs/fashion/test_fake")

# Loss tracking
gen_losses = []
disc_losses = []

# Training loop
print("Start training...")
step = 0

for epoch in range(EPOCHS):
    for batch_idx, (data, target) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        data = data.to(device)
        batch_size = data.shape[0]

        # Train discriminator on real data
        disc_model.zero_grad()
        label = (torch.ones(batch_size) * real_label).to(device)
        output = disc_model(data).reshape(-1)
        real_disc_loss = criterion(output, label)
        d_x = output.mean().item()

        # Train discriminator on fake data
        noise = torch.randn(batch_size, noise_channels, 1, 1, device=device)
        fake = gen_model(noise)
        label = (torch.ones(batch_size) * fake_label).to(device)
        output = disc_model(fake.detach()).reshape(-1)
        fake_disc_loss = criterion(output, label)

        # Calculate discriminator loss
        disc_loss = real_disc_loss + fake_disc_loss
        disc_loss.backward()
        disc_optimizer.step()

        # Train generator
        gen_model.zero_grad()
        label = torch.ones(batch_size).to(device)
        output = disc_model(fake).reshape(-1)
        gen_loss = criterion(output, label)
        gen_loss.backward()
        gen_optimizer.step()

        # Store losses
        gen_losses.append(gen_loss.item())
        disc_losses.append(disc_loss.item())

        # Log losses and images
        if batch_idx % 50 == 0:
            step += 1
            print(
                f"Epoch: {epoch} ===== Batch: {batch_idx}/{len(dataloader)} ===== "
                f"Disc loss: {disc_loss:.4f} ===== Gen loss: {gen_loss:.4f}"
            )
            with torch.no_grad():
                fake_images = gen_model(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(data[:40], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake_images[:40], normalize=True)
                writer_real.add_image("Real images", img_grid_real, global_step=step)
                writer_fake.add_image("Generated images", img_grid_fake, global_step=step)

    # Save generated images every 10 epochs or at epoch 1
    if (epoch + 1) % 10 == 0 or epoch == 0:
        with torch.no_grad():
            fake_images = gen_model(fixed_noise).detach().cpu()
        plt.figure(figsize=(10, 10))
        for j in range(64):
            plt.subplot(8, 8, j+1)
            plt.imshow(fake_images[j].squeeze(), cmap='gray')
            plt.axis('off')
        plt.savefig(f'generated_images_epoch_{epoch+1}.png')
        plt.close()

# Plot losses
plt.figure(figsize=(10, 5))
plt.plot(gen_losses, label='Generator Loss')
plt.plot(disc_losses, label='Discriminator Loss')
plt.title('Generator and Discriminator Losses During Training')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig('loss_plot.png')
plt.close()