# I'm Something of a Painter Myself - CycleGAN Implementation

## 1. Introduction

### The Problem: Style Transfer with Unpaired Data
The goal of this project is to perform **Image-to-Image Translation**, specifically transforming ordinary landscape photos into paintings in the style of **Claude Monet**. 

In many computer vision tasks, we have "paired" data (e.g., a photo of a shoe and a sketch of the exact same shoe). However, for this task, we do not have pairs of "a photo of a place" and "a Monet painting of that exact same place". We only have two independent collections:
*   **Domain A**: A set of ~7000 landscape photos.
*   **Domain B**: A set of ~300 Monet paintings.

This is where **CycleGAN** comes in. It is designed to learn a mapping between two domains without requiring paired training data. It achieves this by enforcing **Cycle Consistency**: if we translate a photo to a painting and then back to a photo, we should get the original photo back.

### The Competition: "I'm Something of a Painter Myself"
This notebook is designed for the Kaggle competition [I'm Something of a Painter Myself](https://www.kaggle.com/competitions/gan-getting-started/overview).

*   **Objective**: Build a GAN model that generates 7,000 to 10,000 Monet-style images from the provided photo dataset.
*   **Evaluation Metric**: **MiFID (Memorization-informed FrÃ©chet Inception Distance)**. 
    *   Lower MiFID is better.
    *   It measures both the quality/diversity of generated images (FID) and penalizes the model if it simply memorizes the training set (Memorization distance).

### Project Overview & Strategy
We will implement a CycleGAN from scratch using **PyTorch**. The workflow includes:
1.  **Data Pipeline**: Efficiently loading data from Google Drive (for Colab) or Kaggle input.
2.  **Model Architecture**: 
    *   **Generator**: ResNet-based architecture to transform images.
    *   **Discriminator**: PatchGAN to classify real vs. fake image patches.
3.  **Training**: Using Adversarial Loss, Cycle Consistency Loss, and Identity Loss.
4.  **Inference**: Generating the final set of images for submission.

**Environment**: This notebook is designed to run seamlessly on **Google Colab (Pro)**, **Kaggle Kernels**, or a **Local Machine**. It automatically detects the environment to configure paths and hyperparameters (like batch size).

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import make_grid
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import shutil
import sys
import gdown

# Set random seed for reproducibility
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

seed_everything(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Data Loading & Visual EDA

### Dataset Overview
The dataset for this project is sourced from the Kaggle competition. It consists of two distinct domains:

| Domain | Folder Name | Description | Count |
| :--- | :--- | :--- | :--- |
| **Source (Domain A)** | `photo_jpg` | Real-world landscape photos. | ~7,038 images |
| **Target (Domain B)** | `monet_jpg` | Paintings by Claude Monet. | 300 images |

**Key Characteristics**:
*   **Unpaired**: There is no one-to-one mapping between a specific photo and a specific painting.
*   **Imbalanced**: We have significantly more photos than paintings (Ratio ~23:1). This imbalance makes data augmentation and stable training crucial.

### Data Loading Strategy
We implement an environment-aware loading strategy:
1.  **Colab**: Downloads the dataset from a shared Google Drive link (via `gdown`) and unzips it locally for high-speed I/O.
2.  **Kaggle**: Directly accesses the read-only dataset provided by the platform.
3.  **Local**: Downloads and unzips to the local directory.

In [None]:
# 1. Detect Environment
def get_env():
    if 'google.colab' in sys.modules:
        return 'colab'
    elif 'kaggle_web_client' in sys.modules or os.path.exists('/kaggle'):
        return 'kaggle'
    else:
        return 'local'

ENV = get_env()
print(f"Running in {ENV} environment.")

# 2. Define paths based on environment
if ENV == 'colab':
    # Shared Link: https://drive.google.com/file/d/1Wf8cZM1QboZamZDoL9hcuA8yIFUQtgEb/view?usp=sharing
    file_id = '1Wf8cZM1QboZamZDoL9hcuA8yIFUQtgEb'
    url = f'https://drive.google.com/uc?id={file_id}'
    
    base_dir = '/content'
    local_zip_path = os.path.join(base_dir, 'gan-getting-started.zip')
    dataset_dir = os.path.join(base_dir, 'dataset')
    
    # Download and Unzip if needed
    if not os.path.exists(dataset_dir):
        print(f"Downloading zip file to {local_zip_path}...")
        gdown.download(url, local_zip_path, quiet=False)
        
        print("Unzipping...")
        shutil.unpack_archive(local_zip_path, dataset_dir)
        print("Done!")
    else:
        print("Dataset already exists.")

elif ENV == 'kaggle':
    # Kaggle standard input directory
    dataset_dir = '/kaggle/input/gan-getting-started'
    print(f"Using Kaggle dataset at {dataset_dir}")

else: # Local
    # Assume local setup or download
    file_id = '1Wf8cZM1QboZamZDoL9hcuA8yIFUQtgEb'
    url = f'https://drive.google.com/uc?id={file_id}'
    
    base_dir = '.'
    local_zip_path = os.path.join(base_dir, 'gan-getting-started.zip')
    dataset_dir = os.path.join(base_dir, 'dataset')
    
    if not os.path.exists(dataset_dir):
        print(f"Downloading zip file to {local_zip_path}...")
        gdown.download(url, local_zip_path, quiet=False)
        print("Unzipping...")
        shutil.unpack_archive(local_zip_path, dataset_dir)
        print("Done!")

# 3. Verify and Visualize Distribution
if os.path.exists(dataset_dir):
    photo_dir = os.path.join(dataset_dir, 'photo_jpg')
    monet_dir = os.path.join(dataset_dir, 'monet_jpg')
    
    if os.path.exists(photo_dir) and os.path.exists(monet_dir):
        n_photo = len(os.listdir(photo_dir))
        n_monet = len(os.listdir(monet_dir))
        
        print(f"Photos: {n_photo}")
        print(f"Monet Paintings: {n_monet}")
        
        # Bar Chart Visualization
        plt.figure(figsize=(8, 5))
        bars = plt.bar(['Photos (Source)', 'Monet (Target)'], [n_photo, n_monet], color=['#4a90e2', '#f5a623'])
        plt.title('Dataset Distribution: Photos vs. Paintings', fontsize=14)
        plt.ylabel('Number of Images')
        
        # Add counts on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 100,
                     f'{height}', ha='center', va='bottom', fontsize=12)
        
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        plt.show()
    else:
        print("Directories not found. Check path structure.")
else:
    print(f"Dataset directory {dataset_dir} not found!")

In [None]:
class ImageDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None):
        self.root_monet = root_monet
        self.root_photo = root_photo
        self.transform = transform

        self.monet_images = os.listdir(root_monet)
        self.photo_images = os.listdir(root_photo)
        self.length_dataset = max(len(self.monet_images), len(self.photo_images))
        self.monet_len = len(self.monet_images)
        self.photo_len = len(self.photo_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        monet_img = self.monet_images[index % self.monet_len]
        photo_img = self.photo_images[index % self.photo_len]

        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)

        monet_img = np.array(Image.open(monet_path).convert("RGB"))
        photo_img = np.array(Image.open(photo_path).convert("RGB"))

        if self.transform:
            monet_img = Image.fromarray(monet_img)
            photo_img = Image.fromarray(photo_img)
            monet_img = self.transform(monet_img)
            photo_img = self.transform(photo_img)

        return monet_img, photo_img

# Transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Create DataLoader
dataset = ImageDataset(
    root_monet=os.path.join(dataset_dir, 'monet_jpg'),
    root_photo=os.path.join(dataset_dir, 'photo_jpg'),
    transform=transform
)
loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

# Visualization
def show_sample(loader):
    monet, photo = next(iter(loader))
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    # Denormalize
    monet = monet * 0.5 + 0.5
    photo = photo * 0.5 + 0.5

    ax[0].imshow(monet[0].permute(1, 2, 0))
    ax[0].set_title("Monet Style")
    ax[0].axis("off")

    ax[1].imshow(photo[0].permute(1, 2, 0))
    ax[1].set_title("Photo")
    ax[1].axis("off")
    plt.show()

show_sample(loader)

## 3. Model Architecture (CycleGAN)

### What is CycleGAN?
CycleGAN (Cycle-Consistent Adversarial Networks) is a technique for training unsupervised image-to-image translation models. Unlike traditional GANs that require paired training data (e.g., a specific photo and its corresponding painting), CycleGAN learns to translate between two domains (e.g., Domain X: Photos, Domain Y: Monet Paintings) using **unpaired** data.

It achieves this by training two sets of Generators and Discriminators simultaneously:
*   **Generator $G$**: Translates $X \rightarrow Y$ (Photo to Monet).
*   **Generator $F$**: Translates $Y \rightarrow X$ (Monet to Photo).
*   **Discriminator $D_Y$**: Distinguishes real images in $Y$ from generated images $G(X)$.
*   **Discriminator $D_X$**: Distinguishes real images in $X$ from generated images $F(Y)$.

