In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
from tqdm.notebook import tqdm
import glob
import random

base_path = Path.home() / "sar_colorization"
data_path = base_path / "data/v_2"
models_path = base_path / "wgan_model2"


# Ensure model directory exists
models_path.mkdir(parents=True, exist_ok=True)

# TensorBoard setup
log_dir = os.path.join(base_path, "logs")
train_summary_writer = tf.summary.create_file_writer(log_dir)

# Land types
land_types = ['agri', 'barrenland', 'grassland', 'urban']

# Load PNG images from the directory and prepare the dataset
def load_image_paths(dataset_path, land_types):
    sar_image_paths = []
    opt_image_paths = []

    for land_type in land_types:
        s1_path = dataset_path / land_type / 's1'
        s2_path = dataset_path / land_type / 's2'
        
        sar_images = sorted(glob.glob(str(s1_path / '*.png')))
        opt_images = sorted(glob.glob(str(s2_path / '*.png')))
        
        assert len(sar_images) == len(opt_images), f"Mismatch between SAR and optical images in {land_type}."
        
        sar_image_paths.extend(sar_images)
        opt_image_paths.extend(opt_images)
    
    return sar_image_paths, opt_image_paths

# Function to read and preprocess images
def load_and_preprocess_image(sar_image_path, opt_image_path):
    # Load SAR and optical images
    sar_image = tf.io.read_file(sar_image_path)
    sar_image = tf.io.decode_png(sar_image, channels=1)
    sar_image = tf.image.resize(sar_image, [256, 256])
    sar_image = tf.cast(sar_image, tf.float32) / 255.0

    opt_image = tf.io.read_file(opt_image_path)
    opt_image = tf.io.decode_png(opt_image, channels=3)
    opt_image = tf.image.resize(opt_image, [256, 256])
    opt_image = tf.cast(opt_image, tf.float32) / 255.0

    return sar_image, opt_image

# Prepare TensorFlow dataset
def create_dataset(sar_image_paths, opt_image_paths, batch_size=1):
    dataset = tf.data.Dataset.from_tensor_slices((sar_image_paths, opt_image_paths))
    dataset = dataset.shuffle(len(sar_image_paths))  # Shuffle dataset
    dataset = dataset.map(lambda sar, opt: load_and_preprocess_image(sar, opt))
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

# Split the dataset into train and test
def split_dataset(sar_image_paths, opt_image_paths, split_ratio=0.8):
    dataset_size = len(sar_image_paths)
    train_size = int(dataset_size * split_ratio)

    sar_image_paths_train = sar_image_paths[:train_size]
    opt_image_paths_train = opt_image_paths[:train_size]
    
    sar_image_paths_test = sar_image_paths[train_size:]
    opt_image_paths_test = opt_image_paths[train_size:]
    
    return sar_image_paths_train, opt_image_paths_train, sar_image_paths_test, opt_image_paths_test

