The first two cells are required for now to run in Kaggle v5e-8 TPU as the installed tensorflow verison does not support TPU. In th future these two cells can be ignored.

In [None]:
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip uninstall --system jax

In [None]:
!export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system tensorflow-tpu=="2.18.0" --find-links https://storage.googleapis.com/libtpu-tf-releases/index.html

### Code begine 

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

In [None]:
# TPU Setup
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
print(tf.__version__)

with strategy.scope():
    policy_type = 'mixed_bfloat16' if isinstance(strategy, tf.distribute.TPUStrategy) else 'mixed_float16'
    mixed_precision.set_global_policy(policy_type)
    print(f"Mixed Precision Policy: {policy_type}")

In [None]:
# Hyperparameters
IMG_SIZE = 512
BATCH_SIZE = 8* strategy.num_replicas_in_sync
EPOCHS = 50
LAMBDA_GAN = 1.0
LAMBDA_FM = 10.0
LAMBDA_VGG = 2
GEN_LR = 0.00005
DISC_LR = 0.0001

# Load dataset paths
GCS_PATH = KaggleDatasets().get_gcs_path()
cinematic_train = tf.io.gfile.glob(str(GCS_PATH + '/FilmSet/train/ClassNeg/*.png'))
input_train = tf.io.gfile.glob(str(GCS_PATH + '/FilmSet/train/input/*.png'))
cinematic_test = tf.io.gfile.glob(str(GCS_PATH + '/FilmSet/test/ClassNeg/*.png'))
input_test = tf.io.gfile.glob(str(GCS_PATH + '/FilmSet/test/input/*.png'))

print(f'Training - Cinematic images: {len(cinematic_train)}')
print(f'Training - Input images: {len(input_train)}')
print(f'Test - Cinematic images: {len(cinematic_test)}')
print(f'Test - Input images: {len(input_test)}')

In [None]:
# Check that all image pairs exist
def verify_pairs(input_paths, target_paths):
    input_names = {os.path.basename(path): path for path in input_paths}
    target_names = {os.path.basename(path): path for path in target_paths}
    
    missing_in_target = set(input_names.keys()) - set(target_names.keys())
    missing_in_input = set(target_names.keys()) - set(input_names.keys())
    
    if missing_in_target:
        print(f"Warning: {len(missing_in_target)} files missing in target folder")
    if missing_in_input:
        print(f"Warning: {len(missing_in_input)} files missing in input folder")
    
    common = set(input_names.keys()) & set(target_names.keys())
    paired_inputs = [input_names[name] for name in sorted(common)]
    paired_targets = [target_names[name] for name in sorted(common)]
    
    return paired_inputs, paired_targets

input_train, cinematic_train = verify_pairs(input_train, cinematic_train)
input_test, cinematic_test = verify_pairs(input_test, cinematic_test)
print(f'Verified paired training samples: {len(input_train)}')
print(f'Verified paired test samples: {len(input_test)}')

In [None]:
# Data loading and preprocessing
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.cast(img, tf.float32)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    img = (img / 127.5) - 1.0  # Normalize to [-1, 1]
    return img

def load_paired_images(input_path, target_path):
    input_img = load_image(input_path)
    target_img = load_image(target_path)
    return input_img, target_img

# Stateless augmentation for determinism
def augment(input_img, target_img):
    seed = tf.random.uniform(shape=[2], minval=0, maxval=2**31 - 1, dtype=tf.int32)
    input_img = tf.image.stateless_random_flip_left_right(input_img, seed)
    target_img = tf.image.stateless_random_flip_left_right(target_img, seed)
    return input_img, target_img

