# Cycle-GAN for Horse-to-Cat Style Transfer

In this exercise you will try to make horses look like cats. Because this is one of the more difficult exercises a maximum amount of code is already given. However, you still need to add some code and do training and evalluation.

## Tasks
1. Add all necessary code to train a Cycle GAN. This includes especially data loaders.
2. Train and optimize the Cycle GAN. This step is not trivial and you should start with very small image resolutions to get a feeling for hyperparameters
3. Visualize generated images. For this task no quantifyable measure apart from the loss can be given. Qualitative examples of horse images which have been turned to cats should be shown.

**Important**: At the end you should write a report of adequate size, which will probably mean at least half a page. In the report you should describe how you approached the task. You should describe:
- Encountered difficulties (due to the method, e.g. "not enough training samples to converge", not technical like "I could not install a package over pip")
- Steps taken to alleviate difficulties
- General description of what you did, explain how you understood the task and what you did to solve it in general language, no code.
- Potential limitations of your approach, what could be issues, how could this be hard on different data or with slightly different conditions
- If you have an idea how this could be extended in an interesting way, describe it.

### Some Information 
### GAN (Generative Adversarial Network)
In a generative adversarial network (GAN) two neural networks contest with each other to generate data (Generator) and to guess if it's generated or true data (Discriminator. One networks's gain is another agent's loss.

Given a training set, this technique learns to generate new data with the same statistics as the training set. For example, a GAN trained on photographs can generate new photographs that look at least superficially authentic to human observers, having many realistic characteristics. Though originally proposed as a form of generative model for unsupervised learning, GANs have also proved useful for semi-supervised learning, fully supervised learning and reinforcement learning.

The core idea of a GAN is based on the "indirect" training through the discriminator which is a neural network that can tell how "realistic" the input seems, which itself is also being updated dynamically. This means that the generator is not trained to minimize the distance to a specific image, but rather to fool the discriminator. This enables the model to learn in an unsupervised manner.

#### Cycle GAN
A CycleGAN is an architecture for performing translations between two domains, such as between photos of horses and photos of zebras, or photos of night cities and photos of day cities. Unlike previous work like pix2pix, which requires paired training data, cycleGAN requires no paired data. For example, to train a pix2pix model to turn a summer scenery photo to winter scenery photo and back, the dataset must contain pairs of the same place in summer and winter, shot at the same angle; cycleGAN would only need a set of summer scenery photos, and an unrelated set of winter scenery photos.
                                                                                                                                            
Modified after https://en.wikipedia.org/wiki/Generative_adversarial_network

In [None]:
# 1. Dataset Loading and Preprocessing
# -------------------
# Instructions:
# - Download and prepare a dataset that contains images of horses and cats. You can choose datasets like MS COCO, Pascal VOC, or Tiny Imagenet. 
# - Use the pycocotools library for loading the MS COCO dataset, or implement a custom data loader if you use another dataset.
# - Preprocess the images: resize them to a smaller size (e.g., 32x32) for faster training and normalize the pixel values.
# - Create a custom Dataset class for loading horse and cat images from the dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pycocotools.coco import COCO
import os
import random

In [None]:
# Define the path to your dataset (adjust for your chosen dataset)
dataDir = './datasets/coco/'
dataType = 'train2017'
coco = COCO(os.path.join(dataDir, 'annotations', 'instances_' + dataType + '.json'))

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# Here is a custom dataset class that resizes images to 32x32. You will need this dimension to run fast enough
# 1. Dataset Loading and Preprocessing
# -------------------
# Download and prepare the MS COCO dataset (or use a similar dataset) containing horse and cat images
# We'll use pycocotools for MS COCO dataset handling. Images will be resized to 32x32 for faster training.

class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, coco, category_id, transform=None):
        self.coco = coco
        self.category_id = category_id
        self.transform = transform
        self.img_ids = list(coco.getImgIds(catIds=[category_id]))
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs([img_id])[0]
        img_path = os.path.join(dataDir, dataType, img_info['file_name'])
        img = Image.open(img_path).convert('RGB')
        
        # Resize the image to 32x32
        img = img.resize((32, 32)) 
        
        
        if self.transform:            
            img = self.transform(img)
        
        return img

# Define the transformation pipeline for the images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create datasets for horses and cats (category IDs may vary)
horse_category_id = 24  # Example category for 'horse'
cat_category_id = 17    # Example category for 'cat'

horse_dataset = ImageDataset(coco, horse_category_id, transform=transform)
cat_dataset = ImageDataset(coco, cat_category_id, transform=transform)

# Create dataloaders
horse_loader = DataLoader(horse_dataset, batch_size=32, shuffle=True)
cat_loader = DataLoader(cat_dataset, batch_size=32, shuffle=True)

In [None]:
# 2. Cycle-GAN Model Architecture
# -------------------
# Instructions:
# - Implement the Generator and Discriminator models for Cycle-GAN using PyTorch.
# - The Generator should use a U-Net like architecture for image-to-image translation.
# - The Discriminator should classify whether an image is real or fake.

In [None]:
# Generator Model
# Define the architecture for the Generator. This should include downsampling and upsampling blocks to transform images from one domain (horse) to another (cat).

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Build a simple U-Net like architecture with convolutional layers
        pass  # Replace this with your implementation

    def forward(self, x):
        # Implement the forward pass through the Generator network
        pass  # Replace this with your implementation

# Discriminator Model
# Define the architecture for the Discriminator. This model will classify images as real or fake.

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Build a series of convolutional layers for the Discriminator
        pass  # Replace this with your implementation

    def forward(self, x):
        # Implement the forward pass through the Discriminator network
        pass  # Replace this with your implementation