### Why CycleGAN for this Project?
1.  **Unpaired Data**: The competition provides a set of photos and a separate set of Monet paintings, but no one-to-one mapping exists. CycleGAN is the state-of-the-art solution for this specific constraint.
2.  **Cycle Consistency**: To prevent the model from hallucinating or ignoring the input image entirely, CycleGAN enforces that $F(G(x)) \approx x$. This ensures that if we turn a photo into a painting and back, we recover the original photo, preserving the structural content (trees, buildings, mountains) while only changing the style.
3.  **Identity Loss**: We also use identity mapping ($G(y) \approx y$) to preserve color composition when the input already looks like the target domain.

### Architecture Components
We implement the standard architecture proposed in the original paper:
1.  **Generator**: ResNet-based architecture (9 residual blocks for 256x256 images). ResNet is preferred over U-Net here because we want to transform the style while keeping the spatial structure intact, and residual connections help preserve information across deep layers.
2.  **Discriminator**: PatchGAN (70x70). Instead of classifying the whole image as real/fake, it classifies overlapping $70 \times 70$ patches. This encourages high-frequency "texture" realism.
3.  **Replay Buffer**: A history of generated images is used to update the discriminator, preventing it from overfitting to the most recent generator output (reducing oscillation).

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels=3, num_residuals=9):
        super(Generator, self).__init__()
        # Initial Convolution
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(img_channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]
        # Downsampling
        in_features = 64
        out_features = in_features * 2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features * 2

        # Residual Blocks
        for _ in range(num_residuals):
            model += [ResidualBlock(in_features)]

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output Layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, img_channels, 7),
            nn.Tanh(),
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, feature, 4, stride=1 if feature == features[-1] else 2, padding=1, padding_mode="reflect"),
                    nn.InstanceNorm2d(feature),
                    nn.LeakyReLU(0.2, inplace=True),
                )
            )
            in_channels = feature

        layers.append(nn.Conv2d(in_channels, 1, 4, stride=1, padding=1, padding_mode="reflect"))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(self.initial(x))

## 4. Training Configuration

This section sets up the critical hyperparameters and initialization strategies for training. 

### Configuration Guide for Developers
If you are forking this notebook or running it on different hardware, here is how you should adjust the settings:

1.  **Batch Size (`BATCH_SIZE`)**:
    *   **High-End GPUs (A100, V100)**: You can likely use `BATCH_SIZE = 32` or higher.
    *   **Mid-Range GPUs (T4, P100 - Kaggle/Colab Standard)**: Use `BATCH_SIZE = 4` or `1`. CycleGAN is memory-intensive because it keeps 4 models (2 Generators, 2 Discriminators) in memory.
    *   **CPU**: Set `BATCH_SIZE = 1`. Training will be extremely slow.