train_dataset = tf.data.Dataset.from_tensor_slices((input_train, cinematic_train))
train_dataset = train_dataset.shuffle(1000)
train_dataset = train_dataset.map(load_paired_images, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.cache()  # Cache decoded images in RAM
train_dataset = train_dataset.map(augment, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=True)
train_dataset = train_dataset.prefetch(AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((input_test, cinematic_test))
test_dataset = test_dataset.map(load_paired_images, num_parallel_calls=AUTOTUNE)
test_dataset = test_dataset.cache()
test_dataset = test_dataset.batch(1).prefetch(AUTOTUNE)

# Calculate steps per epoch
steps_per_epoch = len(input_train) // BATCH_SIZE

# Custom Layers
class SEBlock(layers.Layer):
    def __init__(self, filters, ratio=16, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.ratio = ratio
        
    def build(self, input_shape):
        self.global_pool = layers.GlobalAveragePooling2D()
        self.dense1 = layers.Dense(self.filters // self.ratio, activation='relu')
        self.dense2 = layers.Dense(self.filters, activation='sigmoid')
        self.reshape = layers.Reshape((1, 1, self.filters))
        
    def call(self, inputs):
        se = self.global_pool(inputs)
        se = self.dense1(se)
        se = self.dense2(se)
        se = self.reshape(se)
        return inputs * se

class ResidualBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        
    def build(self, input_shape):
        self.conv1 = layers.Conv2D(self.filters, 3, padding='same')
        self.norm1 = tf.keras.layers.GroupNormalization(groups=-1)
        self.conv2 = layers.Conv2D(self.filters, 3, padding='same')
        self.norm2 = tf.keras.layers.GroupNormalization(groups=-1)
        
    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.norm1(x)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        x = self.norm2(x)
        return inputs + x

def pixel_shuffle(x, scale=2):
    return tf.nn.depth_to_space(x, scale)

# Generator: 
def build_generator():
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    
    # Encoder
    e1 = layers.Conv2D(64, 7, padding='same')(inputs)
    e1 = tf.keras.layers.GroupNormalization(groups=-1)(e1)
    e1 = layers.ReLU()(e1)
    
    e2 = tf.keras.layers.SpectralNormalization(layers.Conv2D(128, 3, strides=2, padding='same'))(e1)
    e2 = tf.keras.layers.GroupNormalization(groups=-1)(e2)
    e2 = layers.ReLU()(e2)
    
    e3 = tf.keras.layers.SpectralNormalization(layers.Conv2D(256, 3, strides=2, padding='same'))(e2)
    e3 = tf.keras.layers.GroupNormalization(groups=-1)(e3)
    e3 = layers.ReLU()(e3)

    e4 = tf.keras.layers.SpectralNormalization(layers.Conv2D(512, 3, strides=2, padding='same'))(e3)
    e4 = tf.keras.layers.GroupNormalization(groups=-1)(e4)
    e4 = layers.ReLU()(e4)
    x = e4
    for _ in range(9):
        x = ResidualBlock(512)(x)
    
    # Decoder with attention skip connections
    # Upsample 1
    x = layers.Conv2D(1024, 3, padding='same')(x)
    x = layers.Lambda(lambda t: pixel_shuffle(t, 2))(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    e3_att = SEBlock(256)(e3)
    x = layers.Concatenate()([x, e3_att])
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    # Upsample 2
    x = layers.Conv2D(512, 3, padding='same')(x)
    x = layers.Lambda(lambda t: pixel_shuffle(t, 2))(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    e2_att = SEBlock(128)(e2)
    x = layers.Concatenate()([x, e2_att])
    x = layers.Conv2D(128, 3, padding='same')(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    # Upsample 3
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.Lambda(lambda t: pixel_shuffle(t, 2))(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    e1_att = SEBlock(64)(e1)
    x = layers.Concatenate()([x, e1_att])
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    x = layers.ReLU()(x)
    
    # Output
    outputs = layers.Conv2D(3, 7, padding='same', activation='tanh')(x)

    
    return keras.Model(inputs, outputs, name='generator')

# Discriminator: Multi-Scale Spectral PatchGAN
def build_discriminator(name='discriminator'):
    inputs = layers.Input(shape=(None, None, 6))
    
    x = tf.keras.layers.SpectralNormalization(layers.Conv2D(64, 4, strides=2, padding='same'))(inputs)
    x = layers.LeakyReLU(0.2)(x)
    
    x = tf.keras.layers.SpectralNormalization(layers.Conv2D(128, 4, strides=2, padding='same'))(x)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    f1 = layers.LeakyReLU(0.2)(x)
    
    x = tf.keras.layers.SpectralNormalization(layers.Conv2D(256, 4, strides=2, padding='same'))(f1)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    f2 = layers.LeakyReLU(0.2)(x)
    
    x = tf.keras.layers.SpectralNormalization(layers.Conv2D(512, 4, strides=1, padding='same'))(f2)
    x = tf.keras.layers.GroupNormalization(groups=-1)(x)
    f3 = layers.LeakyReLU(0.2)(x)
    
    x = tf.keras.layers.SpectralNormalization(layers.Conv2D(1, 4, strides=1, padding='same'))(f3)
    
    return keras.Model(inputs, [x, f1, f2, f3], name=name)

# VGG for Perceptual Loss with float32 output
def build_vgg():
    vgg = keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    layer_name = 'block4_conv2'
    outputs = vgg.get_layer(layer_name).output
    outputs = layers.Activation('linear', dtype='float32')(outputs)
    return keras.Model(vgg.input, outputs, name='vgg_features')

# Loss Functions
def discriminator_loss(real_output, fake_output):
    real_output = tf.cast(real_output, tf.float32)
    fake_output = tf.cast(fake_output, tf.float32)
    real_loss = tf.reduce_mean(tf.maximum(0.0, 1.0 - real_output))
    fake_loss = tf.reduce_mean(tf.maximum(0.0, 1.0 + fake_output))
    return real_loss + fake_loss

def generator_loss(fake_output):
    fake_output = tf.cast(fake_output, tf.float32)
    return -tf.reduce_mean(fake_output)

def feature_matching_loss(real_features, fake_features):
    loss = 0
    for real_f, fake_f in zip(real_features, fake_features):
        real_f = tf.cast(real_f, tf.float32)
        fake_f = tf.cast(fake_f, tf.float32)
        loss += tf.reduce_mean(tf.abs(real_f - fake_f))
    return loss


def perceptual_loss(vgg_model, real_img, fake_img):
    # Combine real and fake for single forward pass
    real_img = tf.cast(real_img, tf.float32)
    fake_img = tf.cast(fake_img, tf.float32)
    combined = tf.concat([real_img, fake_img], axis=0)
    
    # Convert from [-1, 1] to [0, 255]
    combined = (combined + 1) * 127.5
    
    # Apply proper VGG preprocessing (RGB to BGR + ImageNet mean subtraction)
    combined = tf.keras.applications.vgg19.preprocess_input(combined)
    
    # Single VGG forward pass
    features = vgg_model(combined)
    
    # Split features back
    batch_size = tf.shape(real_img)[0]
    real_features = features[:batch_size]
    fake_features = features[batch_size:]
    
    return tf.reduce_mean(tf.abs(real_features - fake_features))

@tf.function
def train_step(input_image, target_image, generator, disc1, disc2, vgg, 
               gen_optimizer, disc_optimizer):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake images
        fake_image = generator(input_image, training=True)
        fake_image = tf.cast(fake_image, tf.float32)
        # Discriminator 1 (full resolution)
        real_pair1 = tf.concat([input_image, target_image], axis=-1)
        fake_pair1 = tf.concat([input_image, fake_image], axis=-1)
        disc1_real, *disc1_real_features = disc1(real_pair1, training=True)
        disc1_fake, *disc1_fake_features = disc1(fake_pair1, training=True)
        
        # Discriminator 2 (half resolution) - Batch resize
        combined_images = tf.concat([input_image, target_image, fake_image], axis=0)
        combined_half = tf.image.resize(combined_images, [IMG_SIZE // 2, IMG_SIZE // 2])
        input_half, target_half, fake_half = tf.split(combined_half, 3, axis=0)
        
        real_pair2 = tf.concat([input_half, target_half], axis=-1)
        fake_pair2 = tf.concat([input_half, fake_half], axis=-1)
        disc2_real, *disc2_real_features = disc2(real_pair2, training=True)
        disc2_fake, *disc2_fake_features = disc2(fake_pair2, training=True)
        
        # Discriminator losses
        disc1_loss = discriminator_loss(disc1_real, disc1_fake)
        disc2_loss = discriminator_loss(disc2_real, disc2_fake)
        total_disc_loss = disc1_loss + disc2_loss
        
        # Generator losses
        gen_gan_loss = (generator_loss(disc1_fake) + generator_loss(disc2_fake)) / 2
        gen_fm_loss = (feature_matching_loss(disc1_real_features, disc1_fake_features) +
                       feature_matching_loss(disc2_real_features, disc2_fake_features)) / 2
        
        # VGG perceptual loss with batched forward pass
        gen_vgg_loss = perceptual_loss(vgg, target_image, fake_image)
        
        total_gen_loss = (LAMBDA_GAN * gen_gan_loss + 
                         LAMBDA_FM * gen_fm_loss + 
                         LAMBDA_VGG * gen_vgg_loss)
    
    gen_gradients = gen_tape.gradient(total_gen_loss, generator.trainable_variables)
    disc_gradients = disc_tape.gradient(total_disc_loss, 
                                        disc1.trainable_variables + disc2.trainable_variables)
    
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_gradients, 
                                       disc1.trainable_variables + disc2.trainable_variables))
    
    return total_gen_loss, total_disc_loss, gen_gan_loss, gen_fm_loss, gen_vgg_loss

# Distributed training step
@tf.function
def distributed_train_step(input_image, target_image):
    def train_step_wrapper(input_img, target_img):
        return train_step(input_img, target_img, generator, discriminator1, 
                         discriminator2, vgg_model, gen_optimizer, disc_optimizer)
    
    per_replica_losses = strategy.run(train_step_wrapper, args=(input_image, target_image))
    
    # Reduce and average losses
    num_replicas = tf.cast(strategy.num_replicas_in_sync, tf.float32)
    
    total_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[0], axis=None) / num_replicas
    total_disc_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[1], axis=None) / num_replicas
    gen_gan_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[2], axis=None) / num_replicas
    gen_fm_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[3], axis=None) / num_replicas
    gen_vgg_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[4], axis=None) / num_replicas
    
    return total_gen_loss, total_disc_loss, gen_gan_loss, gen_fm_loss, gen_vgg_loss

# Visualization function
def generate_and_save_images(generator, test_input, target, epoch, num_examples=3):
    predictions = generator(test_input, training=False)
    predictions = predictions.numpy().astype(np.float32)
    test_input = test_input.numpy().astype(np.float32)
    target = target.numpy().astype(np.float32)
    
    fig = plt.figure(figsize=(15, 5 * num_examples))
    
    for i in range(num_examples):
        # Input
        plt.subplot(num_examples, 3, i * 3 + 1)
        plt.imshow(test_input[i] * 0.5 + 0.5)
        plt.title('Input')
        plt.axis('off')
        
        # Generated
        plt.subplot(num_examples, 3, i * 3 + 2)
        plt.imshow(predictions[i] * 0.5 + 0.5)
        plt.title('Generated')
        plt.axis('off')
        
        # Target
        plt.subplot(num_examples, 3, i * 3 + 3)
        plt.imshow(target[i] * 0.5 + 0.5)
        plt.title('Target (Cinematic)')
        plt.axis('off')
    
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'generated_epoch_{epoch}.png')
    plt.close()

with strategy.scope():
    generator = build_generator()
    discriminator1 = build_discriminator(name='discriminator_full')
    discriminator2 = build_discriminator(name='discriminator_half')
    vgg_model = build_vgg()
    
    # Optimizers with mixed precision wrapper
    gen_optimizer = keras.optimizers.Adam(learning_rate=GEN_LR, beta_1=0.0, beta_2=0.9)
    disc_optimizer = keras.optimizers.Adam(learning_rate=DISC_LR, beta_1=0.0, beta_2=0.9)
    
    gen_optimizer = mixed_precision.LossScaleOptimizer(gen_optimizer)
    disc_optimizer = mixed_precision.LossScaleOptimizer(disc_optimizer)
    
    # Metrics
    gen_loss_metric = keras.metrics.Mean(name='gen_loss', dtype=tf.float32)
    disc_loss_metric = keras.metrics.Mean(name='disc_loss', dtype=tf.float32)
    gan_loss_metric = keras.metrics.Mean(name='gan_loss', dtype=tf.float32)
    fm_loss_metric = keras.metrics.Mean(name='fm_loss', dtype=tf.float32)
    vgg_loss_metric = keras.metrics.Mean(name='vgg_loss', dtype=tf.float32)
    
    print("Generator parameters:", generator.count_params())
    print("Discriminator 1 parameters:", discriminator1.count_params())
    print("Discriminator 2 parameters:", discriminator2.count_params())

In [None]:
# Prepare test samples for visualization
test_samples = list(test_dataset.take(3))
test_inputs = tf.concat([sample[0] for sample in test_samples], axis=0)
test_targets = tf.concat([sample[1] for sample in test_samples], axis=0)

### Training Begins

In [None]:
print("\nStarting Optimized Training...")
train_dataset_dist = strategy.experimental_distribute_dataset(train_dataset)

for epoch in range(EPOCHS):
    print(f'\n{"="*60}')
    print(f'Epoch {epoch + 1}/{EPOCHS}')
    print(f'{"="*60}')
    
    # Reset metrics at the start of each epoch
    gen_loss_metric.reset_state()
    disc_loss_metric.reset_state()
    gan_loss_metric.reset_state()
    fm_loss_metric.reset_state()
    vgg_loss_metric.reset_state()
    
    # Progress bar for the epoch
    pbar = tqdm(enumerate(train_dataset_dist), total=steps_per_epoch, desc=f'Epoch {epoch + 1}')
    
    for step, (input_img, target_img) in pbar:
        gen_loss, disc_loss, gan_loss, fm_loss, vgg_loss = distributed_train_step(input_img, target_img)
        
        # Update metrics
        gen_loss_metric.update_state(gen_loss)
        disc_loss_metric.update_state(disc_loss)
        gan_loss_metric.update_state(gan_loss)
        fm_loss_metric.update_state(fm_loss)
        vgg_loss_metric.update_state(vgg_loss)
        
        # Update progress bar
        pbar.set_postfix({
            'G_Loss': f'{gen_loss_metric.result():.4f}',
            'D_Loss': f'{disc_loss_metric.result():.4f}',
            'GAN': f'{gan_loss_metric.result():.4f}',
            'FM': f'{fm_loss_metric.result():.4f}',
            'VGG': f'{vgg_loss_metric.result():.4f}'
        })
    
    # Epoch summary
    print(f'\nEpoch {epoch + 1} Summary:')
    print(f'  Avg Gen Loss:  {gen_loss_metric.result():.4f}')
    print(f'  Avg Disc Loss: {disc_loss_metric.result():.4f}')
    print(f'  Avg GAN Loss:  {gan_loss_metric.result():.4f}')
    print(f'  Avg FM Loss:   {fm_loss_metric.result():.4f}')
    print(f'  Avg VGG Loss:  {vgg_loss_metric.result():.4f}')
    
    # Generate and save sample images every 10 epochs
    if (epoch + 1) % 10 == 0:
        generate_and_save_images(generator, test_inputs, test_targets, epoch + 1)
        print(f'Saved sample images for epoch {epoch + 1}')

In [None]:
print("Saving models for fine-tuning...")
model_dir = '/kaggle/working/cinematic_gan_model'
os.makedirs(model_dir, exist_ok=True)

generator_path = os.path.join(model_dir, 'generator.keras')
generator.save(generator_path)
print(f'Generator saved to: {generator_path}')
discriminator1_path = os.path.join(model_dir, 'discriminator1.keras')
discriminator1.save(discriminator1_path)
print(f'Discriminator 1 saved to: {discriminator1_path}')

discriminator2_path = os.path.join(model_dir, 'discriminator2.keras')
discriminator2.save(discriminator2_path)
print(f'Discriminator 2 saved to: {discriminator2_path}')

### Finetuning to remove artifacts if any

In [None]:
def total_variation_loss(image):

    # Cast to float32 for computation
    image = tf.cast(image, tf.float32)
    
    # Horizontal differences
    x_diff = tf.reduce_mean(tf.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    # Vertical differences
    y_diff = tf.reduce_mean(tf.abs(image[:, :-1, :, :] - image[:, 1:, :, :]))
    
    return x_diff + y_diff

#### Updated hyperparameters for finetuning

In [None]:
LAMBDA_GAN = 1.0
LAMBDA_FM = 10.0
LAMBDA_VGG = 1.0     
LAMBDA_TV = 20.0      
GEN_LR = 0.000025      
DISC_LR = 0.0001      
FINE_TUNE_EPOCHS = 15

In [None]:

@tf.function
def train_step(input_image, target_image, generator, disc1, disc2, vgg, 
               gen_optimizer, disc_optimizer):
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # Generate fake images
        fake_image = generator(input_image, training=True)
        fake_image = tf.cast(fake_image, tf.float32)
        
        # Discriminator 1 (full resolution)
        real_pair1 = tf.concat([input_image, target_image], axis=-1)
        fake_pair1 = tf.concat([input_image, fake_image], axis=-1)
        disc1_real, *disc1_real_features = disc1(real_pair1, training=True)
        disc1_fake, *disc1_fake_features = disc1(fake_pair1, training=True)
        
        # Discriminator 2 (half resolution) - Batch resize
        combined_images = tf.concat([input_image, target_image, fake_image], axis=0)
        combined_half = tf.image.resize(combined_images, [IMG_SIZE // 2, IMG_SIZE // 2])
        input_half, target_half, fake_half = tf.split(combined_half, 3, axis=0)
        
        real_pair2 = tf.concat([input_half, target_half], axis=-1)
        fake_pair2 = tf.concat([input_half, fake_half], axis=-1)
        disc2_real, *disc2_real_features = disc2(real_pair2, training=True)
        disc2_fake, *disc2_fake_features = disc2(fake_pair2, training=True)
        
        # Discriminator losses
        disc1_loss = discriminator_loss(disc1_real, disc1_fake)
        disc2_loss = discriminator_loss(disc2_real, disc2_fake)
        total_disc_loss = disc1_loss + disc2_loss
        
        # Generator losses
        gen_gan_loss = (generator_loss(disc1_fake) + generator_loss(disc2_fake)) / 2
        gen_fm_loss = (feature_matching_loss(disc1_real_features, disc1_fake_features) +
                       feature_matching_loss(disc2_real_features, disc2_fake_features)) / 2
        
        # VGG perceptual loss
        gen_vgg_loss = perceptual_loss(vgg, target_image, fake_image)
        gen_tv_loss = total_variation_loss(fake_image)
        total_gen_loss = (LAMBDA_GAN * gen_gan_loss + 
                         LAMBDA_FM * gen_fm_loss + 
                         LAMBDA_VGG * gen_vgg_loss +
                         LAMBDA_TV * gen_tv_loss) 
    
    # Apply gradients
    gen_gradients = gen_tape.gradient(total_gen_loss, generator.trainable_variables)
    disc_gradients = disc_tape.gradient(total_disc_loss, 
                                        disc1.trainable_variables + disc2.trainable_variables)
    
    gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(disc_gradients, 
                                       disc1.trainable_variables + disc2.trainable_variables))
    
    return total_gen_loss, total_disc_loss, gen_gan_loss, gen_fm_loss, gen_vgg_loss, gen_tv_loss  # Return TV loss too

@tf.function
def distributed_train_step(input_image, target_image):
    def train_step_wrapper(input_img, target_img):
        return train_step(input_img, target_img, generator, discriminator1, 
                         discriminator2, vgg_model, gen_optimizer, disc_optimizer)
    
    per_replica_losses = strategy.run(train_step_wrapper, args=(input_image, target_image))
    
    num_replicas = tf.cast(strategy.num_replicas_in_sync, tf.float32)
    
    total_gen_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[0], axis=None) / num_replicas
    total_disc_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[1], axis=None) / num_replicas
    gen_gan_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[2], axis=None) / num_replicas
    gen_fm_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[3], axis=None) / num_replicas
    gen_vgg_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[4], axis=None) / num_replicas
    gen_tv_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[5], axis=None) / num_replicas  # NEW
    
    return total_gen_loss, total_disc_loss, gen_gan_loss, gen_fm_loss, gen_vgg_loss, gen_tv_loss  # Return TV