# Define the generator (Pix2Pix U-Net with depthwise convolution)
def build_generator():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    def downsample(filters, size, apply_batchnorm=True):
        """Downsampling layer with depthwise separable convolution."""
        result = tf.keras.Sequential()
        result.add(layers.SeparableConv2D(filters, size, strides=2, padding='same',
                                         depthwise_initializer=initializer,
                                         pointwise_initializer=initializer,
                                         use_bias=False))
        if apply_batchnorm:
            result.add(layers.BatchNormalization())
        result.add(layers.LeakyReLU())
        return result

    def upsample(filters, size, apply_dropout=False):
        """Upsampling layer with depthwise separable convolution."""
        result = tf.keras.Sequential()
        result.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same',
                                         kernel_initializer=initializer,
                                         use_bias=False))
        result.add(layers.BatchNormalization())
        if apply_dropout:
            result.add(layers.Dropout(0.5))
        result.add(layers.ReLU())
        return result

    inputs = layers.Input(shape=[256, 256, 1])

    down_stack = [
        downsample(64, 4, apply_batchnorm=False),
        downsample(128, 4),
        downsample(256, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
        downsample(512, 4),
    ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4, apply_dropout=True),
        upsample(512, 4),
        upsample(256, 4),
        upsample(128, 4),
        upsample(64, 4),
    ]

    last = layers.Conv2DTranspose(3, 4, strides=2, padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh')

    x = inputs
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return models.Model(inputs=inputs, outputs=x)

# Define the multi-scale PatchGAN discriminator (critic for WGAN-GP)
def build_multiscale_patchgan_critic():
    initializer = tf.random_normal_initializer(0., 0.02)
    inp = layers.Input(shape=[256, 256, 1], name='input_image')
    tar = layers.Input(shape=[256, 256, 3], name='target_image')
    x = layers.concatenate([inp, tar])

    def critic_block(x, filters, size, strides=2, apply_batchnorm=True):
        x = layers.Conv2D(filters, size, strides=strides, padding='same', kernel_initializer=initializer, use_bias=False)(x)
        if apply_batchnorm:
            x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU()(x)
        return x

    down1 = critic_block(x, 64, 4, apply_batchnorm=False)
    down2 = critic_block(down1, 128, 4)
    down3 = critic_block(down2, 256, 4)
    down4 = critic_block(down3, 512, 4, strides=1)

    patch_out = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(down4)

    return tf.keras.Model(inputs=[inp, tar], outputs=[patch_out])

# Gradient penalty for WGAN-GP (Fixed to pass SAR and Optical images separately to critic)
def gradient_penalty(critic, real_images, fake_images, sar_images):
    batch_size = tf.shape(real_images)[0]
    alpha = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)

    # Interpolate between real and fake optical images
    interpolated_optical = alpha * real_images + (1 - alpha) * fake_images

    # Compute the gradient penalty by passing SAR and interpolated optical separately
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated_optical)
        pred = critic([sar_images, interpolated_optical], training=True)  # Pass SAR and interpolated Optical separately
    grads = gp_tape.gradient(pred, [interpolated_optical])[0]
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)

    return gp

# Perceptual Loss using VGG19
def build_vgg19_perceptual_loss():
    vgg = VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False

    content_layers = ['block3_conv3', 'block4_conv3']
    content_model = Model(inputs=vgg.input, outputs=[vgg.get_layer(layer).output for layer in content_layers])
    content_model.trainable = False

    def perceptual_loss(y_true, y_pred):
        y_true_vgg = content_model(y_true)
        y_pred_vgg = content_model(y_pred)
        return tf.reduce_mean([tf.reduce_mean(tf.abs(a - b)) for a, b in zip(y_true_vgg, y_pred_vgg)])
    
    return perceptual_loss

# Define WGAN-GP losses
def generator_loss(gen_output, target, perceptual_loss_fn):
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    perceptual_loss = perceptual_loss_fn(target, gen_output)
    total_gen_loss = l1_loss + perceptual_loss  # No GAN loss here as WGAN doesn't use binary crossentropy
    return total_gen_loss, l1_loss, perceptual_loss

def critic_loss(real_output, fake_output, gp, lambda_gp=10):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output) + lambda_gp * gp

# Training step with WGAN-GP
@tf.function
def train_step(input_image, target, generator, critic, generator_optimizer, critic_optimizer, perceptual_loss_fn, lambda_gp=10):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as critic_tape:
        gen_output = generator(input_image, training=True)  # Generated optical image

        # Real and fake outputs from the critic
        real_output = critic([input_image, target], training=True)
        fake_output = critic([input_image, gen_output], training=True)

        # Calculate gradient penalty
        gp = gradient_penalty(critic, target, gen_output, input_image)

        # Calculate losses
        gen_total_loss, l1_loss, perceptual_loss = generator_loss(gen_output, target, perceptual_loss_fn)
        crit_loss = critic_loss(real_output, fake_output, gp, lambda_gp)

        # Apply gradients
        generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
        critic_gradients = critic_tape.gradient(crit_loss, critic.trainable_variables)

        generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
        critic_optimizer.apply_gradients(zip(critic_gradients, critic.trainable_variables))

    return gen_total_loss, crit_loss, l1_loss, perceptual_loss, gen_output

