In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt 
import numpy as np

In [2]:
# Define a simple generator and discriminator for CIFAR-10
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # TODO: Define the generator architecture for CIFAR-10
        # consider that the output must match the size of the images (3*32*32)
        self.fc = nn.Sequential(
            nn.Linear(100, 1024),          # Larger layer size
            nn.BatchNorm1d(1024),          # Batch normalization
            nn.ReLU(),
            
            nn.Linear(1024, 2048),         # Larger hidden layer
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            
            nn.Linear(2048, 4096),         # Larger hidden layer
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            
            nn.Linear(4096, 3*32*32),      # Match output to image size
            nn.Tanh()                      # Tanh for normalized output
        )

    def forward(self, x):
        return self.fc(x).view(x.size(0), 3, 32,32)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # TODO: Define the discriminator architecture for CIFAR-10
        # consider that:
        # the input must match one image (3*32*32)
        # the output must match a number
        self.fc = nn.Sequential(
            nn.Linear(3*32*32, 4096),     # Larger input layer
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),              # Add dropout to prevent overfitting
            
            nn.Linear(4096, 2048),        # Decrease layer size
            nn.BatchNorm1d(2048),         # Batch normalization
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(2048, 1024),        # Smaller hidden layer
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 1),           # Output a single value (real/fake score)
            nn.Sigmoid() 
        )

    def forward(self, x):
        return self.fc(x.view(x.size(0), -1))


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

print(generator)
print(discriminator)

# Define loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Lists to store losses for plotting
d_losses = []
g_losses = []

# Data loading and preprocessing (using CIFAR-10 dataset)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
# dataset = MNIST(root='./data', train=True, transform=transform, download=True)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(len(dataset), len(dataloader))
# for batch_idx, (images, labels) in enumerate(dataloader):
#     print(images.shape, labels.shape)
#     break




Generator(
  (fc): Sequential(
    (0): Linear(in_features=100, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1024, out_features=2048, bias=True)
    (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=2048, out_features=4096, bias=True)
    (7): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=4096, out_features=3072, bias=True)
    (10): Tanh()
  )
)
Discriminator(
  (fc): Sequential(
    (0): Linear(in_features=3072, out_features=4096, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=4096, out_features=2048, bias=True)
    (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negat

In [4]:
def show_generated_images(images, nrow=8):
    """
    Displays generated images in a grid.

    Args:
        images (torch.Tensor): Tensor of images to display.
        nrow (int): Number of images per row in the grid.
    """
    # Denormalize the images
    images = images * 0.5 + 0.5  # Assuming images were normalized to [-1, 1]

    # Convert to numpy array
    grid_img = vutils.make_grid(images.cpu(), nrow=nrow)
    npimg = grid_img.numpy()

    # Plot the images
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# Training loop
def train_gan(generator, discriminator, dataloader, num_epochs):
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)


            # Train the discriminator
            optimizer_D.zero_grad()
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, 100).to(device)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            d_loss = d_loss_real + d_loss_fake
            optimizer_D.step()

            # Train the generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

        # Generate and save a sample of fake images
        if (epoch + 1) % 5 == 0:
            with torch.no_grad():
                z = torch.randn(32, 100).to(device)
                fake_samples = generator(z)
                vutils.save_image(fake_samples, f'fake_cifar_samples_epoch_{epoch+1}.png', normalize=True)
                show_generated_images(fake_samples)


        

        # Plot the loss curves
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss")
        plt.plot(g_losses, label="G Loss")
        plt.plot(d_losses, label="D Loss")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f'loss_plot_epoch_{epoch+1}.png')
        plt.show()

# Main training loop
train_gan(generator, discriminator, dataloader, num_epochs=50)


KeyboardInterrupt: 