2.  **Learning Rate (`LEARNING_RATE`)**:
    *   Standard CycleGAN uses `2e-4`.
    *   We use a linear decay scheduler that keeps the rate constant for the first half of training and decays it to zero over the second half.

3.  **Epochs (`NUM_EPOCHS`)**:
    *   **Debugging**: Set to `1-5` to ensure code runs.
    *   **Good Results**: `30-50` epochs.
    *   **Best Results**: `100-200` epochs (requires several hours on a GPU).

4.  **Loss Weights**:
    *   `LAMBDA_CYCLE = 10`: Controls how strictly the model enforces $F(G(x)) \approx x$. Higher values preserve more structure (edges, shapes) but might limit style transfer.
    *   `LAMBDA_IDENTITY = 0.5`: Controls color preservation.

### Components
*   **Weights Initialization**: We use a Normal distribution ($\mu=0, \sigma=0.02$) as recommended in the paper.
*   **Replay Buffer**: Stores the last 50 generated images. This is crucial for stabilizing the Discriminator, preventing it from oscillating by "forgetting" previous fake images.

In [None]:
def init_weights(net, init_type='normal', init_gain=0.02):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                nn.init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, init_gain)
            nn.init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)

class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)

In [None]:
# Hyperparameters
LEARNING_RATE = 2e-4

# Configure Batch Size based on Environment
if ENV == 'colab':
    BATCH_SIZE = 32 # A100 can handle 32
    NUM_WORKERS = 4
elif ENV == 'kaggle':
    BATCH_SIZE = 4 # P100/T4 on Kaggle might struggle with 32. 4 is safe.
    NUM_WORKERS = 2
else:
    BATCH_SIZE = 1
    NUM_WORKERS = 0

print(f"Configuration: ENV={ENV}, BATCH_SIZE={BATCH_SIZE}, NUM_WORKERS={NUM_WORKERS}")

NUM_EPOCHS = 30 # Increase this for better results (e.g., 30-50)
LAMBDA_CYCLE = 10
LAMBDA_IDENTITY = 0.5

# Re-create DataLoader to apply BATCH_SIZE and parallel loading
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

# Initialize Models
gen_Z = Generator(img_channels=3, num_residuals=9).to(device) # Photo -> Monet
gen_P = Generator(img_channels=3, num_residuals=9).to(device) # Monet -> Photo
disc_Z = Discriminator(in_channels=3).to(device) # Classify Monet
disc_P = Discriminator(in_channels=3).to(device) # Classify Photo

# Initialize Weights
init_weights(gen_Z)
init_weights(gen_P)
init_weights(disc_Z)
init_weights(disc_P)

