In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K

from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose, concatenate, BatchNormalization, Dropout, Lambda

from tensorflow.keras.callbacks import Callback


from tensorflow.keras.optimizers import Adam
import segmentation_models as sm


# --- Sanity ---
def test_training_sanity():
    print("✅ from training.ipynb")

# --- Visualisation ---
def visualise_prediction(rgb, true_mask, pred_mask):
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")
    axs[1].imshow(COLOR_PALETTE[true_mask])
    axs[1].set_title("True Mask")
    axs[1].axis("off")
    axs[2].imshow(COLOR_PALETTE[pred_mask])
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()

INPUT_TYPE_CONFIG = {
    "1ch": {"description": "grayscale only", "channels": 1},
    "2ch": {"description": "grayscale + elevation", "channels": 2},
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elevation": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

class ClearMemory(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        gc.collect()
        K.clear_session()

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map



def train_model(base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
                input_type="rgb_elevation", model_type="unet", tile_size=256,
                batch_size=8, steps=None, epochs=10, train_time=20, verbose=1
                ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    print(f"\n🔧 Training {model_type.upper()} with input type: {input_type} ({num_channels} channels)")
    print(f"🧪 Computed input shape: ({tile_size}, {tile_size}, {num_channels})")

    train_images = os.path.join(base_dir, "train", "images")
    train_elev = os.path.join(base_dir, "train", "elevations")
    train_labels = os.path.join(base_dir, "train", "labels")

    val_images = os.path.join(base_dir, "val", "images")
    val_elev = os.path.join(base_dir, "val", "elevations")
    val_labels = os.path.join(base_dir, "val", "labels")

    test_images = os.path.join(base_dir, "test", "images")
    test_elev = os.path.join(base_dir, "test", "elevations")
    test_labels = os.path.join(base_dir, "test", "labels")


    # --- Generators ---
    train_gen = StreamingDataGenerator(
        train_images, train_elev, train_labels, 
        batch_size=batch_size, input_type=input_type, 
        shuffle=True, steps=steps, fixed=False, augment=True,
        background_threshold=0.95
        )
    
    val_gen = StreamingDataGenerator(
        val_images, val_elev, val_labels, 
        batch_size=batch_size, steps=steps, input_type=input_type, 
        shuffle=False, fixed=True, augment=False
        )
    

    # --- Model ---
    if model_type == "unet":
        print("🧪 Calling build_unet...")
        model = build_unet(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)

    elif model_type == "multi_unet":
        print("🧪 Calling build_multi_unet...")
        model = build_multi_unet(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)

    elif model_type == "unet_aux":
        print("🧪 Calling build_multi_unet_aux...")
        model = build_unet_aux(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)

    elif model_type == "segformer":
        model = build_segformer(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    
    elif model_type == "model3":
        model = build_model_3(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    
    elif model_type == "model0":
        model = build_model_0(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    
    elif model_type == "model1":
        model = build_model_1(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    
    elif model_type == "model2":
        model = build_model_2(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    
    else:
        raise ValueError(f"Unknown model_type: {model_type}")
    

    print(f"🧪 Final model input shape: {model.input_shape}")


    from tensorflow.keras.optimizers.schedules import LearningRateSchedule
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.losses import CategoricalCrossentropy

    # Custom learning rate schedule
    class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, d_model, warmup_steps=4000):
            super().__init__()
            self.d_model = tf.cast(d_model, tf.float32)
            self.warmup_steps = tf.cast(warmup_steps, tf.float32)

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * tf.pow(self.warmup_steps, -1.5)
            return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

        def get_config(self):
            return {
                "d_model": self.d_model.numpy(),  # convert back to Python float
                "warmup_steps": self.warmup_steps.numpy()
            }


    # Instantiate learning rate schedule and optimizer
    lr_schedule = TransformerLRSchedule(d_model=tile_size)
    optimizer = Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    # Loss function with label smoothing
    # loss_fn = CategoricalCrossentropy(label_smoothing=0.1)

    def apply_label_smoothing(y_true, smoothing=0.1):
        num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
        return y_true * (1.0 - smoothing) + (smoothing / num_classes)


    os.environ["SM_FRAMEWORK"] = "tf.keras"

    # Set class weights
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]

    '''
    dice_loss = sm.losses.DiceLoss(class_weights = weights)
    focal_loss = sm.losses.CategoricalFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)
    '''

    # Raw losses from segmentation_models
    raw_dice = sm.losses.DiceLoss(class_weights=weights)
    raw_focal = sm.losses.CategoricalFocalLoss()

    # Final loss function with label smoothing applied
    def total_loss_with_smoothing(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=0.1)
        return raw_dice(y_true_smoothed, y_pred) + raw_focal(y_true_smoothed, y_pred)


    # --- Jaccard Index ---
    from tensorflow.keras.metrics import MeanIoU

    # Custom wrapper for one-hot -> argmax conversion
    class MeanIoUMetric(tf.keras.metrics.MeanIoU):
        def __init__(self, num_classes, name="mean_iou", dtype=None):
            super().__init__(num_classes=num_classes, name=name, dtype=dtype)

        def update_state(self, y_true, y_pred, sample_weight=None):
            y_true = tf.argmax(y_true, axis=-1)
            y_pred = tf.argmax(y_pred, axis=-1)
            return super().update_state(y_true, y_pred, sample_weight)

    miou_metric = MeanIoUMetric(num_classes=6)

    if model_type == "unet_aux":
        model.compile(
            optimizer=optimizer,
            loss={'main_output': total_loss_with_smoothing, 'aux_output': total_loss_with_smoothing},
            loss_weights={'main_output': 1.0, 'aux_output': 0.4},
            metrics={'main_output': miou_metric, 'aux_output': miou_metric}
        )
    else:
        model.compile(
            optimizer=optimizer,
            loss=total_loss_with_smoothing,
            metrics=[miou_metric, 'categorical_accuracy', 'accuracy']
        )

    model.summary()

        
    early_stop = EarlyStopping(monitor='val_mean_iou', patience=16, restore_best_weights=True, mode='max')

    # reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=8, min_lr=1e-6, verbose=1)
    # early_stop = EarlyStopping(monitor=total_loss, patience=16, restore_best_weights=True)

    os.makedirs("checkpoints", exist_ok=True)
    checkpoint = ModelCheckpoint("checkpoints/best_model.h5", monitor='val_mean_iou', save_best_only=True)
    nan_terminate = TerminateOnNaN()


    class TimeLimitCallback(tf.keras.callbacks.Callback):
        def __init__(self, max_minutes=20):
            super().__init__()
            self.max_duration = max_minutes * 60
        def on_train_begin(self, logs=None):
            self.start_time = tf.timestamp()
        def on_epoch_end(self, epoch, logs=None):
            elapsed = tf.timestamp() - self.start_time
            if elapsed > self.max_duration:
                print(f"⏱️ Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
                self.model.stop_training = True


    from collections import defaultdict
    CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

    class DistributionLogger(tf.keras.callbacks.Callback):
        def __init__(self, generator, name="Training", max_batches=16):
            super().__init__()
            self.generator = generator
            self.name = name
            self.max_batches = max_batches
            self.cumulative_class_counts = defaultdict(int)

        def on_epoch_end(self, epoch, logs=None):
            batch_class_counts = defaultdict(int)
            batches_seen = 0

            for batch_images, batch_labels in self.generator:
                if batches_seen >= self.max_batches:
                    break
                unique, counts = np.unique(np.argmax(batch_labels, axis=-1), return_counts=True)
                for u, c in zip(unique, counts):
                    batch_class_counts[u] += c
                    self.cumulative_class_counts[u] += c
                batches_seen += 1

            total_pixels = sum(batch_class_counts.values())
            total_images = batches_seen * self.generator.batch_size

            print(f"\n📊 {self.name} Distribution After Epoch {epoch + 1} ({total_images:,} images):")
            for cls in sorted(batch_class_counts):
                count = batch_class_counts[cls]
                percent = 100.0 * count / total_pixels
                print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

        def on_train_end(self, logs=None):
            total_pixels = sum(self.cumulative_class_counts.values())

            print("\n📊 Final Cumulative Training Class Distribution:")
            print(f"  Total pixels: {total_pixels:,} px")
            for cls in sorted(self.cumulative_class_counts):
                count = self.cumulative_class_counts[cls]
                percent = 100.0 * count / total_pixels
                print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

            plot_class_distribution(self.cumulative_class_counts)


    class LearningRateLogger(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            lr = self.model.optimizer._decayed_lr(tf.float32).numpy()
            print(f"Learning Rate at epoch {epoch + 1}: {lr:.6f}")


    time_limit = TimeLimitCallback(max_minutes=train_time)

    callbacks = [checkpoint, early_stop, nan_terminate, time_limit, ClearMemory(),
                 DistributionLogger(train_gen, name="Training", max_batches=steps),
                 LearningRateLogger()]

    print("🚀 Starting training...")
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )


    # --- Evaluate Model ---
    from tqdm import trange
    
    # 📈 Plotting Training Curves for Mean IoU and val loss 
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(1, len(loss) + 1)
    plt.plot(epochs, loss, 'y', label="Training Loss")
    plt.plot(epochs, val_loss, 'r', label="Validation Loss")
    plt.title("Training Vs Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join(out_dir, "loss_plot.png"))
    plt.show()

    mean_iou = history.history['mean_iou']
    val_mean_iou = history.history['val_mean_iou']
    epochs = range(1, len(mean_iou) + 1)

    plt.plot(epochs, mean_iou, 'b', label="Training mIoU")
    plt.plot(epochs, val_mean_iou, 'g', label="Validation mIoU")
    plt.title("Training vs Validation Mean IoU")
    plt.xlabel("Epochs")
    plt.ylabel("mIoU")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "mIoU_plot.png"))
    plt.show()

    # 📈 Plotting Random Outputs, Confusion Matrix, Classification Report and mIoU
    test_gen = StreamingDataGenerator(test_images, test_elev, test_labels,
                                   batch_size=batch_size, steps=steps, input_type=input_type,
                                   shuffle=False,
                                   fixed=True, augment=False)

    evaluate_on_test(model, test_gen, n_vis=30)

    


















    # --- Archive ---

    '''
    class_weights = compute_class_weights(train_gen)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    def weighted_loss(y_true, y_pred):
        weights = tf.reduce_sum(class_weights * y_true, axis=-1)
        return tf.reduce_mean(weights * loss_fn(y_true, y_pred))

    model.compile(optimizer='adam', loss=weighted_loss, metrics=['accuracy'])
    print("✅ Model compiled with class weights")
    '''

    '''
    from segmentation_models.metrics import iou_score as jaccard_coef
    metrics = ["accuracy", jaccard_coef]
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
    
    dice_loss = dice_loss(class_weights = weights)
    focal_loss = categorical_focal_loss()
    total_loss = dice_loss + (1 * focal_loss)
    
    model.compile(optimizer="adam", loss="total_loss", metrics=metrics)
    '''

    '''
    def jaccard_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value
    '''