def generate_and_save_images(generator, test_input, target, epoch, num_examples=3):
    predictions = generator(test_input, training=False)
    predictions = predictions.numpy().astype(np.float32)
    test_input = test_input.numpy().astype(np.float32)
    target = target.numpy().astype(np.float32)
    
    fig = plt.figure(figsize=(15, 5 * num_examples))
    
    for i in range(num_examples):
        # Input
        plt.subplot(num_examples, 3, i * 3 + 1)
        plt.imshow(test_input[i] * 0.5 + 0.5)
        plt.title('Input')
        plt.axis('off')
        
        # Generated
        plt.subplot(num_examples, 3, i * 3 + 2)
        plt.imshow(predictions[i] * 0.5 + 0.5)
        plt.title('Generated')
        plt.axis('off')
        
        # Target
        plt.subplot(num_examples, 3, i * 3 + 3)
        plt.imshow(target[i] * 0.5 + 0.5)
        plt.title('Target (Cinematic)')
        plt.axis('off')
    
    plt.suptitle(f'Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(f'Finetune_epoch_{epoch}.png')
    plt.close()


with strategy.scope():
    generator = build_generator()
    discriminator1 = build_discriminator(name='discriminator_full')
    discriminator2 = build_discriminator(name='discriminator_half')
    vgg_model = build_vgg()

    try:
        generator.load_weights('/kaggle/working/cinematic_gan_model/generator.keras')
        print("Generator weights loaded")
    except:
        generator.load_weights('/kaggle/working/cinematic_gan_model/generator')
        print("Generator weights loaded from SavedModel")
    
    try:
        discriminator1.load_weights('/kaggle/working/cinematic_gan_model/discriminator1.keras')
        discriminator2.load_weights('/kaggle/working/cinematic_gan_model/discriminator2.keras')
        print("Discriminator weights loaded")
    except:
        print("Discriminator weights not found (optional for inference)")
    
    # Create NEW optimizers with LOWER learning rates
    gen_optimizer = keras.optimizers.Adam(learning_rate=GEN_LR, beta_1=0.0, beta_2=0.9)
    disc_optimizer = keras.optimizers.Adam(learning_rate=DISC_LR, beta_1=0.0, beta_2=0.9)
    
    gen_optimizer = mixed_precision.LossScaleOptimizer(gen_optimizer)
    disc_optimizer = mixed_precision.LossScaleOptimizer(disc_optimizer)
    
    # Create metrics
    gen_loss_metric = keras.metrics.Mean(name='gen_loss', dtype=tf.float32)
    disc_loss_metric = keras.metrics.Mean(name='disc_loss', dtype=tf.float32)
    gan_loss_metric = keras.metrics.Mean(name='gan_loss', dtype=tf.float32)
    fm_loss_metric = keras.metrics.Mean(name='fm_loss', dtype=tf.float32)
    vgg_loss_metric = keras.metrics.Mean(name='vgg_loss', dtype=tf.float32)
    tv_loss_metric = keras.metrics.Mean(name='tv_loss', dtype=tf.float32)

print("Weights loaded successfully!")
print(f"New learning rates - Gen: {GEN_LR}, Disc: {DISC_LR}")

print("Starting FINE-TUNING...")


train_dataset_dist = strategy.experimental_distribute_dataset(train_dataset)

for epoch in range(FINE_TUNE_EPOCHS):
    print(f'\n{"="*60}')
    print(f'Fine-Tune Epoch {epoch + 1}/{FINE_TUNE_EPOCHS}')
    
    # Reset metrics
    gen_loss_metric.reset_state()
    disc_loss_metric.reset_state()
    gan_loss_metric.reset_state()
    fm_loss_metric.reset_state()
    vgg_loss_metric.reset_state()
    tv_loss_metric.reset_state()
    
    # Progress bar
    pbar = tqdm(enumerate(train_dataset_dist), total=steps_per_epoch, desc=f'Fine-Tune {epoch + 1}')
    
    for step, (input_img, target_img) in pbar:
        gen_loss, disc_loss, gan_loss, fm_loss, vgg_loss, tv_loss = distributed_train_step(
            input_img, target_img
        )
        gen_loss_metric.update_state(gen_loss)
        disc_loss_metric.update_state(disc_loss)
        gan_loss_metric.update_state(gan_loss)
        fm_loss_metric.update_state(fm_loss)
        vgg_loss_metric.update_state(vgg_loss)
        tv_loss_metric.update_state(tv_loss)

        pbar.set_postfix({
            'G_Loss': f'{gen_loss_metric.result():.4f}',
            'D_Loss': f'{disc_loss_metric.result():.4f}',
            'TV': f'{tv_loss_metric.result():.4f}',  
            'VGG': f'{vgg_loss_metric.result():.4f}'
        })

    print(f'\nFine-Tune Epoch {epoch + 1} Summary:')
    print(f'  Avg Gen Loss:  {gen_loss_metric.result():.4f}')
    print(f'  Avg Disc Loss: {disc_loss_metric.result():.4f}')
    print(f'  Avg GAN Loss:  {gan_loss_metric.result():.4f}')
    print(f'  Avg FM Loss:   {fm_loss_metric.result():.4f}')
    print(f'  Avg VGG Loss:  {vgg_loss_metric.result():.4f}')
    print(f'  Avg TV Loss:   {tv_loss_metric.result():.4f} ')
    
    # Generate and save sample images every 5 epochs
    if (epoch + 1) % 5 == 0:
        generate_and_save_images(generator, test_inputs, test_targets, epoch + 1)
        print(f'Saved fine-tuned sample images for epoch {epoch + 1}')

# Save with new name
finetuned_dir = 'cinematic_gan_model_finetuned'
os.makedirs(finetuned_dir, exist_ok=True)

generator.save(os.path.join(finetuned_dir, 'generator.keras'))
print(f'Fine-tuned model saved to: {finetuned_dir}/')
print("\nFine-tuning complete!")