In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils
import matplotlib.pyplot as plt 
from torchvision.datasets import MNIST
import numpy as np

In [None]:
# Define a simple generator and discriminator for CIFAR-10
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # TODO: Define the generator architecture for CIFAR-10
        # consider that the output must match the size of the images (3*32*32)
        
        self.fc = nn.Sequential(
            nn.Linear(100, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(True),
            nn.Linear(1024, 3*32*32),
            nn.Tanh()
        )

    def forward(self, x):
        return self.fc(x).view(x.size(0), 3, 32,32)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # TODO: Define the discriminator architecture for CIFAR-10
        # consider that:
        # the input must match one image (3*32*32)
        # the output must match a number
        # self.fc = nn.Sequential(
        #     nn.Linear(3*32*32, 128),
        #     nn.ReLU(),
        #     nn.Linear(128, 1),
        #     nn.Sigmoid()
        # )
        self.fc = nn.Sequential(
            nn.Conv2d(3, 64, 3, 2, 1),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.Dropout2d(),
            nn.Linear(128*8*8, 50),
            nn.ReLU(),
            nn.Linear(50, 1),
        )

    def forward(self, x):
        return self.fc(x.view(x.size(0), 3, 32,32))




In [None]:
# Initialize the models
generator = Generator()
discriminator = Discriminator()

print(generator)
print(discriminator)

# Define loss function and optimizers
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Lists to store losses for plotting
d_losses = []
g_losses = []

# Data loading and preprocessing (using CIFAR-10 dataset)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
# dataset = MNIST(root='./data', train=True, transform=transform, download=True)

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

print(len(dataset), len(dataloader))
# for batch_idx, (images, labels) in enumerate(dataloader):
#     print(images.shape, labels.shape)
#     break




In [None]:
def show_generated_images(images, nrow=4):
    """
    Displays generated images in a grid.

    Args:
        images (torch.Tensor): Tensor of images to display.
        nrow (int): Number of images per row in the grid.
    """
    # Denormalize the images
    images = images * 0.5 + 0.5  # Assuming images were normalized to [-1, 1]

    # Convert to numpy array
    grid_img = vutils.make_grid(images.cpu(), nrow=nrow)
    npimg = grid_img.numpy()

    # Plot the images
    plt.figure(figsize=(4, 4))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# Training loop
def train_gan(generator, discriminator, dataloader, num_epochs):
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1)
            real_labels = torch.ones(batch_size, 1, 1, 1)
            fake_labels = torch.zeros(batch_size, 1, 1, 1)


            # Train the discriminator
            optimizer_D.zero_grad()
            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)
            d_loss_real.backward()

            z = torch.randn(batch_size, 100).to(real_images.device)
            fake_images = generator(z)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            d_loss_fake.backward()
            d_loss = d_loss_real + d_loss_fake
            optimizer_D.step()

            # Train the generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

            d_losses.append(d_loss.item())
            g_losses.append(g_loss.item())

            if (i + 1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{len(dataloader)}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')

        # Generate and save a sample of fake images
        if (epoch + 1) % 2 == 0:
            with torch.no_grad():
                z = torch.randn(32, 100)
                fake_samples = generator(z)
                vutils.save_image(fake_samples, f'fake_cifar_samples_epoch_{epoch+1}.png', normalize=True)
                show_generated_images(fake_samples)


        

        # Plot the loss curves
        plt.figure(figsize=(10, 5))
        plt.title("Generator and Discriminator Loss")
        plt.plot(g_losses, label="G Loss")
        plt.plot(d_losses, label="D Loss")
        plt.xlabel("Iterations")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f'loss_plot_epoch_{epoch+1}.png')
        plt.show()

# Main training loop
train_gan(generator, discriminator, dataloader, num_epochs=4)
