# I’m Something of a Painter Myself – Monet GAN
# Author: James Coffey   Date: 2025‑07‑30

* Mixed‑precision + TPU‑friendly throughput
* ResNet‑9 CycleGAN w/ 70×70 PatchGAN discriminator
*  Augmentation pipeline loaded from TFRecords (fast)
*  Replay buffer, TTUR, EMA of generators
*  LSGAN losses, identity & cycle consistency
*  MiFID callback + early stop on memorization term
*  Direct streaming of 7 500 monetified images → images.zip

# Imports & Environment Setup

In [None]:
import os, io, zipfile, time, math, random, itertools, shutil, subprocess, glob
from functools import partial
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt

class InstanceNorm(layers.Layer):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
    def build(self, input_shape):
        self.gamma = self.add_weight(name="gamma", shape=(input_shape[-1],), initializer=tf.random_normal_initializer(1.0, 0.02), trainable=True)
        self.beta  = self.add_weight(name="beta",  shape=(input_shape[-1],), initializer="zeros", trainable=True)
    def call(self, x):
        mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
        inv = tf.math.rsqrt(var + self.epsilon)
        normalized = (x - mean) * inv
        return self.gamma * normalized + self.beta

# TPU detection --------------------------------------------------

def detect_tpu():
    """Detects and creates an appropriate `tf.distribute` strategy."""
    try:
        strategy = tf.distribute.TPUStrategy()
        print("TPU‑VM detected – replicas:", strategy.num_replicas_in_sync)
        return strategy
    except (ValueError, NotImplementedError):
        try:
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver("local")
            tf.config.experimental_connect_to_cluster(resolver)
            tf.tpu.experimental.initialize_tpu_system(resolver)
            strategy = tf.distribute.TPUStrategy(resolver)
            print("Remote TPU detected – replicas:", strategy.num_replicas_in_sync)
            return strategy
        except Exception as e:
            print("TPU not available (", e, ") – falling back to CPU/GPU")
            return tf.distribute.MirroredStrategy()

strategy = detect_tpu()

# Mixed precision (commented out for now) -----------------------
TF_VERSION = tuple(map(int, tf.__version__.split(".")[:2]))

AUTO      = tf.data.AUTOTUNE
IMG_SIZE  = 256
BATCH     = 8   # per replica; effective batch = 8 × replicas
SEED      = 42
BUFFER    = 1024

np.random.seed(SEED)
random.seed(SEED)

E0000 00:00:1754519838.464572      10 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:230


# TFRecord Data Load

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path()
MONET_TFREC = tf.io.gfile.glob(f"{GCS_PATH}/monet_tfrec/*.tfrec")
PHOTO_TFREC = tf.io.gfile.glob(f"{GCS_PATH}/photo_tfrec/*.tfrec")
print(f"Monet TFRecords: {len(MONET_TFREC)}  |  Photo TFRecords: {len(PHOTO_TFREC)}")

# --- Parsing helpers -------------------------------------------
IMAGE_FEATURE_DESCRIPTION = {
    "image_name": tf.io.FixedLenFeature([], tf.string),
    "image": tf.io.FixedLenFeature([], tf.string),
    "target": tf.io.FixedLenFeature([], tf.string),
}

def _parse_example(proto):
    example = tf.io.parse_single_example(proto, IMAGE_FEATURE_DESCRIPTION)
    img   = tf.image.decode_jpeg(example["image"], channels=3)
    img   = tf.cast(img, tf.float32) / 127.5 - 1.0   # [-1,1]
    img.set_shape([IMG_SIZE, IMG_SIZE, 3])
    return img

# --- Augmentations --------------------------------------------

def _augment(img):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_brightness(img, 0.2)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.image.random_hue(img, 0.05)
    img = tf.image.resize(img, [286, 286])
    img = tf.image.random_crop(img, [IMG_SIZE, IMG_SIZE, 3])
    return img

# ----------------------------------------------------------------

