In [None]:
import os
import shutil
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model

# Parameters
INPUT_SIZE = (256, 256)
BS = 16
ROOT_DIR = "/kaggle/"

# Paths for various terrain types
#DATASETS = ["agri", "barrenland", "grassland", "urban"]
DATASETS = ["agri"]
ROOT_DATASET_PATH = os.path.join(ROOT_DIR, 'input/sentinel12-image-pairs-segregated-by-terrain/v_2')

# Set up the data directories
DATA_GEN_INPUT = os.path.join(ROOT_DIR, 'DATASET')
if os.path.exists(DATA_GEN_INPUT):
    shutil.rmtree(DATA_GEN_INPUT)
os.mkdir(DATA_GEN_INPUT)

# Link all the datasets into the main DATASET directory
for terrain in DATASETS:
    src_s1 = os.path.join(ROOT_DATASET_PATH, terrain, "s1")
    src_s2 = os.path.join(ROOT_DATASET_PATH, terrain, "s2")
    dst_s1 = os.path.join(DATA_GEN_INPUT, f"DATA_{terrain}_s1")
    dst_s2 = os.path.join(DATA_GEN_INPUT, f"DATA_{terrain}_s2")
    os.symlink(src_s1, dst_s1)
    os.symlink(src_s2, dst_s2)
    print(f"Linked {src_s1} to {dst_s1}")
    print(f"Linked {src_s2} to {dst_s2}")

# Image preprocessing function
def preprocessing_function(img):
    return np.float32(img / 255.0)

# Custom data generator
class PairedImageDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, base_dir, terrain_types, batch_size, input_size):
        self.base_dir = base_dir
        self.terrain_types = terrain_types
        self.batch_size = batch_size
        self.input_size = input_size
        self.image_files = self._get_image_files()
        print(f"Total image pairs: {len(self.image_files)}")

    def _get_image_files(self):
        files = []
        for terrain in self.terrain_types:
            s1_dir = os.path.join(self.base_dir, f"DATA_{terrain}_s1")
            s2_dir = os.path.join(self.base_dir, f"DATA_{terrain}_s2")
            s1_files = sorted(os.listdir(s1_dir))
            s2_files = sorted(os.listdir(s2_dir))
            print(f"Found {len(s1_files)} images for {terrain}")
            files.extend([(os.path.join(s1_dir, s1), os.path.join(s2_dir, s2)) 
                          for s1, s2 in zip(s1_files, s2_files)])
        return files

    def __len__(self):
        return len(self.image_files) // self.batch_size

    def __getitem__(self, idx):
        batch_files = self.image_files[idx * self.batch_size:(idx + 1) * self.batch_size]
        x_batch = np.zeros((self.batch_size, *self.input_size, 1))
        y_batch = np.zeros((self.batch_size, *self.input_size, 3))

        for i, (s1_file, s2_file) in enumerate(batch_files):
            try:
                s1_img = tf.keras.preprocessing.image.load_img(s1_file, color_mode='grayscale', target_size=self.input_size)
                s2_img = tf.keras.preprocessing.image.load_img(s2_file, color_mode='rgb', target_size=self.input_size)
                
                s1_array = tf.keras.preprocessing.image.img_to_array(s1_img)
                s2_array = tf.keras.preprocessing.image.img_to_array(s2_img)
                
                x_batch[i, :, :, 0] = preprocessing_function(s1_array[:, :, 0])
                y_batch[i] = preprocessing_function(s2_array)
            except Exception as e:
                print(f"Error loading image pair: {s1_file}, {s2_file}")
                print(f"Error message: {str(e)}")

        return x_batch, y_batch

# Create data generator
train_generator = PairedImageDataGenerator(DATA_GEN_INPUT, DATASETS, BS, INPUT_SIZE)