# PSNR and SSIM calculations
def calculate_accuracy(target, gen_output):
    psnr = tf.reduce_mean(tf.image.psnr(target, gen_output, max_val=1.0))
    ssim = tf.reduce_mean(tf.image.ssim(target, gen_output, max_val=1.0))
    return psnr, ssim

# Training loop with TensorBoard logging
def train(dataset, epochs, checkpoint_dir, n_critic=5, log_image_interval=1500):
    # Build generator and critic models
    generator = build_generator()
    critic = build_multiscale_patchgan_critic()

    # Optimizers for both models
    generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    critic_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # Perceptual loss function using VGG19
    perceptual_loss_fn = build_vgg19_perceptual_loss()

    # Create checkpoint manager for saving models
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     critic_optimizer=critic_optimizer,
                                     generator=generator,
                                     critic=critic)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    # Restore the latest checkpoint if available
    if checkpoint_manager.latest_checkpoint:
        print(f"Restoring from checkpoint: {checkpoint_manager.latest_checkpoint}")
        checkpoint.restore(checkpoint_manager.latest_checkpoint)
        initial_epoch = int(checkpoint_manager.latest_checkpoint.split('-')[-1])
    else:
        print("Starting training from scratch.")
        initial_epoch = 0

    step = 0
    for epoch in range(initial_epoch, epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        progress_bar = tqdm(dataset, desc=f"Epoch {epoch + 1}/{epochs}", unit="batch")

        total_gen_loss, total_crit_loss, total_l1_loss, total_perceptual_loss = 0, 0, 0, 0
        total_psnr, total_ssim, batch_count = 0, 0, 0

        for n, (input_image, target) in enumerate(progress_bar):
            # Train the critic n_critic times before updating the generator
            for _ in range(n_critic):
                _, crit_loss, _, _, _ = train_step(input_image, target, generator, critic, generator_optimizer, critic_optimizer, perceptual_loss_fn, lambda_gp=10)

            # Now train the generator
            gen_total_loss, _, l1_loss, perceptual_loss, gen_output = train_step(
                input_image, target, generator, critic, generator_optimizer, critic_optimizer, perceptual_loss_fn
            )

            # Calculate PSNR and SSIM
            psnr, ssim = calculate_accuracy(target, gen_output)

            # Accumulate metrics
            total_gen_loss += gen_total_loss
            total_crit_loss += crit_loss
            total_l1_loss += l1_loss
            total_perceptual_loss += perceptual_loss
            total_psnr += psnr
            total_ssim += ssim
            batch_count += 1

            progress_bar.set_postfix(gen_loss=f"{gen_total_loss:.4f}", crit_loss=f"{crit_loss:.4f}", psnr=f"{psnr:.4f}", ssim=f"{ssim:.4f}")

            # Visualize and log images every `log_image_interval` batches
            if (n + 1) % log_image_interval == 0:
                visualize_generated_images(input_image, target, gen_output)

                with train_summary_writer.as_default():
                    num_images = min(10, input_image.shape[0])
                    sar_images = (input_image[:num_images] + 1) / 2.0
                    opt_images = (target[:num_images] + 1) / 2.0
                    gen_images = (gen_output[:num_images] + 1) / 2.0

                    image_rows = []
                    for i in range(num_images):
                        sar = sar_images[i, :, :, 0]
                        opt = opt_images[i]
                        gen = gen_images[i]
                        row = tf.concat([tf.repeat(sar[:, :, tf.newaxis], 3, axis=2), opt, gen], axis=1)
                        image_rows.append(row)
                    
                    image_grid = tf.concat(image_rows, axis=0)
                    tf.summary.image("SAR_True_Generated", image_grid[tf.newaxis, :, :, :], step=step)

            step += 1

        # Log average metrics for the entire epoch
        with train_summary_writer.as_default():
            tf.summary.scalar('epoch_generator_loss', total_gen_loss / batch_count, step=epoch)
            tf.summary.scalar('epoch_critic_loss', total_crit_loss / batch_count, step=epoch)
            tf.summary.scalar('epoch_l1_loss', total_l1_loss / batch_count, step=epoch)
            tf.summary.scalar('epoch_perceptual_loss', total_perceptual_loss / batch_count, step=epoch)
            tf.summary.scalar('epoch_psnr', total_psnr / batch_count, step=epoch)
            tf.summary.scalar('epoch_ssim', total_ssim / batch_count, step=epoch)

        # Save the model every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_manager.save(checkpoint_number=epoch + 1)
            print(f"Checkpoint saved at epoch {epoch + 1}")

    return generator  # Return the trained generator




# Visualize SAR and Optical Images
def visualize_images(dataset):
    for sar_image, opt_image in dataset.take(1):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

        ax1.imshow(sar_image[0, :, :, 0], cmap='gray')
        ax1.set_title('SAR Image')
        ax1.axis('off')

        ax2.imshow(opt_image[0, :, :, :])
        ax2.set_title('Optical Image')
        ax2.axis('off')

        plt.show()

# Visualize Generated Images
def visualize_generated_images(sar_image, opt_image, gen_output):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

    ax1.imshow(sar_image[0, :, :, 0], cmap='gray')
    ax1.set_title('SAR Image')
    ax1.axis('off')

    ax2.imshow(opt_image[0, :, :, :])
    ax2.set_title('Optical Image (True)')
    ax2.axis('off')

    ax3.imshow(gen_output[0, :, :, :])
    ax3.set_title('Optical Image (Generated)')
    ax3.axis('off')

    plt.show()

# Generate optical images for test data
def generate_optical_images(generator, test_dataset):
    for sar_image, opt_image in test_dataset:
        gen_output = generator(sar_image, training=False)
        visualize_generated_images(sar_image, opt_image, gen_output)

# Main execution
if __name__ == "__main__":
    BATCH_SIZE = 4  # Optimal batch size based on memory constraints
    EPOCHS = 150    # Increase the epochs for better results

    # Load image paths
    # sar_image_paths, opt_image_paths = load_image_paths(data_path, land_types)

    # Train-test split
    # sar_train, opt_train, sar_test, opt_test = split_dataset(sar_image_paths, opt_image_paths, split_ratio=0.8)

    # Create TensorFlow datasets
    # train_dataset = create_dataset(sar_train, opt_train, BATCH_SIZE)
    # test_dataset = create_dataset(sar_test, opt_test, BATCH_SIZE)

    # Visualize the first pair of SAR and Optical images
    # visualize_images(train_dataset)

    # Train the model
    # trained_generator = train(train_dataset, EPOCHS, models_path, n_critic=5)

    # Generate optical images for all the test data
    # generate_optical_images(trained_generator, test_dataset)

In [None]:
# Assuming new SAR images are located in this directory
new_images_dir = Path.home() / "sar_colorization" / "new_images"

# Path to the saved model checkpoint directory
checkpoint_dir = Path.home() / "sar_colorization" / "wgan_model2"

# Load and preprocess the new SAR images (resizing to 256x256)
def load_new_sar_images(image_paths):
    sar_images = []
    for img_path in image_paths:
        sar_image = tf.io.read_file(img_path)
        sar_image = tf.io.decode_png(sar_image, channels=1)  # Keep it grayscale
        sar_image = tf.image.resize(sar_image, [256, 256])   # Resize to 256x256
        sar_image = tf.cast(sar_image, tf.float32) / 255.0    # Normalize to [0, 1]
        sar_images.append(sar_image)
    
    # Convert to a batch for inference
    return tf.stack(sar_images, axis=0)

# Visualize the generated optical images
def visualize_generated_images(sar_images, gen_images):
    num_images = sar_images.shape[0]
    for i in range(num_images):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
        
        ax1.imshow(sar_images[i, :, :, 0], cmap='gray')
        ax1.set_title('SAR Image')
        ax1.axis('off')

        ax2.imshow(gen_images[i])
        ax2.set_title('Generated Optical Image')
        ax2.axis('off')

        plt.show()

# Load the trained model from checkpoint
def load_trained_model(checkpoint_dir):
    generator = build_generator()  # Rebuild the generator architecture
    checkpoint = tf.train.Checkpoint(generator=generator)
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        checkpoint.restore(latest_checkpoint).expect_partial()
        print(f"Restored model from {latest_checkpoint}")
    else:
        print("No checkpoint found, ensure the path is correct.")
    
    return generator

# Main testing function
def test_on_new_images(image_paths, checkpoint_dir):
    # Load the SAR images
    sar_images = load_new_sar_images(image_paths)
    
    # Load the trained generator model
    generator = load_trained_model(checkpoint_dir)
    
    # Generate optical images
    gen_images = generator(sar_images, training=False)
    
    # Post-process generated images for visualization
    gen_images = (gen_images + 1) / 2.0  # Bring the values back to [0, 1] from [-1, 1]
    
    # Visualize results
    visualize_generated_images(sar_images, gen_images)

# Dynamically get all the SAR image paths (matching the naming convention SAR-Image-*)
new_sar_image_paths = glob.glob(str(new_images_dir / "SAR-Image-*.jpg"))

# Test the model on the new images
test_on_new_images(new_sar_image_paths, checkpoint_dir)


In [None]:
# Function to restore the latest checkpoint and print the checkpoint number
def load_latest_checkpoint(checkpoint_dir, generator):
    # Initialize the checkpoint manager
    checkpoint = tf.train.Checkpoint(generator=generator)
    checkpoint_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)

    # Restore the latest checkpoint
    if checkpoint_manager.latest_checkpoint:
        checkpoint.restore(checkpoint_manager.latest_checkpoint).expect_partial()
        checkpoint_number = int(checkpoint_manager.latest_checkpoint.split('-')[-1])
        print(f"Restored from checkpoint: {checkpoint_manager.latest_checkpoint}")
        print(f"Checkpoint number: {checkpoint_number}")
    else:
        print("No checkpoint found.")
        checkpoint_number = None

    return checkpoint_number

