In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
z_dim = 100
num_epochs = 50
learning_rate = 0.0002

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])
mnist_dataset = MNIST(root="./data", train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

# Data augmentation functions
def rotate_images(images, angle):
    # Rotate each image in the batch
    rotated = torch.stack([transforms.functional.rotate(img, angle) for img in images])
    return rotated


def augment_data(images):
    return [
        images,  # Original
        rotate_images(images, 90),
        rotate_images(images, 180),
        rotate_images(images, 270)
    ]

# Generator architecture
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.net(z)
        return x.view(-1, 1, 28, 28)

# Discriminator architecture
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(0.2),
            nn.Linear(512, 256),
            nn.ReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)

# Initialize models and optimizers
generator = Generator().to(device)
discriminators = [Discriminator().to(device) for _ in range(4)]
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizers_D = [optim.Adam(d.parameters(), lr=learning_rate) for d in discriminators]
criterion = nn.BCELoss()

# Training loop
for epoch in range(num_epochs):
    for real_data, _ in dataloader:
        real_data = real_data.to(device)
        augmented_data = augment_data(real_data)

        # Train discriminators
        for i, D in enumerate(discriminators):
            D.zero_grad()
            real_preds = D(augmented_data[i])
            real_loss = criterion(real_preds, torch.ones_like(real_preds))
            z = torch.randn(batch_size, z_dim, device=device)
            fake_data = generator(z)
            fake_augmented = rotate_images(fake_data, 90 * i)
            fake_preds = D(fake_augmented)
            fake_loss = criterion(fake_preds, torch.zeros_like(fake_preds))
            d_loss = real_loss + fake_loss
            d_loss.backward()
            optimizers_D[i].step()

        # Train generator
        generator.zero_grad()
        g_loss = 0
        for i, D in enumerate(discriminators):
            z = torch.randn(batch_size, z_dim, device=device)
            fake_data = generator(z)
            fake_augmented = rotate_images(fake_data, 90 * i)
            preds = D(fake_augmented)
            g_loss += criterion(preds, torch.ones_like(preds))
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

# Generate and visualize samples
generator.eval()
z = torch.randn(16, z_dim, device=device)
samples = generator(z).detach().cpu()
grid = torchvision.utils.make_grid(samples, nrow=4, normalize=True)
plt.imshow(grid.permute(1, 2, 0))
plt.show()


Epoch [1/50], Loss D: 0.2568, Loss G: 14.6054
Epoch [2/50], Loss D: 0.8554, Loss G: 5.0883
Epoch [3/50], Loss D: 0.2426, Loss G: 10.7112
Epoch [4/50], Loss D: 0.2688, Loss G: 7.5371
Epoch [5/50], Loss D: 0.1414, Loss G: 15.5320
Epoch [6/50], Loss D: 0.3862, Loss G: 15.1802
