In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# === Changes Applied ===
# 1. Switched to BCEWithLogitsLoss and removed Sigmoid from Discriminator
# 2. Moved fixed_noise outside the training loop for consistent epoch-to-epoch comparison


# 1. Environment Setup and Hyperparameter Definition
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 128
lr = 0.0002
num_epochs = 30
z_dim = 100  # Dimension of the latent space

# 2. Loading and Preprocessing the Dataset (FashionMNIST)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize pixel values to the range [-1, 1]
])
dataset = dsets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 3. Implementing the Generator Model for DCGAN (Using CNN)
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_channels=1, feature_map_g=64):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: (z_dim, 1, 1)
            nn.ConvTranspose2d(z_dim, feature_map_g * 2, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(feature_map_g * 2),
            nn.ReLU(True),
            # Output: (feature_map_g*2, 7, 7)
            nn.ConvTranspose2d(feature_map_g * 2, feature_map_g, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature_map_g),
            nn.ReLU(True),
            # Output: (feature_map_g, 14, 14)
            nn.ConvTranspose2d(feature_map_g, img_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()  # Adjust output to the range [-1, 1]
            # Final output: (img_channels, 28, 28)
        )
    def forward(self, x):
        return self.net(x)

# 4. Implementing the Discriminator Model for DCGAN (Using CNN)
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_map_d=64):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            # Input: (img_channels, 28, 28)
            nn.Conv2d(img_channels, feature_map_d, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: (feature_map_d, 14, 14)
            nn.Conv2d(feature_map_d, feature_map_d * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(feature_map_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # Output: (feature_map_d*2, 7, 7)
            nn.Conv2d(feature_map_d * 2, 1, kernel_size=7, stride=1, padding=0, bias=False),
            # nn.Sigmoid() # REMOVED this line for BCEWithLogitsLoss
            # CHANGED: Removed Sigmoid activation to use raw logits with BCEWithLogitsLoss
            # Output: (1, 1, 1) which is then flattened to (batch, 1)
        )
    def forward(self, x):
        x = self.net(x)
        return x.view(-1, 1)

# 5. Creating the Models and Allocating Them to the Device
generator = Generator(z_dim=z_dim).to(device)
discriminator = Discriminator().to(device)

# 6. Loss Function and Optimizers Setup
#adversarial_loss = nn.BCELoss()
adversarial_loss = nn.BCEWithLogitsLoss()  # Using raw logits for numerical stability
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 7. Weight Initialization (Initialize all Convolution layers with a normal distribution: mean 0, std 0.02)
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

fixed_noise = torch.randn(25, z_dim, 1, 1, device=device)  #Moved here — constant visual reference

# 8. Training Process
G_losses = []
D_losses = []

for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(data_loader):
        batch_size_i = imgs.size(0)

        valid = torch.ones((batch_size_i, 1), device=device)
        fake = torch.zeros((batch_size_i, 1), device=device)

        real_imgs = imgs.to(device)

        # ----- Train Generator -----
        optimizer_G.zero_grad()
        # Generate random noise in the shape (z_dim, 1, 1) from the latent space
        z = torch.randn(batch_size_i, z_dim, 1, 1, device=device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ----- Train Discriminator -----
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        if i % 100 == 0:
            print(f"[Epoch {epoch+1}/{num_epochs}] [Batch {i}/{len(data_loader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # Visualize generated images at the end of each epoch
    generator.eval()
    with torch.no_grad():
        #fixed_noise = torch.randn(25, z_dim, 1, 1, device=device)
        gen_imgs = generator(fixed_noise)
    generator.train()

    gen_imgs = (gen_imgs + 1) / 2  # Convert from [-1,1] to [0,1]
    grid = torchvision.utils.make_grid(gen_imgs, nrow=5, padding=2, normalize=False)
    np_grid = grid.cpu().numpy().transpose((1, 2, 0))

    plt.figure(figsize=(5,5))
    plt.imshow(np_grid)
    plt.title(f"Epoch {epoch+1}")
    plt.axis('off')
    plt.show()

# Visualize loss curves after training
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