# Optimizers
opt_gen = optim.Adam(
    list(gen_Z.parameters()) + list(gen_P.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)
opt_disc = optim.Adam(
    list(disc_Z.parameters()) + list(disc_P.parameters()),
    lr=LEARNING_RATE,
    betas=(0.5, 0.999),
)

# Schedulers
# Linear decay after 15 epochs
def lambda_rule(epoch):
    lr_l = 1.0 - max(0, epoch + 1 - 15) / float(15 + 1)
    return lr_l
scheduler_gen = optim.lr_scheduler.LambdaLR(opt_gen, lr_lambda=lambda_rule)
scheduler_disc = optim.lr_scheduler.LambdaLR(opt_disc, lr_lambda=lambda_rule)

# Losses
L1 = nn.L1Loss()
mse = nn.MSELoss()

# Buffers
fake_monet_buffer = ReplayBuffer()
fake_photo_buffer = ReplayBuffer()

## 5. Training & Monitoring

### Training Strategy
The training loop for CycleGAN is complex because we are training **four networks** simultaneously ($G$, $F$, $D_X$, $D_Y$).

**Step 1: Train Generators ($G$ and $F$)**
We update the generators to minimize a weighted sum of three losses:
1.  **Adversarial Loss**: $G$ tries to fool $D_Y$ (make generated Monet look real). $F$ tries to fool $D_X$.
2.  **Cycle Consistency Loss**: $F(G(x)) \approx x$ and $G(F(y)) \approx y$. This ensures the image content is preserved.
3.  **Identity Loss**: $G(y) \approx y$ and $F(x) \approx x$. This preserves color and prevents the model from making unnecessary changes if the image is already in the target domain.

**Step 2: Train Discriminators ($D_X$ and $D_Y$)**
We update the discriminators to correctly classify:
*   Real images as **Real** (label 1).
*   Generated images as **Fake** (label 0).
*   *Note*: We use the **Replay Buffer** here. Instead of always showing the discriminator the *latest* generated image, we sometimes show it an image generated a few steps ago. This stabilizes training.

### Performance Optimization
*   **Mixed Precision (`torch.amp`)**: We use Automatic Mixed Precision. This casts some operations to `float16` instead of `float32`, significantly reducing memory usage and speeding up training on modern GPUs (T4, P100, V100, A100).
*   **GradScaler**: Manages gradient scaling to prevent underflow when using mixed precision.

### Monitoring
*   **Tqdm Progress Bar**: Shows real-time progress for each epoch.
*   **Loss Logging**: We print `loss_G` (Generator total loss) and `loss_D` (Discriminator total loss) periodically.
*   **Visual Validation**: At the end of every epoch, we run a fixed "test photo" through the generator and display the result. This allows us to visually confirm that the model is learning the Monet style and not just collapsing to noise.

In [None]:
def train_fn(disc_Z, disc_P, gen_Z, gen_P, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler):
    loop = tqdm(loader, leave=True)

    for idx, (monet, photo) in enumerate(loop):
        monet = monet.to(device)
        photo = photo.to(device)

        # Train Generators H and Z
        with torch.amp.autocast('cuda'):
            # Identity loss
            fake_monet = gen_Z(monet)
            loss_identity_monet = L1(fake_monet, monet) * LAMBDA_IDENTITY
            
            fake_photo = gen_P(photo)
            loss_identity_photo = L1(fake_photo, photo) * LAMBDA_IDENTITY

            # GAN loss
            fake_monet = gen_Z(photo)
            D_Z_fake = disc_Z(fake_monet)
            loss_GAN_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            fake_photo = gen_P(monet)
            D_P_fake = disc_P(fake_photo)
            loss_GAN_P = mse(D_P_fake, torch.ones_like(D_P_fake))

            # Cycle loss
            cycle_monet = gen_Z(fake_photo)
            loss_cycle_monet = L1(cycle_monet, monet) * LAMBDA_CYCLE

            cycle_photo = gen_P(fake_monet)
            loss_cycle_photo = L1(cycle_photo, photo) * LAMBDA_CYCLE

            # Total loss
            loss_G = (
                loss_GAN_Z
                + loss_GAN_P
                + loss_cycle_monet
                + loss_cycle_photo
                + loss_identity_monet
                + loss_identity_photo
            )

        opt_gen.zero_grad()
        g_scaler.scale(loss_G).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        # Train Discriminators H and Z
        with torch.amp.autocast('cuda'):
            # Discriminator P
            D_P_real = disc_P(photo)
            loss_D_P_real = mse(D_P_real, torch.ones_like(D_P_real))
            
            fake_photo_ = fake_photo_buffer.push_and_pop(fake_photo)
            D_P_fake = disc_P(fake_photo_.detach())
            loss_D_P_fake = mse(D_P_fake, torch.zeros_like(D_P_fake))
            loss_D_P = (loss_D_P_real + loss_D_P_fake) / 2

            # Discriminator Z
            D_Z_real = disc_Z(monet)
            loss_D_Z_real = mse(D_Z_real, torch.ones_like(D_Z_real))
            
            fake_monet_ = fake_monet_buffer.push_and_pop(fake_monet)
            D_Z_fake = disc_Z(fake_monet_.detach())
            loss_D_Z_fake = mse(D_Z_fake, torch.zeros_like(D_Z_fake))
            loss_D_Z = (loss_D_Z_real + loss_D_Z_fake) / 2

            # Total loss
            loss_D = (loss_D_P + loss_D_Z) / 2

        opt_disc.zero_grad()
        d_scaler.scale(loss_D).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        if idx % 200 == 0:
            loop.set_postfix(loss_G=loss_G.item(), loss_D=loss_D.item())

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def visualize_progress(gen, test_input, epoch):
    gen.eval()
    with torch.no_grad():
        fake = gen(test_input)
        fake = fake * 0.5 + 0.5 # Denormalize
        fake = fake.cpu()
        
        plt.figure(figsize=(6, 6))
        plt.imshow(fake[0].permute(1, 2, 0))
        plt.title(f"Epoch {epoch+1} Generated Monet")
        plt.axis("off")
        plt.show()
    gen.train()

# Get a fixed sample for visualization
fixed_monet, fixed_photo = next(iter(loader))
fixed_photo = fixed_photo.to(device)

# Training Loop
g_scaler = torch.amp.GradScaler('cuda')
d_scaler = torch.amp.GradScaler('cuda')

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    train_fn(disc_Z, disc_P, gen_Z, gen_P, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
    
    # Visualization
    visualize_progress(gen_Z, fixed_photo, epoch)
    
    # Step schedulers
    scheduler_gen.step()
    scheduler_disc.step()
    
    # Save checkpoint every 5 epochs
    if (epoch + 1) % 5 == 0:
        save_checkpoint(gen_Z, opt_gen, filename=f"gen_Z_{epoch+1}.pth.tar")
        save_checkpoint(gen_P, opt_gen, filename=f"gen_P_{epoch+1}.pth.tar")

## 6. Inference & Submission

Once training is complete, we use the trained Generator `gen_Z` (Photo -> Monet) to transform all images in the `photo_jpg` directory.
The images are saved to a directory and then zipped for submission.

In [None]:
def generate_images(gen_Z, photo_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    gen_Z.eval()
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    print(f"Generating images from {photo_dir}...")
    photo_files = os.listdir(photo_dir)

    with torch.no_grad():
        for i, file in enumerate(tqdm(photo_files)):
            img_path = os.path.join(photo_dir, file)
            img = Image.open(img_path).convert("RGB")
            img = transform(img).unsqueeze(0).to(device)

            fake_monet = gen_Z(img)

            # Denormalize
            fake_monet = fake_monet * 0.5 + 0.5
            fake_monet = fake_monet.squeeze(0).cpu().detach()

            # Save
            save_path = os.path.join(output_dir, file)
            # Convert to PIL and save
            transforms.ToPILImage()(fake_monet).save(save_path)

    print("Generation complete.")

# Generate
output_dir = 'images'
photo_dir = os.path.join(dataset_dir, 'photo_jpg')
generate_images(gen_Z, photo_dir, output_dir)

# Zip
shutil.make_archive('images', 'zip', output_dir)
print("Images zipped successfully. Ready for submission!")

## 7. Conclusion & Future Work

### Conclusion
We successfully implemented a CycleGAN to translate landscape photos into Monet-style paintings. 
*   The model learned to capture the texture and color palette of Monet's work.
*   The use of Cycle Consistency Loss ensured that the structural content of the original photos was preserved.

### Future Work
To improve the MiFID score, we could consider:
1.  **Data Augmentation**: Using more advanced augmentations (e.g., random crops, flips) to increase data diversity.
2.  **Hyperparameter Tuning**: Experimenting with different weights for the Identity and Cycle losses.
3.  **Architecture Variants**: Trying U-Net based generators or different discriminator architectures.
4.  **Longer Training**: Training for more epochs (e.g., 100+) with a decaying learning rate.

## 8. References

1.  **CycleGAN Paper**: Zhu, J. Y., Park, T., Isola, P., & Efros, A. A. (2017). Unpaired image-to-image translation using cycle-consistent adversarial networks. *ICCV*.
2.  **Kaggle Competition**: [I'm Something of a Painter Myself](https://www.kaggle.com/competitions/gan-getting-started)
3.  **PyTorch Documentation**: [https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)