# I’m Something of a Painter Myself – Monet GAN
#### Author: James Coffey   
#### Date: 2025‑07‑30
#### Challenge URL: [I’m Something of a Painter Myself](https://www.kaggle.com/competitions/gan-getting-started)
# Imports & Seeds

In [None]:
import random
import shutil
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt


# Make randomness repeatable across hosts and replicas
SEED = 42
np.random.seed(SEED)
random.seed(SEED)

# Global constants
IMG_SIZE: int = 256  # Height & width of images
BATCH: int = 8  # Per‑replica batch size (TPU v3‑8 ⇒ 64 global)
BUFFER: int = 1024  # Shuffle buffer size for TFRecords
AUTO = tf.data.AUTOTUNE  # Let TF tune prefetch & parallel calls
OUTPUT_CHANNELS = 3  # RGB

E0000 00:00:1754519838.464572      10 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:230


# TPU Detection

In [None]:
def detect_tpu() -> tf.distribute.Strategy:
    """Detects available TPU and returns an appropriate distribution strategy.

    The function attempts TPU‑VM first (hosted notebooks), falls back to a remote
    TPU resolver, and finally defaults to `MirroredStrategy` for CPU/GPU.  This
    abstraction means the rest of the code can use the same `strategy` object
    regardless of hardware.

    Returns:
        A ready‑to‑use `tf.distribute.Strategy`.
    """
    try:
        strategy = tf.distribute.TPUStrategy()  # TPU‑VM path (≈ 1 line)
        print("TPU‑VM detected – replicas:", strategy.num_replicas_in_sync)
        return strategy
    except (ValueError, NotImplementedError):
        # Fall back to remote‑TPU (legacy notebooks / training pods).
        try:
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver("local")
            tf.config.experimental_connect_to_cluster(resolver)
            tf.tpu.experimental.initialize_tpu_system(resolver)
            strategy = tf.distribute.TPUStrategy(resolver)
            print("Remote TPU detected – replicas:", strategy.num_replicas_in_sync)
            return strategy
        except Exception as e:
            print("TPU not available (", e, ") – falling back to CPU/GPU")
            return tf.distribute.MirroredStrategy()


strategy = detect_tpu()

# TFRecord Data Pipeline

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()
MONET_TFREC = tf.io.gfile.glob(f"{GCS_PATH}/monet_tfrec/*.tfrec")
PHOTO_TFREC = tf.io.gfile.glob(f"{GCS_PATH}/photo_tfrec/*.tfrec")
print(f"Monet TFRecords: {len(MONET_TFREC)}  |  Photo TFRecords: {len(PHOTO_TFREC)}")

# Feature description mirrors Kaggle’s TFRecord schema
_IMAGE_FEATURE_DESCRIPTION = {
    "image_name": tf.io.FixedLenFeature([], tf.string),
    "image": tf.io.FixedLenFeature([], tf.string),
    "target": tf.io.FixedLenFeature([], tf.string),  # Unused placeholder
}


def _parse_example(proto: tf.Tensor) -> tf.Tensor:
    """Decodes a single TFRecord example into a normalized image tensor.

    Each TFRecord stores one Monet or photo sample under the feature key
    `"image"`.  The raw bytes are decoded from JPEG, cast to `tf.float32`,
    and linearly rescaled from the original `[0, 255]` range to `[-1, 1]`
    so they match the generator’s `tanh` output scale.

    Args:
        proto: A scalar `tf.string` tensor containing the serialized
            `tf.train.Example` pulled from a `TFRecordDataset` iterator.

    Returns:
        A `tf.Tensor` of shape `(256, 256, 3)` with dtype `tf.float32`
        and values in `[-1, 1]`.
    """
    example = tf.io.parse_single_example(proto, _IMAGE_FEATURE_DESCRIPTION)
    img = tf.image.decode_jpeg(example["image"], channels=3)
    img = tf.cast(img, tf.float32) / 127.5 - 1.0  # Rescale to [‑1, 1]
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    return img