# Define the model using VGG16 for transfer learning
def create_model(input_shape=(256, 256, 1)):
    # Encoder (VGG16)
    vgg16 = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    vgg16.trainable = False
    
    # Input layer
    input_layer = Input(shape=input_shape)
    
    # Convert grayscale to 3-channel
    x = Conv2D(3, (1, 1), padding='same')(input_layer)
    
    # Extract features from VGG16
    features = vgg16(x)
    
    # Decoder
    x = Conv2D(256, (3, 3), activation='relu', padding='same')(features)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
    x = UpSampling2D((2, 2))(x)
    
    # Output layer
    output_layer = Conv2D(3, (3, 3), activation='tanh', padding='same')(x)
    
    return Model(inputs=input_layer, outputs=output_layer)

# Define custom loss function
def custom_loss():
    vgg16 = VGG16(weights='imagenet', include_top=False, input_shape=(256, 256, 3))
    vgg16.trainable = False
    def loss(y_true, y_pred):
        # Combine MSE and perceptual loss
        mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
        perceptual = tf.reduce_mean(tf.square(vgg16(y_true) - vgg16(y_pred)))
        return mse + 0.1 * perceptual
    return loss

# Metrics calculation
def compute_metrics(y_true, y_pred):
    mse = tf.keras.losses.MeanSquaredError()(y_true, y_pred)
    ssim = tf.reduce_mean(tf.image.ssim(y_true, y_pred, max_val=1.0))
    psnr = tf.reduce_mean(tf.image.psnr(y_true, y_pred, max_val=1.0))
    return mse, ssim, psnr

# Function to visualize and test model performance
def test_model(data_generator, model):
    sar_img, true_color = next(iter(data_generator))
    
    # Prepare for prediction
    colorized_img = model.predict(sar_img)

    # Plot images
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))
    ax[0].imshow(sar_img[0, :, :, 0], cmap='gray')
    ax[0].set_title('Original SAR Image (s1)')
    ax[1].imshow(np.clip(colorized_img[0], 0, 1))
    ax[1].set_title('Colorized SAR Image')
    ax[2].imshow(true_color[0])
    ax[2].set_title('True Color Image (s2)')
    plt.show()

# Build and compile the model
model = create_model(list(INPUT_SIZE) + [1])
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
model.compile(optimizer=opt, loss=custom_loss(), metrics=[compute_metrics])

# Early stopping criteria
patience = 10
best_val_loss = float('inf')
epochs_without_improvement = 0

EPOCHS = 100
train_loss = []

# Training loop with early stopping and visual testing
for epoch in range(0, EPOCHS):
    print(f"[INFO] Starting epoch {epoch + 1}/{EPOCHS}")
    epoch_start = time.time()
    loss_batch = []

    # Training loop for one epoch
    for i in tqdm(range(len(train_generator))):
        try:
            sar_data, true_color = train_generator[i]
            
            # Train model
            loss = model.train_on_batch(sar_data, true_color)
            loss_batch.append(loss)
        except Exception as e:
            print(f"Error during training on batch {i}")
            print(f"Error message: {str(e)}")

    # Calculate average loss for this epoch
    if loss_batch:
        avg_loss = np.mean(loss_batch, axis=0)
        train_loss.append(avg_loss)
        loss_names = ['Total Loss', 'MSE', 'SSIM', 'PSNR']
        for name, value in zip(loss_names, avg_loss):
            print(f'{name}: {value:.4f}')

        # Early stopping logic
        if avg_loss[0] < best_val_loss - 0.001:
            best_val_loss = avg_loss[0]
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement > patience:
            print(f"Early stopping triggered at epoch {epoch + 1}")
            break

        # Visualize colorization results every 5 epochs
        if epoch % 5 == 0:
            plt.figure(figsize=(10, 6))
            plt.plot(train_loss)
            plt.legend(['Total Loss', 'MSE', 'SSIM', 'PSNR'])
            plt.title(f'Training Loss - Epoch {epoch + 1}')
            plt.show()
            test_model(train_generator, model)
    else:
        print("No valid batches in this epoch")

    epoch_end = time.time()
    elapsed = (epoch_end - epoch_start) / 60.0
    print(f"Epoch took {elapsed:.4f} minutes")

# Save the model
model.save('sar_colorization_model.h5')
print("Model saved successfully.")