In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm
import imageio as iio
from torchvision.utils import make_grid
from mpl_toolkits.axes_grid1 import ImageGrid
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, Image

In [None]:
# Define the display_image_grid function
def display_image_grid(images, rows, cols, title):
    fig, axs = plt.subplots(rows, cols, figsize=(10, 10))
    axs = axs.ravel()

    for i in range(rows * cols):
        img = transforms.ToPILImage()(images[i])
        axs[i].imshow(img)
        axs[i].axis('off')

    plt.suptitle(title)
    plt.show()

## Data Preprocessing and Dataset

In [None]:
# Set the device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download and extract the dataset
!wget https://www.dropbox.com/s/g0w7a3x1aw3oonf/SimpsonFaces.zip?dl=0
!unzip -o SimpsonFaces.zip?dl=0 -d extracted_data -x '__MACOSX/*'

# Custom Dataset Class
class MyDataset(Dataset):
    def __init__(self, image_path, transform):
        self.image_path = image_path
        self.images = os.listdir(image_path)
        self.transform = transform
        self._check_images()

    def _check_images(self):
        valid_images = []
        for img_name in self.images:
            img_path = os.path.join(self.image_path, img_name)
            try:
                iio.v2.imread(img_path)
                valid_images.append(img_name)
            except Exception as e:
                print(f"Error loading image {img_path}: {e}. Skipping...")
                os.remove(img_path)

        self.images = valid_images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        im_path = os.path.join(self.image_path, self.images[idx])
        im = iio.v2.imread(im_path)
        im = self.transform(im)
        return im

# Data Transformation
trans = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize([128, 128]),
    transforms.RandomRotation(5),
    transforms.RandomHorizontalFlip(0.1),
    transforms.ToTensor()
])

# Create DataLoader
batch_size = 32
dataset = MyDataset("extracted_data/cropped/", trans)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Generator and Discriminator Architecture

In [None]:
# Generator Model
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