def _augment(img: tf.Tensor) -> tf.Tensor:
    """Applies lightweight colour and spatial augmentations.

    The goal is to diversify the **Monet** domain without breaking its
    characteristic palette or brushwork.  Augmentations are purposefully
    mild compared to large-scale classification pipelines.

    Augmentations performed (in order):
      1. Horizontal flip with 50 % probability.
      2. Random brightness shift in `[-0.2, 0.2]`.
      3. Random contrast scale in `[0.8, 1.2]`.
      4. Random hue shift in `[-0.05, 0.05]`.
      5. Resize to `(286, 286)` followed by a random crop back to
         `(256, 256)` — mirrors the augmentation in the original CycleGAN
         TensorFlow tutorial.

    Args:
        img: A tensor of shape `(256, 256, 3)` in `[-1, 1]`.

    Returns:
        A tensor of identical shape and dtype with the same scaling,
        potentially flipped, color-jittered, and crop-shifted.
    """
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_brightness(img, 0.2)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.image.random_hue(img, 0.05)
    # Slight zoom‑crop a 286→256 patch, mirroring CycleGAN paper.
    img = tf.image.resize(img, [286, 286])
    img = tf.image.random_crop(img, [IMG_SIZE, IMG_SIZE, 3])
    return img


def make_dataset(tfrec_files, *, augment: bool = False) -> tf.data.Dataset:
    """Creates a shuffled, batched, prefetched `tf.data.Dataset`.

    Args:
        tfrec_files: List or glob of TFRecord paths.
        augment: Whether to apply `_augment` during the map stage.

    Returns:
        A ready‑to‑iterate `tf.data.Dataset` of `(B, 256, 256, 3)` float images.
    """
    ds = tf.data.TFRecordDataset(tfrec_files, num_parallel_reads=AUTO)
    ds = ds.shuffle(BUFFER, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(_parse_example, num_parallel_calls=AUTO)
    if augment:
        ds = ds.map(_augment, num_parallel_calls=AUTO)
    ds = ds.batch(BATCH, drop_remainder=True)
    ds = ds.prefetch(AUTO)
    return ds


# Build datasets
monet_ds = make_dataset(MONET_TFREC, augment=True)
photo_ds = make_dataset(PHOTO_TFREC, augment=False)

# Quick sanity‑check visual
sample_monet = next(iter(monet_ds.take(1)))[0]
sample_photo = next(iter(photo_ds.take(1)))[0]
plt.figure(figsize=(6, 3))
plt.subplot(1, 2, 1)
plt.imshow((sample_photo + 1) / 2)
plt.title("Photo")
plt.axis(False)
plt.subplot(1, 2, 2)
plt.imshow((sample_monet + 1) / 2)
plt.title("Monet")
plt.axis(False)
plt.show()

# Architecture – downsample & upsample blocks

In [None]:
def downsample(
    filters: int, size: int, *, apply_instancenorm: bool = True
) -> keras.Sequential:
    """Creates a down-sampling block used in the encoder path.

    The block performs **Conv ⇒ (InstanceNorm) ⇒ LeakyReLU** with a stride of 2,
    halving the spatial resolution while increasing the channel depth.

    Args:
        filters: Number of convolution filters to apply.
        size: Side length of the square convolution kernel (e.g. 4 for a 4×4
            kernel).
        apply_instancenorm: If `True`, inserts `GroupNormalization` with
            `groups = channels` (i.e. InstanceNorm).  Set to `False` for the
            very first layer, mirroring common CycleGAN practice.

    Returns:
        A `keras.Sequential` block that maps tensors of shape
        `(B, H, W, C)` → `(B, H/2, W/2, filters)`.
    """
    init = tf.random_normal_initializer(0.0, 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    block = keras.Sequential()
    block.add(
        layers.Conv2D(
            filters,
            size,
            strides=2,
            padding="same",
            kernel_initializer=init,
            use_bias=False,
        )
    )
    if apply_instancenorm:
        block.add(layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))
    block.add(layers.LeakyReLU())
    return block


def upsample(
    filters: int, size: int, *, apply_dropout: bool = False
) -> keras.Sequential:
    """Creates an up-sampling block used in the decoder path.

    The block performs **Transposed Conv ⇒ InstanceNorm ⇒ (Dropout) ⇒ ReLU**
    with a stride of 2, doubling the spatial resolution and optionally applying
    dropout (useful for introducing stochasticity near bottleneck layers).

    Args:
        filters: Number of transposed-convolution filters to apply.
        size: Side length of the square transposed-convolution kernel
            (e.g. 4 for a 4×4 kernel).
        apply_dropout: If `True`, inserts a `Dropout(0.5)` layer after
            InstanceNorm.  Recommended for the three innermost decoder layers
            as in the original U-Net and CycleGAN papers.

    Returns:
        A `keras.Sequential` block that maps tensors of shape
        `(B, H, W, C)` → `(B, 2 H, 2 W, filters)`.
    """
    init = tf.random_normal_initializer(0.0, 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    block = keras.Sequential()
    block.add(
        layers.Conv2DTranspose(
            filters,
            size,
            strides=2,
            padding="same",
            kernel_initializer=init,
            use_bias=False,
        )
    )
    block.add(layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))
    if apply_dropout:
        block.add(layers.Dropout(0.5))
    block.add(layers.ReLU())
    return block

# Generator & Discriminator builders

In [None]:
def build_generator() -> keras.Model:
    """Constructs a ResNet‑9 generator (CycleGAN style).

    The architecture mirrors the one proposed in the original CycleGAN paper:

    * **Encoder** – Seven stride‑2 convolution blocks that progressively halve
      spatial resolution (256 → 1) while expanding channel depth.
    * **Decoder** – Seven transpose‑convolution blocks that restore the
      resolution back to 256 × 256 and include skip connections to the encoder
      (U‑Net flavour) to better preserve low‑frequency structure.
    * **Activation** – A final `tanh` layer maps logits to the `[-1, 1]` range
      expected by the loss functions.

    Returns:
        keras.Model: Functional model that maps a batch of RGB images
        (`shape=(B, 256, 256, 3)`, scaled to `[-1, 1]`) to equally‑sized
        stylised images in the same range.
    """
    inputs = layers.Input(shape=[IMG_SIZE, IMG_SIZE, 3])

    # Encoder (downsample): size halves each step.
    down_stack = [
        downsample(64, 4, apply_instancenorm=False),  # (128×128)
        downsample(128, 4),  # (64×64)
        downsample(256, 4),  # (32×32)
        downsample(512, 4),  # (16×16)
        downsample(512, 4),  # (8×8)
        downsample(512, 4),  # (4×4)
        downsample(512, 4),  # (2×2)
        downsample(512, 4),  # (1×1)
    ]

    # Decoder (upsample): mirrors the encoder.
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (2×2)
        upsample(512, 4, apply_dropout=True),  # (4×4)
        upsample(512, 4, apply_dropout=True),  # (8×8)
        upsample(512, 4),  # (16×16)
        upsample(256, 4),  # (32×32)
        upsample(128, 4),  # (64×64)
        upsample(64, 4),   # (128×128)
    ]

    last = layers.Conv2DTranspose(
        OUTPUT_CHANNELS,
        4,
        strides=2,
        padding="same",
        kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
        activation="tanh",
    )  # (256×256×3)

    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])  # Skip the innermost layer in U‑Net fashion

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    outputs = last(x)
    return keras.Model(inputs, outputs, name="generator")