In [None]:
# 3. Loss Functions
# -------------------
# Instructions:
# - Implement the adversarial loss and cycle consistency loss.
# - The adversarial loss (for the generator) should be based on Binary Cross-Entropy (BCE).
# - The cycle consistency loss should be L1 loss, ensuring that after translating from one domain to the other and back again, the image is the same.

In [None]:
def adversarial_loss(real, fake):
    # Implement the adversarial loss function
    pass  # Replace this with your implementation

def cycle_loss(real, reconstructed):
    # Implement the cycle consistency loss function
    pass  # Replace this with your implementation

In [None]:
# 4. Training Loop
# -------------------
# Instructions:
# - Initialize the Generator and Discriminator models.
# - Define optimizers (Adam optimizer with a learning rate of 0.0002 and betas (0.5, 0.999)).
# - Implement the training loop for both Generators and Discriminators.
# - Train the Generators using adversarial loss and cycle loss.
# - Train the Discriminators using real and fake images.
# - Print the loss values during the training for monitoring the progress.

In [None]:
# Initialize the models
G_A_to_B = Generator().to(device)
G_B_to_A = Generator().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(list(G_A_to_B.parameters()) + list(G_B_to_A.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
# Training loop
# Instructions: Implement the loop where the model trains for a number of epochs.
# - In each epoch, generate fake images using the Generator models.
# - Calculate adversarial and cycle losses.
# - Update the weights of the Generators and Discriminators using backpropagation.

In [None]:
num_epochs = 50
for epoch in range(num_epochs):
    for i, (horse_real, cat_real) in enumerate(zip(horse_loader, cat_loader)):
        # Move data to device
        horse_real = horse_real.to(device)
        cat_real = cat_real.to(device)

        # Train Generators
        optimizer_G.zero_grad()

        # Generate fake images
        horse_fake = G_A_to_B(horse_real)
        cat_fake = G_B_to_A(cat_real)

        # Adversarial loss
        loss_G_A_to_B = adversarial_loss(D_B(horse_fake), torch.ones_like(D_B(horse_fake)))
        loss_G_B_to_A = adversarial_loss(D_A(cat_fake), torch.ones_like(D_A(cat_fake)))

        # Cycle loss
        horse_reconstructed = G_B_to_A(horse_fake)
        cat_reconstructed = G_A_to_B(cat_fake)
        loss_cycle_A = cycle_loss(horse_real, horse_reconstructed)
        loss_cycle_B = cycle_loss(cat_real, cat_reconstructed)

        # Total generator loss
        loss_G = loss_G_A_to_B + loss_G_B_to_A + loss_cycle_A + loss_cycle_B
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminators
        optimizer_D_A.zero_grad()
        optimizer_D_B.zero_grad()

        # Real and fake discriminators
        loss_D_A_real = adversarial_loss(D_A(cat_real), torch.ones_like(D_A(cat_real)))
        loss_D_A_fake = adversarial_loss(D_A(cat_fake.detach()), torch.zeros_like(D_A(cat_fake)))
        loss_D_A = (loss_D_A_real + loss_D_A_fake) / 2

        loss_D_B_real = adversarial_loss(D_B(horse_real), torch.ones_like(D_B(horse_real)))
        loss_D_B_fake = adversarial_loss(D_B(horse_fake.detach()), torch.zeros_like(D_B(horse_fake)))
        loss_D_B = (loss_D_B_real + loss_D_B_fake) / 2

        loss_D_A.backward()
        loss_D_B.backward()
        optimizer_D_A.step()
        optimizer_D_B.step()

    # Print progress
    print(f"Epoch [{epoch}/{num_epochs}] | Loss_G: {loss_G.item()} | Loss_D_A: {loss_D_A.item()} | Loss_D_B: {loss_D_B.item()}")

In [None]:
# 5. Visualization
# -------------------
# Instructions:
# - After training, visualize some generated images.
# - Display the real and generated images for both the horse-to-cat and cat-to-horse transformations.
# - You should visualize at least 15-20 examples to evaluate the performance of your model.

In [None]:
def visualize_real_and_fake_examples(horse_loader, cat_loader, generator_horse_to_cat, generator_cat_to_horse, n_examples=5, device='cuda'):
    """
    Visualizes n_examples of real and fake animals using the provided model.
    
    Args:
    - horse_loader (DataLoader): Dataloader for horse images.
    - cat_loader (DataLoader): Dataloader for cat images.
    - generator_horse_to_cat (nn.Module): Generator model to convert horses to cats.
    - generator_cat_to_horse (nn.Module): Generator model to convert cats to horses.
    - n_examples (int): Number of examples to visualize.
    - device (str): Device to run the models on ('cuda' or 'cpu').
    """
    # Set the model to evaluation mode
    generator_horse_to_cat.eval()
    generator_cat_to_horse.eval()
    
    # Create the figure for visualization
    fig, axs = plt.subplots(n_examples, 4, figsize=(12, 3 * n_examples))
    
    # Loop through the number of examples
    for i in range(n_examples):
        # Get a batch of horse and cat images
        horse_real = next(iter(horse_loader)).to(device)
        cat_real = next(iter(cat_loader)).to(device)
        
        # Generate fake images
        horse_fake = generator_horse_to_cat(horse_real)
        cat_fake = generator_cat_to_horse(cat_real)
        
        # Plot real images (Horse -> Cat)
        axs[i, 0].imshow(horse_real[0].permute(1, 2, 0).cpu().detach().numpy())
        axs[i, 0].set_title("Real Horse")
        axs[i, 0].axis('off')
        
        axs[i, 1].imshow(cat_fake[0].permute(1, 2, 0).cpu().detach().numpy())
        axs[i, 1].set_title("