class Generator(nn.Module):
    def __init__(self, Z):
        super(Generator, self).__init__()
        self.Z = Z

        self.gen_model = nn.Sequential(
            nn.Linear(self.Z, 1024*8*8),
            nn.BatchNorm1d(1024*8*8),
            nn.LeakyReLU(0.2),
            Reshape((-1, 1024, 8, 8)),
            nn.ConvTranspose2d(1024, 512, 5, 2, 1, 0),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, 5, 2, 2, 0),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 128, 5, 2, 2, 0),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 64, 5, 2, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 3, 5, 1, 1),
        )

    def forward(self, noise):
        x = self.gen_model(noise)
        x = torch.tanh(x)
        return x

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.disc_model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2),
        )
        self.linearization = nn.Sequential(
            nn.Flatten(1, -1),
            nn.Linear(1024*8*8, 1)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.disc_model(x)
        x = self.linearization(x)
        x = self.sigmoid(x)
        return x

# Weight Initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator and Discriminator Instances
Z = 100
generator = Generator(Z).to(device)
discriminator = Discriminator().to(device)
generator.apply(weights_init)
discriminator.apply(weights_init)

# Hyperparameters and Training
EPOCHS = 50
lrg = 0.0002
lrd = 0.0002

# Training Function
def train_GAN(EPOCHS, lrg, lrd, discriminator, generator, save_path='models/'):
    real_label = 1
    fake_label = 0

    discriminator.train()
    generator.train()

    optimizer_gen = torch.optim.Adam(generator.parameters(), lr=lrg, betas=(0.5, 0.999))
    optimizer_disc = torch.optim.Adam(discriminator.parameters(), lr=lrd, betas=(0.5, 0.999))
    loss_fn = nn.BCELoss()

    generator_losses = []
    discriminator_losses = []
    generated_images = []

    for epoch in range(1, EPOCHS + 1):
        pbar = tqdm(train_loader)

        total_gen_loss = 0.0
        total_disc_loss = 0.0
        num_samples = 0

        sample_noise = torch.randn(10, Z).to(device)

        for batch in pbar:
            inputs = batch
            inputs = inputs.to(device)
            inputs = (inputs - 0.5) * 2

            optimizer_disc.zero_grad()
            label = torch.full((inputs.shape[0], 1), real_label, dtype=torch.float, device=device)
            output_real = discriminator(inputs)
            errD_real = loss_fn(output_real, label)
            errD_real.backward()

            D_x = output_real.mean().item()

            noise = torch.randn(inputs.shape[0], Z).to(device)
            fake = generator(noise)
            label.fill_(fake_label)
            output_fake = discriminator(fake.detach())
            errD_fake = loss_fn(output_fake, label)
            errD_fake.backward()

            D_G_z1 = output_fake.mean().item()
            errD = errD_real + errD_fake
            optimizer_disc.step()

            optimizer_gen.zero_grad()
            label.fill_(real_label)
            output_fake = discriminator(fake)
            errG = loss_fn(output_fake, label)
            errG.backward()
            D_G_z2 = output_fake.mean().item()
            optimizer_gen.step()

            total_gen_loss += errG.item()
            total_disc_loss += errD.item()
            num_samples += inputs.size(0)

            pbar.set_description(f"Epoch {epoch}/{EPOCHS}: ")
            pbar.set_postfix({
                "generator_loss": errG.item(),
                "discriminator_loss": errD.item(),
                "D(x)": D_x,
                "D(G(z1))": D_G_z1,
                "D(G(z2))": D_G_z2
            })

        generator_losses.append(total_gen_loss / num_samples)
        discriminator_losses.append(total_disc_loss / num_samples)

        generations = generator(sample_noise).cpu()
        generations = (generations + 1) / 2
        generations = (generations * 255).clamp(0, 255).to(torch.uint8)
        generated_images.append(generations)

        display_image_grid(generations, 1, 10, f"Generated images at epoch {epoch}")

        torch.save(generator.state_dict(), os.path.join(save_path, f'generator_epoch_{epoch}.pth'))
        torch.save(discriminator.state_dict(), os.path.join(save_path, f'discriminator_epoch_{epoch}.pth'))

    return generator_losses, discriminator_losses, generated_images


# Specify the path where you want to save the models
save_path = 'models/'

# Ensure the directory exists before saving models
os.makedirs(save_path, exist_ok=True)

# Train GAN with model saving
generator_losses, discriminator_losses, generated_images = train_GAN(EPOCHS, lrg, lrd, discriminator, generator, save_path)

## Display Generated Images

In [None]:
# Display generated images collage
test_images = (generator(torch.randn(200, Z).to(device)).cpu() + 1) / 2
test_images = (test_images * 255).clamp(0, 255).to(torch.uint8)

plt.figure(figsize=(30, 60))
for i in range(200):
    plt.subplot(20, 10, i + 1)
    plt.imshow(np.clip(test_images[i].permute(1, 2, 0), 0, 255), interpolation='nearest', aspect='auto')
    plt.axis("off")

plt.subplots_adjust(wspace=0, hspace=0)
plt.savefig('collage_images.png', bbox_inches='tight', pad_inches=0)
with open('collage_images.png', 'rb') as f:
    display(Image(data=f.read(), format='png'))

## Visualize Training Progress

In [None]:
# Plot loss curve
def plot_loss(generator_losses, discriminator_losses):
    plt.figure(figsize=(10, 5))
    plt.plot(generator_losses, label='Generator Loss')
    plt.plot(discriminator_losses, label='Discriminator Loss')
    plt.title('Generator and Discriminator Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

plot_loss(generator_losses, discriminator_losses)

In [None]:
# Save generated images as GIF
frames = []
for epoch_images in generated_images:
    epoch_grid = np.zeros((128, 128 * 10, 3), dtype=np.uint8)
    for i, image in enumerate(epoch_images):
        image = image.permute(1, 2, 0)
        h_start, h_end = i * 128, (i + 1) * 128
        epoch_grid[:, h_start:h_end, :] = image
    frames.append(epoch_grid)

iio.mimsave('generated_images.gif', frames, duration=0.1)
with open('generated_images.gif', 'rb') as f:
    display(Image(data=f.read(), format='gif'))

## Create and Display GIFs

In [None]:
# Function to create a GIF from a list of images
# def create_gif(images, filename, duration=0.1):
#     frames = []
#     for image in images:
#         # Convert the image from (3, 128, 128) to (128, 128, 3) as imageio requires the channel dimension at the end
#         image = image.permute(1, 2, 0).numpy()
#         frames.append((image * 255).astype(np.uint8))

#     iio.mimsave(filename, frames, duration=duration)

# Extract the 5th image from each epoch and create a GIF for them
images_at_index_4 = [epoch[4] for epoch in generated_images]
create_gif(images_at_index_4, 'generated_image_index_4.gif')

# Extract the 3rd image from each epoch and create a GIF for them
images_at_index_3 = [epoch[3] for epoch in generated_images]
create_gif(images_at_index_3, 'generated_image_index_3.gif')

# Display the GIFs
with open('generated_image_index_4.gif', 'rb') as f:
    display(Image(data=f.read(), format='gif'))
with open('generated_image_index_3.gif', 'rb') as f:
    display(Image(data=f.read(), format='gif'))

In [None]:
# Function to Create GIF
def create_gif(images, filename, duration=0.1):
    frames = []
    for image in images:
        image = image.permute(1, 2, 0).numpy()
        frames.append((image * 255).astype(np.uint8))

    iio.mimsave(filename, frames, duration=duration)

# Display the GIFs
def display_gif(filename):
    with open(filename, 'rb') as f:
        display(Image(data=f.read(), format='png'))

display_gif('generated_images.gif')