def build_discriminator() -> keras.Model:
    """Creates a 70 × 70 PatchGAN discriminator.

    The network classifies overlapping 70 × 70 patches of an input image as
    *real* or *fake*, yielding a `(B, 30, 30, 1)` map of logits. PatchGANs focus
    on high‑frequency texture, which encourages generators to produce locally
    consistent brush‑strokes while remaining lightweight.

    Returns:
        keras.Model: A discriminator that accepts images of shape
        `(B, 256, 256, 3)` scaled to `[-1, 1]` and outputs patch‑level logits in
        the same batch order.
    """
    init = tf.random_normal_initializer(0.0, 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[IMG_SIZE, IMG_SIZE, 3], name="input_image")
    x = inp

    x = downsample(64, 4, apply_instancenorm=False)(x)
    x = downsample(128, 4)(x)
    x = downsample(256, 4)(x)

    x = layers.ZeroPadding2D()(x)
    x = layers.Conv2D(512, 4, strides=1, kernel_initializer=init, use_bias=False)(x)
    x = layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init)(x)
    x = layers.LeakyReLU()(x)

    x = layers.ZeroPadding2D()(x)
    x = layers.Conv2D(1, 4, strides=1, kernel_initializer=init)(x)

    return keras.Model(inp, x, name="discriminator")

