In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict




# --- Sanity ---
def test_training_sanity():
    print("✅ from training.ipynb")


val_files = [
    "1476907971_CHADGRISMOPENPIPELINE",
    "dabec5e872_E8AD935CEDINSPIRE",
    "c6d131e346_536DE05ED2OPENPIPELINE",
    "57426ebe1e_84B52814D2OPENPIPELINE",
    "1726eb08ef_60693DB04DINSPIRE",
    "9170479165_625EDFBAB6OPENPIPELINE",
    "520947aa07_8FCB044F58OPENPIPELINE",
    "cc4b443c7d_A9CBEF2C97INSPIRE",
    "12fa5e614f_53197F206FOPENPIPELINE",
    "2ef3a4994a_0CCD105428INSPIRE",
]

test_files = [
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE",
    "a1af86939f_F1BE1D4184OPENPIPELINE",
    "1553541487_APIGENERATED",
    "74d7796531_EB81FE6E2BOPENPIPELINE",
    "8710b98ea0_06E6522D6DINSPIRE",
    "c644f91210_27E21B7F30OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE", 
]


INPUT_TYPE_CONFIG = {
    "1ch": {"description": "grayscale only", "channels": 1},
    "2ch": {"description": "grayscale + elevation", "channels": 2},
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elevation": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map


def filter_tile_ids_by_substring(image_dir, base_names):
    return [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if any(base in f for base in base_names)]


def train_model(base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
                input_type="rgb_elevation", model_type="unet", tile_size=256,
                batch_size=8, steps=None, epochs=10, train_time=20, verbose=1
                ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    train_images = os.path.join(base_dir, "train", "images")
    train_elev = os.path.join(base_dir, "train", "elevations")
    train_labels = os.path.join(base_dir, "train", "labels")

    eval_images = os.path.join(base_dir, "raw", "images")
    eval_elev = os.path.join(base_dir, "raw", "elevations")
    eval_labels = os.path.join(base_dir, "raw", "labels")

    val_images = os.path.join(base_dir, "raw", "images")
    val_elev = os.path.join(base_dir, "raw", "elevations")
    val_labels = os.path.join(base_dir, "raw", "labels")

    test_images = os.path.join(base_dir, "raw", "images")
    test_elev = os.path.join(base_dir, "raw", "elevations")
    test_labels = os.path.join(base_dir, "raw", "labels")



    # --- Streaming Data Generator ---
    train_gen = StreamingDataGenerator(
        train_images, train_elev, train_labels,
        split='train',
        val_files=val_files,
        test_files=test_files,
        batch_size=batch_size,
        input_type=input_type,
        shuffle=True,
        steps=steps,
        fixed=False,
        augment=True,
        metadata_csv_path="/content/chipped_data/content/chipped_data/train_metadata.csv"
    )

    val_tile_ids = filter_tile_ids_by_substring(os.path.join(base_dir, "train", "images"), val_files)
    test_tile_ids = filter_tile_ids_by_substring(os.path.join(base_dir, "train", "images"), test_files)

    val_tile_ids = filter_tile_ids_by_substring(os.path.join(base_dir, "raw", "images"), val_files)
    test_tile_ids = filter_tile_ids_by_substring(os.path.join(base_dir, "raw", "images"), test_files)


    # val_tile_ids = [f.replace('-ortho.png', '') for f in os.listdir(val_images) if any(f.startswith(v) for v in val_files)]
    val_gen = StreamingDataGenerator(eval_images, eval_elev, eval_labels,
                                     split='val', val_files=val_tile_ids, test_files=test_files,
                                     batch_size=8, steps=len(val_tile_ids) // 8 + 1,
                                     input_type=input_type, shuffle=False, fixed=True, augment=False)

    # test_tile_ids = [f.replace('-ortho.png', '') for f in os.listdir(test_images) if any(f.startswith(v) for v in test_files)]
    test_gen = StreamingDataGenerator(eval_images, eval_elev, eval_labels,
                                      split='test', val_files=val_files, test_files=test_tile_ids,
                                      batch_size=8, steps=len(test_tile_ids) // 8 + 1,
                                      input_type=input_type, shuffle=False, fixed=True, augment=False)


    #input_shape = (tile_size, tile_size, num_channels)
    input_shape = (None, None, num_channels)

    # --- Model ---
    if model_type == "unet":
        print("🧪 Calling build_unet...")
        model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

    elif model_type == "multi_unet":
        print("🧪 Calling build_multi_unet...")
        model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

    elif model_type == "unet_aux":
        print("🧪 Calling build_multi_unet_aux...")
        model = build_unet_aux(input_shape=input_shape, num_classes=NUM_CLASSES)

    elif model_type == "segformer":
        model = build_segformer(input_shape=input_shape, num_classes=NUM_CLASSES)
    
    elif model_type == "CRF":
        model = build_crf(input_shape=input_shape, num_classes=NUM_CLASSES)
    
    elif model_type == "resnet34":
        import segmentation_models as sm
        model = sm.Unet(
            backbone_name="resnet34",          # or 'efficientnetb0', 'mobilenetv2', etc.
            input_shape=input_shape,
            classes=NUM_CLASSES,                  
            activation='softmax', 
            encoder_weights='imagenet'         # Load ImageNet pre-trained weights
        )
    
    elif model_type == "model1":
        model = build_model_1(input_shape=input_shape, num_classes=NUM_CLASSES)
    
    elif model_type == "model2":
        model = build_model_2(input_shape=input_shape, num_classes=NUM_CLASSES)
    
    else:
        raise ValueError(f"Unknown model_type: {model_type}")
    




    # --- Metrics ---
    import tensorflow as tf
    from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TerminateOnNaN
    
    # Custom learning rate schedule
    class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
        def __init__(self, d_model, warmup_steps=4000):
            super().__init__()
            self.d_model = tf.cast(d_model, tf.float32)
            self.warmup_steps = tf.cast(warmup_steps, tf.float32)

        def __call__(self, step):
            step = tf.cast(step, tf.float32)
            arg1 = tf.math.rsqrt(step)
            arg2 = step * tf.pow(self.warmup_steps, -1.5)
            return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

        def get_config(self):
            return {
                "d_model": self.d_model.numpy(),  # convert back to Python float
                "warmup_steps": self.warmup_steps.numpy()
            }


    # Instantiate learning rate schedule and optimizer
    lr_schedule = TransformerLRSchedule(d_model=tile_size)
    optimizer = Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    # Loss function with label smoothing
    # loss_fn = CategoricalCrossentropy(label_smoothing=0.1)

    def apply_label_smoothing(y_true, smoothing=0.1):
        num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
        return y_true * (1.0 - smoothing) + (smoothing / num_classes)


    os.environ["SM_FRAMEWORK"] = "tf.keras"

    # Set class weights
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]

    '''
    dice_loss = sm.losses.DiceLoss(class_weights = weights)
    focal_loss = sm.losses.CategoricalFocalLoss()
    total_loss = dice_loss + (1 * focal_loss)
    '''

    # Raw losses from segmentation_models
    raw_dice = sm.losses.DiceLoss(class_weights=weights)
    raw_focal = sm.losses.CategoricalFocalLoss()

    # Final loss function with label smoothing applied
    def total_loss_with_smoothing(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=0.1)
        return raw_dice(y_true_smoothed, y_pred) + raw_focal(y_true_smoothed, y_pred)



    import tensorflow as tf
    from tensorflow.keras.metrics import MeanIoU

    # --- Jaccard Index ---
    '''
    class MeanIoUMetric(tf.keras.metrics.MeanIoU):
        def __init__(self, num_classes, name="mean_iou", dtype=None):
            super().__init__(num_classes=num_classes, name=name, dtype=dtype)

        def update_state(self, y_true, y_pred, sample_weight=None):
            y_true = tf.argmax(y_true, axis=-1)
            y_pred = tf.argmax(y_pred, axis=-1)
            return super().update_state(y_true, y_pred, sample_weight)

    miou_metric = MeanIoUMetric(num_classes=6)
    '''

    class MaskedMeanIoU(tf.keras.metrics.Metric):
        def __init__(self, num_classes, name="masked_mean_iou", **kwargs):
            super(MaskedMeanIoU, self).__init__(name=name, **kwargs)
            self.num_classes = num_classes
            self.total_cm = self.add_weight(
                name="total_confusion_matrix",
                shape=(num_classes, num_classes),
                initializer="zeros",
                dtype=tf.float32,
            )

        def update_state(self, y_true, y_pred, sample_weight=None):
            y_true = tf.argmax(y_true, axis=-1)
            y_pred = tf.argmax(y_pred, axis=-1)

            # MASK: Ignore pixels labeled with 6 (your ignore class)
            mask = tf.not_equal(y_true, 6)
            y_true = tf.boolean_mask(y_true, mask)
            y_pred = tf.boolean_mask(y_pred, mask)

            current_cm = tf.math.confusion_matrix(y_true, y_pred, num_classes=self.num_classes, dtype=tf.float32)
            self.total_cm.assign_add(current_cm)

        def result(self):
            sum_over_row = tf.reduce_sum(self.total_cm, axis=0)
            sum_over_col = tf.reduce_sum(self.total_cm, axis=1)
            true_positives = tf.linalg.diag_part(self.total_cm)
            denominator = sum_over_row + sum_over_col - true_positives

            iou = tf.math.divide_no_nan(true_positives, denominator)
            return tf.reduce_mean(iou)

        def reset_states(self):
            tf.keras.backend.set_value(self.total_cm, tf.zeros((self.num_classes, self.num_classes)))

        
    miou_metric = MaskedMeanIoU(num_classes=NUM_CLASSES)
    metrics=[
        miou_metric,
        sm.metrics.IOUScore(threshold=None),
        sm.metrics.FScore(threshold=None),
        'categorical_accuracy',
        'accuracy'
    ]

    # --- Compile Model ---
    if model_type == "unet_aux":
        model.compile(
            optimizer=optimizer,
            loss={'main_output': total_loss_with_smoothing, 'aux_output': total_loss_with_smoothing},
            loss_weights={'main_output': 1.0, 'aux_output': 0.4},
            metrics={'main_output': miou_metric, 'aux_output': miou_metric}
        )

    elif model_type == "resnet34":
        model.compile(
            optimizer=optimizer,
            loss=total_loss_with_smoothing,
            metrics=metrics
        )

    else:
        model.compile(
            optimizer=optimizer,
            loss=total_loss_with_smoothing,
            metrics=metrics
        )


    model.summary()


    # --- Callbacks --- 
    class ClearMemory(tf.keras.callbacks.Callback):
        def on_train_batch_end(self, batch, logs=None):
            gc.collect()
            K.clear_session()

    class LearningRateLogger(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            lr = self.model.optimizer.learning_rate.numpy()
            print(f"📉 Learning Rate at epoch {epoch + 1}: {lr:.6f}")

    class TimeLimitCallback(tf.keras.callbacks.Callback):
        def __init__(self, max_minutes=20):
            super().__init__()
            self.max_duration = max_minutes * 60
        def on_train_begin(self, logs=None):
            self.start_time = tf.timestamp()
        def on_epoch_end(self, epoch, logs=None):
            elapsed = tf.timestamp() - self.start_time
            if elapsed > self.max_duration:
                print(f"⏱️ Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
                self.model.stop_training = True

    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor='val_mean_iou', patience=12, restore_best_weights=True, mode='max')
    nan_terminate = TerminateOnNaN()

    # Create both output folders
    os.makedirs("checkpoints", exist_ok=True)
    os.makedirs("/content/drive/MyDrive/segmentation_checkpoints", exist_ok=True)

    # Checkpoint in Colab workspace
    checkpoint_local = ModelCheckpoint(
        "checkpoints/best_model.h5",
        monitor='val_mean_iou',
        save_best_only=True
    )

    # Checkpoint in Google Drive
    checkpoint_drive = ModelCheckpoint(
        "/content/drive/MyDrive/segmentation_checkpoints/best_model.h5",
        monitor='val_mean_iou',
        save_best_only=True
    )

    CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
    import matplotlib.pyplot as plt

    class DistributionLogger(tf.keras.callbacks.Callback):
        def __init__(self, generator, name="Training", max_batches=16, visualise_samples=2):
            super().__init__()
            self.generator = generator
            self.name = name
            self.max_batches = max_batches
            self.visualise_samples = visualise_samples
            self.cumulative_class_counts = defaultdict(int)

        def on_epoch_end(self, epoch, logs=None):
            batch_class_counts = defaultdict(int)
            all_samples = []
            batches_seen = 0

            for batch_images, batch_labels in self.generator:
                if batches_seen >= self.max_batches:
                    break

                batch_preds = np.argmax(batch_labels, axis=-1)
                unique, counts = np.unique(batch_preds, return_counts=True)

                for u, c in zip(unique, counts):
                    batch_class_counts[u] += c
                    self.cumulative_class_counts[u] += c

                for img, label in zip(batch_images, batch_preds):
                    all_samples.append((img, label))

                batches_seen += 1

            total_pixels = sum(batch_class_counts.values())
            total_images = batches_seen * self.generator.batch_size

            print(f"📊 {self.name} Distribution After Epoch {epoch + 1} ({total_images:,} images):")
            for cls in sorted(batch_class_counts):
                count = batch_class_counts[cls]
                percent = 100.0 * count / total_pixels
                print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

            self._plot_random_samples(all_samples, epoch)

        def _plot_random_samples(self, all_samples, epoch):
            if len(all_samples) < self.visualise_samples:
                return

            samples_to_show = random.sample(all_samples, self.visualise_samples)
            fig, axs = plt.subplots(self.visualise_samples, 2, figsize=(8, 4 * self.visualise_samples))

            for i, (image, label) in enumerate(samples_to_show):
                img = (image * 255).astype(np.uint8)
                axs[i, 0].imshow(img if img.shape[-1] == 3 else img[:, :, 0], cmap='gray')
                axs[i, 0].set_title("Input Image")
                axs[i, 0].axis('off')

                label_rgb = np.zeros((label.shape[0], label.shape[1], 3), dtype=np.uint8)
                for class_id, color in CLASS_TO_COLOR.items():
                    label_rgb[label == class_id] = color
                axs[i, 1].imshow(label_rgb)
                axs[i, 1].set_title("Label")
                axs[i, 1].axis('off')

            plt.tight_layout()
            plt.suptitle(f"🔍 Sample Batches from Epoch {epoch + 1}", y=1.02)
            plt.show()

        def on_train_end(self, logs=None):
            total_pixels = sum(self.cumulative_class_counts.values())
            print("📊 Final Cumulative Training Class Distribution:")
            print(f"Total pixels: {total_pixels:,} px")
            for cls in sorted(self.cumulative_class_counts):
                count = self.cumulative_class_counts[cls]
                percent = 100.0 * count / total_pixels
                print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

            plot_class_distribution(self.cumulative_class_counts)


    class ValidationPredictionLogger(tf.keras.callbacks.Callback):
        def __init__(self, val_gen, model, max_batches=1):
            super().__init__()
            self.val_gen = val_gen
            self.model = model
            self.max_batches = max_batches

        def on_epoch_end(self, epoch, logs=None):
            batches_seen = 0
            for batch_images, batch_labels in self.val_gen:
                if batches_seen >= self.max_batches:
                    break
                preds = self.model.predict(batch_images)
                preds_argmax = np.argmax(preds, axis=-1)
                true_argmax = np.argmax(batch_labels, axis=-1)

                fig, axs = plt.subplots(len(batch_images), 3, figsize=(10, 3 * len(batch_images)))
                for i in range(len(batch_images)):
                    axs[i, 0].imshow((batch_images[i] * 255).astype(np.uint8))
                    axs[i, 0].set_title("Input")
                    axs[i, 0].axis('off')

                    true_rgb = np.zeros((*true_argmax[i].shape, 3), dtype=np.uint8)
                    pred_rgb = np.zeros((*preds_argmax[i].shape, 3), dtype=np.uint8)
                    for cid, col in CLASS_TO_COLOR.items():
                        true_rgb[true_argmax[i] == cid] = col
                        pred_rgb[preds_argmax[i] == cid] = col

                    axs[i, 1].imshow(true_rgb)
                    axs[i, 1].set_title("Ground Truth")
                    axs[i, 1].axis('off')
                    axs[i, 2].imshow(pred_rgb)
                    axs[i, 2].set_title("Prediction")
                    axs[i, 2].axis('off')

                plt.tight_layout()
                plt.suptitle(f"🔍 Validation Predictions After Epoch {epoch + 1}", y=1.02)
                plt.show()
                batches_seen += 1

    class StepTimer(tf.keras.callbacks.Callback):
        def on_train_begin(self, logs=None):
            self.total_time = 0.0
            self.total_steps = 0

        def on_train_batch_begin(self, batch, logs=None):
            self.start_time = tf.timestamp()

        def on_train_batch_end(self, batch, logs=None):
            elapsed = tf.timestamp() - self.start_time
            self.total_time += elapsed
            self.total_steps += 1

        def on_train_end(self, logs=None):
            avg_step_time = self.total_time / self.total_steps
            print(f"🕒 Average training step time: {avg_step_time:.4f} sec")


    callbacks = [
        checkpoint_local, checkpoint_drive, 
        early_stop, nan_terminate, time_limit, 
        ClearMemory(), 
        LearningRateLogger(),
        StepTimer(),
        DistributionLogger(train_gen, name="Training", max_batches=steps, visualise_samples=4),
        ValidationPredictionLogger(val_gen, model, max_batches=1)
    ]





    # --- Evaluate Model ---
    print("🚀 Starting training...")
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    from tqdm import trange
    
    # 📈 Plotting Training Curves for Mean IoU and val loss 
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(1, len(loss) + 1)
    plt.plot(epochs, loss, 'y', label="Training Loss")
    plt.plot(epochs, val_loss, 'r', label="Validation Loss")
    plt.title("Training Vs Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join(out_dir, "loss_plot.png"))
    plt.show()

    '''
    mean_iou = history.history['masked_mean_iou']
    val_mean_iou = history.history['val_masked_mean_iou']
    epochs = range(1, len(mean_iou) + 1)

    plt.plot(epochs, masked_mean_iou, 'b', label="Training mIoU")
    plt.plot(epochs, val_masked_mean_iou, 'g', label="Validation mIoU")
    plt.title("Training vs Validation Mean IoU")
    plt.xlabel("Epochs")
    plt.ylabel("mIoU")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, "mIoU_plot.png"))
    plt.show()
    '''

    # 📈 Plotting Random Outputs, Confusion Matrix, Classification Report and mIoU
    evaluate_on_test(model, test_gen, n_vis=9)

    
    def measure_inference_time(model, generator, num_batches=5):
        import time
        total_time = 0
        total_images = 0

        for i, (x_batch, _) in enumerate(generator):
            if i >= num_batches:
                break
            start = time.time()
            _ = model.predict(x_batch, verbose=0)
            end = time.time()
            total_time += (end - start)
            total_images += x_batch.shape[0]


    print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")
    measure_inference_time(model, test_gen, num_batches=steps)
    print("🚀 Training complete!")

















    # --- Archive ---

    '''
    class_weights = compute_class_weights(train_gen)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    def weighted_loss(y_true, y_pred):
        weights = tf.reduce_sum(class_weights * y_true, axis=-1)
        return tf.reduce_mean(weights * loss_fn(y_true, y_pred))

    model.compile(optimizer='adam', loss=weighted_loss, metrics=['accuracy'])
    print("✅ Model compiled with class weights")
    '''

    '''
    from segmentation_models.metrics import iou_score as jaccard_coef
    metrics = ["accuracy", jaccard_coef]
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
    
    dice_loss = dice_loss(class_weights = weights)
    focal_loss = categorical_focal_loss()
    total_loss = dice_loss + (1 * focal_loss)
    
    model.compile(optimizer="adam", loss="total_loss", metrics=metrics)
    '''

    '''
    def jaccard_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value
    '''