# I'm Something of a Painter Myself - CycleGAN Implementation

## 1. Introduction
This notebook implements a CycleGAN to translate photos into Monet-style paintings.
Competition: [I'm Something of a Painter Myself](https://www.kaggle.com/competitions/gan-getting-started/overview)

**Strategy**:
1.  **Data**: Load from Google Drive (Zip file) for faster IO.
2.  **Model**: CycleGAN (ResNet Generator + PatchGAN Discriminator).
3.  **Framework**: PyTorch.

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Set random seed for reproducibility
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Data Pipeline

### Data Loading Strategy
We will use **Google Drive** to store the dataset.
1.  Download `gan-getting-started.zip` from Kaggle to your local machine.
2.  Upload the zip file to your Google Drive (e.g., in a folder named `kaggle_gan`).
3.  Mount Drive in Colab.
4.  Copy the zip to the local Colab environment (`/content/`) and unzip it. This is much faster than reading files directly from Drive.

In [None]:
from google.colab import drive
import shutil

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Define paths
# CHANGE THIS PATH to where you uploaded the zip file in your Drive
drive_zip_path = '/content/drive/MyDrive/kaggle_gan/gan-getting-started.zip' 
local_zip_path = '/content/gan-getting-started.zip'
dataset_dir = '/content/dataset'

# 3. Copy and Unzip
if not os.path.exists(dataset_dir):
    print("Copying zip file...")
    shutil.copy(drive_zip_path, local_zip_path)
    
    print("Unzipping...")
    shutil.unpack_archive(local_zip_path, dataset_dir)
    print("Done!")
else:
    print("Dataset already exists.")

# 4. Verify
print(f"Photos: {len(os.listdir(os.path.join(dataset_dir, 'photo_jpg')))}")
print(f"Monet: {len(os.listdir(os.path.join(dataset_dir, 'monet_jpg')))}")

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.root_monet = root_monet
        self.root_photo = root_photo
        self.transform = transform
        
        self.monet_images = os.listdir(root_monet)
        self.photo_images = os.listdir(root_photo)
        self.length_dataset = max(len(self.monet_images), len(self.photo_images))
        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        monet_img = self.monet_images[index % self.monet_len]
        photo_img = self.photo_images[index % self.photo_len]

        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)

        monet_img = np.array(Image.open(monet_path).convert("RGB"))
        photo_img = np.array(Image.open(photo_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=monet_img, image0=photo_img)
            monet_img = augmentations["image"]
            photo_img = augmentations["image0"]

        return monet_img, photo_img

# Note: We will need albumentations or torchvision transforms. 
# Here is a simple torchvision version for starter:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Simplified Dataset for torchvision (without albumentations for now)
class SimpleDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.monet_files = [os.path.join(root_monet, f) for f in os.listdir(root_monet)]
        self.photo_files = [os.path.join(root_photo, f) for f in os.listdir(root_photo)]
        self.transform = transform
        
    def __len__(self):
        return max(len(self.monet_files), len(self.photo_files))
    
    def __getitem__(self, idx):
        monet_img = Image.open(self.monet_files[idx % len(self.monet_files)]).convert("RGB")
        photo_img = Image.open(self.photo_files[idx % len(self.photo_files)]).convert("RGB")
        
        if self.transform:
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)
            
        return monet_img, photo_img

# Create DataLoader
# dataset = SimpleDataset(
#     root_monet=os.path.join(dataset_dir, 'monet_jpg'),
#     root_photo=os.path.join(dataset_dir, 'photo_jpg'),
#     transform=transform
# )
# loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)

## 3. Model Architecture (CycleGAN)

We need:
1.  **Generator**: ResNet-based (9 blocks for 256x256 images).
2.  **Discriminator**: PatchGAN (70x70 PatchGAN).
3.  **Weights Initialization**: Normal distribution with mean 0.0, std 0.02.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels=3, num_residuals=9):
        super(Generator, self).__init__()
        # Initial Convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(img_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual Blocks
        for _ in range(num_residuals):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output Layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, img_channels, 7),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, feature, 4, stride=1 if feature == features[-1] else 2, padding=1, padding_mode="reflect"),
                    nn.InstanceNorm2d(feature),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, 4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return torch.sigmoid(self.initial(x)) if False else self.model(self.initial(x)) # PatchGAN output

def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    disc = Discriminator(img_channels)
    print(gen(x).shape)
    print(disc(x).shape)

# test()

## 4. Training Loop
This section defines the training loop.
*   **Losses**: Adversarial (MSE), Cycle (L1), Identity (L1).
*   **Optimizers**: Adam.
*   **Loop**: Iterate through epochs, update G and D.

In [None]:
# Hyperparameters
LEARNING_RATE = 2e-4
BATCH_SIZE = 1
NUM_EPOCHS = 30 # Increase this for better results
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 0.5

# Initialize Models
gen_Z = Generator(img_channels=3, num_residuals=9).to(device) # Photo -> Monet
gen_P = Generator(img_channels=3, num_residuals=9).to(device) # Monet -> Photo
disc_Z = Discriminator(in_channels=3).to(device) # Classify Monet
disc_P = Discriminator(in_channels=3).to(device) # Classify Photo

# Optimizers
opt_gen = optim.Adam(
    list(gen_Z.parameters()) + list(gen_P.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)
opt_disc = optim.Adam(
    list(disc_Z.parameters()) + list(disc_P.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

# Losses
L1 = nn.L1Loss()
mse = nn.MSELoss()

# Training Loop Skeleton
def train_fn(disc_Z, disc_P, gen_Z, gen_P, loader, opt_disc, opt_gen, L1, mse):
    loop = tqdm(loader, leave=True)

    for idx, (monet, photo) in enumerate(loop):
        monet = monet.to(device)
        photo = photo.to(device)

        # Train Discriminators H and Z
        # ... (Implementation needed) ...

        # Train Generators H and Z
        # ... (Implementation needed) ...

        # Update progress bar
        # loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))

# Run Training
# for epoch in range(NUM_EPOCHS):
#     train_fn(disc_Z, disc_P, gen_Z, gen_P, loader, opt_disc, opt_gen, L1, mse)
#     # Save Model Checkpoints
#     # Save Sample Images