# Instantiate models under distribution scope

In [None]:
with strategy.scope():
    monet_generator = build_generator()  # Photo → Monet
    photo_generator = build_generator()  # Monet → Photo

    monet_discriminator = build_discriminator()
    photo_discriminator = build_discriminator()

## Demo a single forward pass for sanity

In [None]:
_to_monet = monet_generator(sample_photo[None, ...])
plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(sample_photo * 0.5 + 0.5)
plt.axis(False)
plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(_to_monet[0] * 0.5 + 0.5)
plt.axis(False)
plt.show()

# CycleGAN model

In [None]:
class CycleGan(keras.Model):
    """CycleGAN composite model with custom training loop.

    This class bundles two generators and two discriminators into a single
    `keras.Model` so that we can leverage Keras fit/compile semantics while
    retaining full control over the adversarial, cycle‑consistency, and
    identity losses described in the original *Unpaired Image‑to‑Image
    Translation using Cycle‑Consistent Adversarial Networks* paper (Zhu et al.,
    2017).

    During each training step the model performs the following sub‑steps:

    1. **Forward translation** – photo → Monet (fake) and Monet → photo (fake).
    2. **Cycle translation** – translate fakes back to their original domain.
    3. **Identity mapping** – pass real images through their own‑domain
       generator to discourage colour shifts.
    4. **Adversarial updates** – compute generator and discriminator losses
       (LSGAN variant) and apply gradients with independent Adam optimizers.

    Attributes:
        m_gen: Generator `keras.Model` mapping *photos ➔ Monet*.
        p_gen: Generator `keras.Model` mapping *Monet ➔ photos*.
        m_disc: Discriminator judging real vs. fake Monet images.
        p_disc: Discriminator judging real vs. fake photo images.
        lambda_cycle: Weight applied to cycle‑consistency and identity losses.
    """

    def __init__(
        self,
        monet_generator: keras.Model,
        photo_generator: keras.Model,
        monet_discriminator: keras.Model,
        photo_discriminator: keras.Model,
        lambda_cycle: int = 10,
    ) -> None:
        """Initializes the composite CycleGAN.

        Args:
            monet_generator: Pre‑built generator that converts photos to the
                Monet style.
            photo_generator: Generator that converts Monet paintings back to
                photo style (inverse mapping).
            monet_discriminator: Discriminator that classifies real vs fake
                Monet images.
            photo_discriminator: Discriminator that classifies real vs fake
                photo images.
            lambda_cycle: Scaling factor for the cycle‑consistency loss term.
        """
        super().__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle

    def compile(
        self,
        *,
        m_gen_optimizer: keras.optimizers.Optimizer,
        p_gen_optimizer: keras.optimizers.Optimizer,
        m_disc_optimizer: keras.optimizers.Optimizer,
        p_disc_optimizer: keras.optimizers.Optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn,
    ) -> None:
        """Configures optimizers and loss callables.

        **All arguments must be passed by keyword** to avoid accidental mixing
        of optimizers.

        Args:
            m_gen_optimizer: Optimizer for the photo➔Monet generator.
            p_gen_optimizer: Optimizer for the Monet➔photo generator.
            m_disc_optimizer: Optimizer for the Monet discriminator.
            p_disc_optimizer: Optimizer for the photo discriminator.
            gen_loss_fn: Callable implementing the adversarial generator loss
                (expects discriminator logits for fake images).
            disc_loss_fn: Callable implementing the discriminator loss (expects
                real and fake logits).
            cycle_loss_fn: Callable computing cycle‑consistency L1 loss.
            identity_loss_fn: Callable computing identity‑mapping L1 loss.
        """
        super().compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def call(self, inputs, training: bool = False):
        """Runs a full forward‑and‑cycle pass (no gradient side‑effects).

        This method is primarily used for **inference/visualization**.  It
        translates each domain, cycles back, and returns all intermediate
        tensors so callers can inspect generator outputs.

        Args:
            inputs: Tuple `(real_monet, real_photo)` where each element is a
                batch of images scaled to `[-1, 1]`.
            training: Forwarded to internal layers so things like Dropout or
                GroupNorm behave correctly.

        Returns:
            Tuple `(fake_monet, cycled_photo, fake_photo, cycled_monet)` with
            tensors in the same order as described above.
        """
        real_monet, real_photo = inputs
        fake_monet = self.m_gen(real_photo, training=training)
        cycled_photo = self.p_gen(fake_monet, training=training)
        fake_photo = self.p_gen(real_monet, training=training)
        cycled_monet = self.m_gen(fake_photo, training=training)
        return fake_monet, cycled_photo, fake_photo, cycled_monet

    def train_step(self, batch_data):
        """Executes one training step comprising G & D updates for both domains.

        The method follows the standard Keras `train_step` contract so the
        model can be trained via `model.fit`.  Internally it computes generator
        and discriminator losses, applies gradients with the optimizers
        provided in `compile`, and returns a dictionary of metrics.

        Args:
            batch_data: Tuple `(real_monet, real_photo)` drawn from the zipped
                training dataset.

        Returns:
            A `dict` mapping metric names to scalar tensors — these get logged
            by Keras and shown in the training progress bar.
        """
        real_monet, real_photo = batch_data

        with tf.GradientTape(persistent=True) as tape:
            # ---------- Forward pass (G) ---------- #
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # ---------- Identity mapping ---------- #
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # ---------- Discriminator logits ------- #
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # ---------- Losses --------------------- #
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)
            total_cycle_loss = (
                self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) +
                self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)
            )
            total_monet_gen_loss = (
                monet_gen_loss + total_cycle_loss +
                self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            )
            total_photo_gen_loss = (
                photo_gen_loss + total_cycle_loss +
                self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)
            )

            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # ---------- Gradients & optimiser steps ---- #
        self.m_gen_optimizer.apply_gradients(
            zip(tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables),
                self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(
            zip(tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables),
                self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(
            zip(tape.gradient(monet_disc_loss, self.m_disc.trainable_variables),
                self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(
            zip(tape.gradient(photo_disc_loss, self.p_disc.trainable_variables),
                self.p_disc.trainable_variables))

        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss,
        }

# Loss functions

In [None]:
with strategy.scope():

    def discriminator_loss(real, generated):
        """Calculates the Least‑Squares GAN loss for the discriminator.

        The discriminator is trained to output values close to **1** for real
        images and **0** for generated (fake) images. The loss is the average of
        two binary cross‑entropy terms, encouraging correct classification of both
        real and fake batches.

        Args:
            real: Tensor of discriminator logits for real images.
            generated: Tensor of discriminator logits for generated images.

        Returns:
            A scalar tensor containing the discriminator loss for the batch.
        """
        real_loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.ones_like(real), real)
        gen_loss = tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.zeros_like(generated), generated)
        return 0.5 * (real_loss + gen_loss)

    def generator_loss(generated):
        """Computes the generator's adversarial loss.

        The generator aims to fool the discriminator, so it is rewarded when the
        discriminator predicts **1** (real) for generated images. The loss is the
        binary cross‑entropy between the discriminator logits for generated images
        and a target tensor of ones.

        Args:
            generated: Tensor of discriminator logits for generated images.

        Returns:
            A scalar tensor representing the generator loss.
        """
        return tf.keras.losses.BinaryCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )(tf.ones_like(generated), generated)

    def calc_cycle_loss(real_image, cycled_image, lam):
        """Measures cycle‑consistency error between original and cycled images.

        After translating an image to the opposite domain and back again, the output
        should closely match the original. This function computes the mean absolute
        error (L1 distance) between `real_image` and `cycled_image`, then scales it
        by `lam`.

        Args:
            real_image: Batch of source‑domain images.
            cycled_image: Images obtained after forward‑and‑back translation.
            lam: Weighting factor for the cycle‑loss term.

        Returns:
            A scalar tensor with the weighted cycle‑consistency loss.
        """
        return lam * tf.reduce_mean(tf.abs(real_image - cycled_image))

    def identity_loss(real_image, same_image, lam):
        """Enforces identity mapping for images already in the target domain.

        Passing a target‑domain image through the generator should ideally leave it
        unchanged. This loss penalizes differences between `real_image` and the
        generator output `same_image`.

        Args:
            real_image: Images that already belong to the generator's target style.
            same_image: Generator output for `real_image`.
            lam: Scaling factor, typically half the cycle‑loss weight.

        Returns:
            A scalar tensor containing the identity‑mapping loss.
        """
        return 0.5 * lam * tf.reduce_mean(tf.abs(real_image - same_image))