def make_dataset(tfrec_files, augment=False):
    ds = tf.data.TFRecordDataset(tfrec_files, num_parallel_reads=AUTO)
    ds = ds.shuffle(BUFFER, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(_parse_example, num_parallel_calls=AUTO)
    if augment:
        ds = ds.map(_augment, num_parallel_calls=AUTO)
    ds = ds.batch(BATCH, drop_remainder=True)
    ds = ds.prefetch(AUTO)
    return ds

monet_ds  = make_dataset(MONET_TFREC, augment=True)
photo_ds  = make_dataset(PHOTO_TFREC, augment=False)

# Peek a few samples -------------------------------------------
sample_monet = next(iter(monet_ds.take(1)))[0]
sample_photo = next(iter(photo_ds.take(1)))[0]
plt.figure(figsize=(6,3))
plt.subplot(1,2,1); plt.imshow((sample_photo+1)/2); plt.title("Photo"); plt.axis(False)
plt.subplot(1,2,2); plt.imshow((sample_monet+1)/2); plt.title("Monet"); plt.axis(False)
plt.show()

# Build the generator

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 128, 128, 64)
        downsample(128, 4), # (bs, 64, 64, 128)
        downsample(256, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        downsample(512, 4), # (bs, 1, 1, 512)
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4), # (bs, 16, 16, 1024)
        upsample(256, 4), # (bs, 32, 32, 512)
        upsample(128, 4), # (bs, 64, 64, 256)
        upsample(64, 4), # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Build the discriminator

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = Discriminator() # differentiates real photos and generated photos

In [None]:
to_monet    = monet_generator(sample_photo[None, ...])


plt.subplot(1, 2, 1)
plt.title("Original Photo")
plt.imshow(sample_photo * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Monet-esque Photo")
plt.imshow(to_monet[0] * 0.5 + 0.5)
plt.show()

# Build the CycleGAN model

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def call(self, inputs, training=False):
        """Run one full cycle and return the intermediate tensors.

        Args
        ----
        inputs : Tuple[tf.Tensor, tf.Tensor]
            • real_monet : batch of Monet images (B, 256, 256, 3)
            • real_photo : batch of Photo images (B, 256, 256, 3)
        training : bool
            Propagated to sub-layers so things like Dropout / InstanceNorm behave.

        Returns
        -------
        Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]
            fake_monet, cycled_photo, fake_photo, cycled_monet
        """
        real_monet, real_photo = inputs

        # Photo → Monet → Photo
        fake_monet    = self.m_gen(real_photo, training=training)
        cycled_photo  = self.p_gen(fake_monet, training=training)

        # Monet → Photo → Monet
        fake_photo    = self.p_gen(real_monet, training=training)
        cycled_monet  = self.m_gen(fake_photo, training=training)

        return fake_monet, cycled_photo, fake_photo, cycled_monet
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # discriminator used to check, inputing real images
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle)

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss + self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle)
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle)

            # evaluates discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_gen_loss": total_monet_gen_loss,
            "photo_gen_loss": total_photo_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        }

# Define loss functions

In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Train the CycleGAN

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

In [None]:
# ────────────────────────────────────────────────────────────────
# Dataset for training — 1st element is *x* = (monet, photo); no labels
# --------------------------------------------------------------------
train_ds = (
    tf.data.Dataset.zip((monet_ds, photo_ds))
    .map(lambda monet, photo: (monet, photo))   # x = (monet, photo)
    .prefetch(AUTO)
)

# Build model variables once using a dummy batch
_ = cycle_gan_model(next(iter(train_ds.take(1))), training=False)

# ────────────────────────────────────────────────────────────────
# Train
# --------------------------------------------------------------------
cycle_gan_model.fit(train_ds, epochs=25)

# Visualize our Monet-esque photos

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(photo_ds.take(5)):
    prediction = monet_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input Photo")
    ax[i, 1].set_title("Monet-esque")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

# Create submission file

In [None]:
import pathlib, PIL.Image, shutil, numpy as np

# Folder where submission images will live
out_dir = pathlib.Path("/kaggle/working/images")
out_dir.mkdir(exist_ok=True)

idx = 0
for batch in photo_ds:                       # batch.shape == (B, 256, 256, 3)
    fake_batch = monet_generator(batch, training=False).numpy()
    for img in fake_batch:                   # iterate over *all* images
        arr = ((img * 127.5) + 127.5).astype(np.uint8)
        PIL.Image.fromarray(arr).save(out_dir / f"{idx}.jpg")
        idx += 1

print(f"✅ wrote {idx} images")              # should match len(photo_ds.unbatch())

# Create images.zip in the working directory for Kaggle submission
shutil.make_archive("/kaggle/working/images", "zip", out_dir)