# Function to visualize SAR, Optical, and Generated images
def visualize_sar_opt_gen(sar_image, opt_image, gen_image):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))

    ax1.imshow(sar_image[0, :, :, 0], cmap='gray')
    ax1.set_title('SAR Image')
    ax1.axis('off')

    ax2.imshow(opt_image[0, :, :, :])
    ax2.set_title('Optical Image (True)')
    ax2.axis('off')

    ax3.imshow(gen_image[0, :, :, :])
    ax3.set_title('Optical Image (Generated)')
    ax3.axis('off')

    plt.show()

# Generate and visualize images for 15 examples from the test dataset
def generate_and_visualize_images(generator, test_dataset, num_images=150):
    count = 0
    for sar_image, opt_image in test_dataset.take(num_images):
        gen_image = generator(sar_image, training=False)
        visualize_sar_opt_gen(sar_image, opt_image, gen_image)
        count += 1
        if count >= num_images:
            break

# Main block to load checkpoint and visualize images
if __name__ == "__main__":
    # Define paths and load the generator
    checkpoint_dir = models_path  # Ensure this path matches the one used during training

    # Rebuild the generator model (must match the architecture used in training)
    generator = build_generator()

    # Load the latest checkpoint
    checkpoint_number = load_latest_checkpoint(checkpoint_dir, generator)

    if checkpoint_number:
        # Generate and visualize 15 images from the test dataset
        generate_and_visualize_images(generator, train_dataset, num_images=150)