# Optimizers

In [None]:
with strategy.scope():
    monet_gen_opt = keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_gen_opt = keras.optimizers.Adam(2e-4, beta_1=0.5)
    monet_disc_opt = keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_disc_opt = keras.optimizers.Adam(2e-4, beta_1=0.5)

    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )
    cycle_gan_model.compile(
        m_gen_optimizer=monet_gen_opt,
        p_gen_optimizer=photo_gen_opt,
        m_disc_optimizer=monet_disc_opt,
        p_disc_optimizer=photo_disc_opt,
        gen_loss_fn=generator_loss,
        disc_loss_fn=discriminator_loss,
        cycle_loss_fn=calc_cycle_loss,
        identity_loss_fn=identity_loss,
    )

# Training loop

In [None]:
train_ds = tf.data.Dataset.zip((monet_ds, photo_ds)).prefetch(AUTO)
_ = cycle_gan_model(next(iter(train_ds.take(1))), training=False)  # Build variables
cycle_gan_model.fit(train_ds, epochs=25)

# Visualize a few results

In [None]:
fig, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(5)):
    pred = monet_generator(img, training=False)[0].numpy()
    pred = (pred * 127.5 + 127.5).astype(np.uint8)
    inp = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
    ax[i, 0].imshow(inp)
    ax[i, 0].set_title("Input Photo")
    ax[i, 0].axis("off")
    ax[i, 1].imshow(pred)
    ax[i, 1].set_title("Monet Output")
    ax[i, 1].axis("off")
plt.show()

# Generate submission zip

In [None]:
SUB_DIR = Path("/kaggle/working/images")
SUB_DIR.mkdir(exist_ok=True)

idx = 0
for batch in photo_ds:
    fake_batch = monet_generator(batch, training=False).numpy()
    for img in fake_batch:
        arr = ((img * 127.5) + 127.5).astype(np.uint8)
        keras.utils.save_img(SUB_DIR / f"{idx}.jpg", arr, scale=False)
        idx += 1

print(f"✅ wrote {idx} images")
shutil.make_archive("/kaggle/working/images", "zip", SUB_DIR)