In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')



INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 5}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6 # Ignore pixel for visualisation
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6}  # Exclude ignore class
NUM_CLASSES = 6
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map


def filter_tile_ids_by_substring(image_dir, base_names):
    return [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if any(base in f for base in base_names)]



# --- Metrics ---
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * tf.pow(self.warmup_steps, -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "d_model": self.d_model.numpy(),  # convert back to Python float
            "warmup_steps": self.warmup_steps.numpy()
        }



# Loss function with label smoothing
def apply_label_smoothing(y_true, smoothing=0.1):
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)

os.environ["SM_FRAMEWORK"] = "tf.keras"

# --- Jaccard Index ---
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name="mean_iou", dtype=None):
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

miou_metric = MeanIoUMetric(num_classes=6)




# --- Model Building ---
def train_unet(
        base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
        input_type="rgb_elev", model_type="ENHANCED_unet", tile_size=256,
        batch_size=8, epochs=50, train_time=20, verbose=1, yummy=False, model_path=None,
    ):
    

    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', subset=0.5)
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')

    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)



    for x_batch, y_batch in test_gen.take(1):
        y_np = np.argmax(y_batch.numpy(), axis=-1)
        print("Shape of y_batch:", y_batch.shape)
        print("Unique labels in y batch:", np.unique(y_np))


    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):
        if model_type == "unet":
            model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "new_unet":
            model, base_model = build_flexible_unet(input_shape=input_shape, num_classes=NUM_CLASSES, freeze_rgb_encoder=False)

        elif model_type == "multi_unet":
            model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "enhanced_unet":
            model = enhanced_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "resnet34":
            model = sm.Unet(
                backbone_name="resnet34",               # or 'efficientnetb0', 'mobilenetv2', etc.
                input_shape=input_shape,
                classes=NUM_CLASSES,                  
                activation='softmax', 
                encoder_weights='imagenet'              # Load ImageNet pre-trained weights
            )

        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  

        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )

        for layer in model.layers:
            layer.trainable = True


    model.summary()

    '''  
    for i, layer in enumerate(model.layers):
        print(i, layer.name)
    '''

    # --- Callbacks --- 
        

    class DynamicClassWeightUpdater(tf.keras.callbacks.Callback):
        def __init__(self, val_data, update_every=5, target='f1', ignore_class=None):
            super().__init__()
            self.val_data = val_data
            self.update_every = update_every
            self.target = target  # 'f1' or 'iou'
            self.ignore_class = ignore_class

        def on_epoch_end(self, epoch, logs=None):
            if (epoch + 1) % self.update_every != 0:
                return

            y_true_all = []
            y_pred_all = []

            for x_batch, y_batch in self.val_data:
                preds = self.model.predict(x_batch, verbose=0)
                y_true = tf.argmax(y_batch, axis=-1).numpy().flatten()
                y_pred = tf.argmax(preds, axis=-1).numpy().flatten()

                y_true_all.extend(y_true)
                y_pred_all.extend(y_pred)

            y_true_all = np.array(y_true_all)
            y_pred_all = np.array(y_pred_all)

            new_weights = []

            for i in range(NUM_CLASSES):
                if self.ignore_class is not None and i == self.ignore_class:
                    new_weights.append(0.0)
                    continue

                if self.target == 'f1':
                    f1 = f1_score(y_true_all == i, y_pred_all == i, zero_division=0)
                    weight = 1.0 if f1 == 0 else 1.0 / f1
                else:
                    # IoU: intersection / union
                    intersection = np.logical_and(y_true_all == i, y_pred_all == i).sum()
                    union = (y_true_all == i).sum() + (y_pred_all == i).sum() - intersection
                    iou = intersection / union if union > 0 else 0.0
                    weight = 1.0 if iou == 0 else 1.0 / iou

                new_weights.append(weight)

            new_weights = np.array(new_weights, dtype=np.float32)
            new_weights = new_weights / new_weights.max()  # normalise

            print(f"\n📈 Updating class weights: {new_weights}\n")
            class_weights.assign(new_weights)





    #LearningRateLogger()
    monitor = "val_iou_score"
    #monitor = "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=25, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=6, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)

    weight_callback = DynamicClassWeightUpdater(val_data=val_gen, update_every=5, target='iou', ignore_class=4)
    
    callbacks = [
        reduce_lr,
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]

    optimizer = mixed_precision.LossScaleOptimizer(
        Adam(learning_rate=1e-4), dynamic=True
    )

        
    #weights = [0.23, 0.01, 0.04, 0.65, 0.0025, 0.36]                  # original
    #weights = [0.2, 0.01, 0.04, 0.65, 0.0005, 0.3]
    #weights = [0.4, 0.03, 0.015, 1.8, 0.001, 0.3]
    #weights = [0.27, 0.01, 0.04, 0.99, 0.0005, 0.2] 

    label_smoothing = 0.1
    loss_weights = [0.05, 1.8, 1.0]

    class_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
    class_weights = tf.Variable([1.0] * NUM_CLASSES, trainable=False, dtype=tf.float32)
    class_weights = [1.05, 0.9, 1.0, 1.3, 0.0, 1.0]

    raw_dice = sm.losses.DiceLoss(class_weights=class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss()
    raw_cce = CategoricalCrossentropy()


    def apply_ignore_class_mask(y_true, y_pred, ignore_class=4, loss_fn=None):
        # y_true, y_pred shape: (batch, h, w, num_classes)
        class_ids = tf.argmax(y_true, axis=-1)  # shape: (batch, h, w)
        mask = tf.not_equal(class_ids, ignore_class)  # shape: (batch, h, w)

        mask = tf.cast(mask, tf.float32)  # same shape as class_ids
        mask = tf.expand_dims(mask, axis=-1)  # shape: (batch, h, w, 1)

        # Apply mask to loss
        loss = loss_fn(y_true, y_pred)  # shape: (batch, h, w, 1) or scalar

        # If loss is scalar (e.g. averaged), convert to pixelwise
        if len(loss.shape) < 4:
            return loss
        masked_loss = loss * mask
        return tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)


    def total_loss(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        base_loss = (
            loss_weights[0] * cce +
            loss_weights[1] * dice +
            loss_weights[2] * focal
        )

        return apply_ignore_class_mask(y_true_smoothed, y_pred, ignore_class=4, loss_fn=lambda yt, yp: base_loss)


    model.compile(
        optimizer=optimizer,
        loss=total_loss,
        metrics=metrics
    )

    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )


    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=12, n_cols=3) 
    measure_inference_time(model, test_gen, num_batches=5)
 
    
    # --- Gangster Shit ---
    if yummy:
        img, label, pred = reconstruct_canvas(model, test_df, "25f1c24f30_EB81FE6E2BOPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "25f1c24f30_EB81FE6E2BOPENPIPELINE")

        img, label, pred = reconstruct_canvas(model, test_df, "84410645db_8D20F02042OPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "84410645db_8D20F02042OPENPIPELINE")

        img, label, pred = reconstruct_canvas(model, test_df, "8710b98ea0_06E6522D6DINSPIRE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "8710b98ea0_06E6522D6DINSPIRE")

        img, label, pred = reconstruct_canvas(model, test_df, "a1af86939f_F1BE1D4184OPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "a1af86939f_F1BE1D4184OPENPIPELINE")







def train_segformer(
        base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
        input_type="rgb_elev", model_type="ENHANCED_unet", tile_size=256,
        batch_size=8, epochs=50, train_time=20, verbose=1, fine_tune=False, yummy=False, model_path=None,
    ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train')
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')


    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)

    for x_batch, y_batch in test_gen.take(1):
        y_np = np.argmax(y_batch.numpy(), axis=-1)
        print("Unique labels in y batch:", np.unique(y_np))


    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):

        if model_type == "B2":
            model = SegFormer_B2(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B0":
            model = SegFormer_B0(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "B5":
            model = SegFormer_B5(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B4":
            model = SegFormer_B4(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B1":
            model = SegFormer_B1(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B3":
            model = SegFormer_B3(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )
        for layer in model.layers:
            layer.trainable = True


    model.summary()

    for i, layer in enumerate(model.layers):
        print(i, layer.name)


    # --- Callbacks --- 
    #LearningRateLogger()
    #monitor = "val_iou_score"
    monitor = "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=20, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=6, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)

    lr_schedule = TransformerLRSchedule(d_model=tile_size, warmup_steps=2048)
    callbacks = [
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]


    loss_weights = [0.25, 1.5, 1.0]
    class_weights = [1.1, 1.0, 1.0, 2.0, 0.0, 1.3]
    label_smoothing = 0.1

    raw_dice = sm.losses.DiceLoss(class_weights=class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss()
    raw_cce = CategoricalCrossentropy()


    def total_loss_with_smoothing(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        return loss_weights[0] * cce + loss_weights[1] * dice + loss_weights[2] * focal



    # --- Compile Model and Train ---
    if fine_tune:
        
        water_df = balanced_stage1_filter(test_df)

        frozen_train_gen = build_tf_dataset(water_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='train',
                                augment=True, shuffle=True, batch_size=batch_size)
    
        # ✅ Phase 1: Freeze encoder (pretrained backbone)
        stage1_epochs = 40
        N = 40

        # Freeze first N layers
        for layer in model.layers[:N]:
            layer.trainable = False
        for layer in model.layers[N:]:
            layer.trainable = True

        # 🔧 Compile model (must compile after changing layer.trainable)
        model.compile(
            optimizer=Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.98, epsilon=1e-7),
            loss=total_loss_with_smoothing,
            metrics=metrics
        )

        # 🏋️‍♂️ Train for a few epochs (warm-up decoder only)
        model.fit(
            frozen_train_gen,
            validation_data=val_gen,
            epochs=stage1_epochs,
            callbacks=(reduce_lr, nan_terminate)    
        )

        # ✅ Phase 2: Unfreeze encoder
        for layer in model.layers:
            layer.trainable = True

        # 🔧 Re-compile again after unfreezing
        model.compile(
            optimizer=Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9), 
            loss=total_loss_with_smoothing,
            metrics=metrics
        )

        # 🏋️‍♂️ Continue training with fine-tuning
        history = model.fit(
            train_gen,
            validation_data=val_gen,
            initial_epoch=stage1_epochs,
            epochs=epochs,
            callbacks=callbacks
        )

    else:
        model.compile(
            optimizer=Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
            loss=total_loss_with_smoothing,
            metrics=metrics
        )
        
        history = model.fit(
            train_gen, validation_data=val_gen,
            epochs=epochs,
            callbacks=callbacks,
            verbose=verbose
        )


    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=12, n_cols=3) 
    measure_inference_time(model, test_gen, num_batches=5)
 
    
    # --- Gangster Shit ---
    if yummy:
        img, label, pred = reconstruct_canvas(model, test_df, "25f1c24f30_EB81FE6E2BOPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "25f1c24f30_EB81FE6E2BOPENPIPELINE")

        img, label, pred = reconstruct_canvas(model, test_df, "84410645db_8D20F02042OPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "84410645db_8D20F02042OPENPIPELINE")

        img, label, pred = reconstruct_canvas(model, test_df, "8710b98ea0_06E6522D6DINSPIRE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "8710b98ea0_06E6522D6DINSPIRE")

        img, label, pred = reconstruct_canvas(model, test_df, "a1af86939f_F1BE1D4184OPENPIPELINE", build_tf_dataset, img_dir, elev_dir, slope_dir, label_dir)
        plot_reconstruction(img, label, pred, "a1af86939f_F1BE1D4184OPENPIPELINE")



