In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import backend as K
K.clear_session()




class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    """Custom learning rate schedule based on the Transformer paper.

    The schedule increases the learning rate linearly for the first `warmup_steps`,
    and then decreases it proportionally to the inverse square root of the step number.

    This is commonly used in training Transformer models.

    Attributes:
        d_model (tf.Tensor): The dimensionality of the model.
        warmup_steps (tf.Tensor): Number of steps to linearly increase the learning rate.
    """

    def __init__(self, d_model: int, warmup_steps: int = 4000):
        """Initialises the TransformerLRSchedule.

        Args:
            d_model (int): The model dimensionality (e.g., 512).
            warmup_steps (int): Number of warm-up steps. Default is 4000.
        """
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step: tf.Tensor) -> tf.Tensor:
        """Computes the learning rate at a given training step.

        Args:
            step (tf.Tensor): The current training step.

        Returns:
            tf.Tensor: The calculated learning rate for this step.
        """
        step = tf.cast(step, tf.float32)

        # Inverse square root decay and warmup scaling
        arg1 = tf.math.rsqrt(step)
        arg2 = step * tf.pow(self.warmup_steps, -1.5)

        # Apply the min schedule
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self) -> dict:
        """Returns the config of the learning rate schedule for serialization.

        Returns:
            dict: Configuration dictionary with d_model and warmup_steps.
        """
        return {
            "d_model": self.d_model.numpy(),
            "warmup_steps": self.warmup_steps.numpy(),
        }



INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 5}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6 # Ignore pixel for visualisation
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6}  # Exclude ignore class
NUM_CLASSES = 6
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img: np.ndarray) -> np.ndarray:
    """Converts a colour-coded label image into a class ID map.

    Args:
        label_img (np.ndarray): A (H, W, 3) RGB label image where each unique colour
            represents a class, and colours are defined in COLOR_LOOKUP.

    Returns:
        np.ndarray: A (H, W) array of class IDs corresponding to each pixel.

    Raises:
        ValueError: If an unknown colour is encountered in the label image.
    """
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)

    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]

    return label_map



def plot_augmented_grid_from_dataset(
    tf_dataset: tf.data.Dataset,
    input_type: str,  # 'rgb' or 'rgb_elev'
    n_rows: int = 3,
    n_cols: int = 4,
    title: str = "Augmented Training Chips"
):
    """
    Plots a grid of RGB + label masks from a tf.data.Dataset, using CLASS_TO_COLOR for display.
    Layout matches the style of visualise_prediction_grid, but without predictions.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    print(f"Fetching one batch for {n_rows * n_cols} RGB + label pairs...")

    try:
        batch = next(iter(tf_dataset.take(1)), None)

        if batch is None:
            print("Warning: Dataset is empty. Skipping plot.")
            return

        images, labels_one_hot = batch
        rgb_images = images[:, :, :, :3].numpy() if input_type == 'rgb_elev' else images.numpy()
        rgb_images = np.clip(rgb_images, 0.0, 1.0)

        label_ids = tf.argmax(labels_one_hot, axis=-1).numpy()
        ignore_mask = tf.reduce_all(labels_one_hot == 0, axis=-1).numpy()

        total = n_rows * n_cols
        batch_size = rgb_images.shape[0]
        total = min(total, batch_size)

        fig, axs = plt.subplots(n_rows, n_cols * 2, figsize=(n_cols * 5.6, n_rows * 2.6))

        for i in range(total):
            rgb = rgb_images[i]
            mask = label_ids[i]
            ignore = ignore_mask[i]

            # Create RGB mask image from class IDs
            label_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
            for class_id, color in CLASS_TO_COLOR.items():
                label_rgb[mask == class_id] = color
            label_rgb[ignore] = (255, 0, 255)  # Magenta for ignored pixels

            row = i // n_cols
            col = (i % n_cols) * 2

            axs[row, col].imshow(rgb)
            axs[row, col].set_title("RGB")
            axs[row, col].axis("off")

            axs[row, col + 1].imshow(label_rgb)
            axs[row, col + 1].set_title("Label")
            axs[row, col + 1].axis("off")

        # Hide any unused axes
        for j in range(total, n_rows * n_cols):
            row = j // n_cols
            col = (j % n_cols) * 2
            axs[row, col].axis("off")
            axs[row, col + 1].axis("off")

        #plt.suptitle(title, fontsize=16, y=1.02)
        plt.tight_layout()
        plt.show()
        plt.close(fig)

    except Exception as e:
        print(f"Error during plotting: {e}")



# --- Metrics ---
# Loss function with label smoothing
def apply_label_smoothing(y_true, smoothing=0.1):
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)

os.environ["SM_FRAMEWORK"] = "tf.keras"

# --- Jaccard Index ---
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name="mean_iou", dtype=None):
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

miou_metric = MeanIoUMetric(num_classes=6)




# --- Model Building ---
def train_unet(
        base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
        input_type="rgb_elev", model_type="enhanced_unet", tile_size=256,
        batch_size=8, epochs=50, train_time=20, verbose=1, yummy=False, model_path=None,
    ):
    

    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', subset=1.0)
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')

    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)



    # To visualize augmented training examples:
    plot_augmented_grid_from_dataset(
        tf_dataset=train_gen,
        input_type='rgb', 
    )



    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):
        if model_type == "unet":
            model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "new_unet":
            model, base_model = build_flexible_unet(input_shape=input_shape, num_classes=NUM_CLASSES, freeze_rgb_encoder=False)

        elif model_type == "multi_unet":
            model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "enhanced_unet":
            model = enhanced_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "resnet34":
            model = sm.Unet(
                backbone_name="resnet34",               # or 'efficientnetb0', 'mobilenetv2', etc.
                input_shape=input_shape,
                classes=NUM_CLASSES,                  
                activation='softmax', 
                encoder_weights='imagenet'              # Load ImageNet pre-trained weights
            )

        else:
            raise ValueError(f"Unknown model_type: {model_type}")
    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  

        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )

        for layer in model.layers:
            layer.trainable = True

    model.summary()
    print(f"Number of Parameters: {model.count_params()}\n"
          f"Number of Layers: {len(model.layers)}\n")

    '''
    for i, layer in enumerate(model.layers):
        print(i, layer.name)
    '''

    # --- Callbacks --- 
    monitor = "val_iou_score"       # "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=10, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=5, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)
    weight_callback = DynamicClassWeightUpdater(val_data=val_gen, update_every=5, target='iou', ignore_class=4)
    #LearningRateLogger()

    callbacks = [
        reduce_lr,
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]

    learning_rate = 5.2e-4
    label_smoothing = 0.01
    loss_weights = [0.25, 1.5, 2.2]                    # [cce, dice, focal]
    class_weights = [4.95, 3.2, 1.0, 11.6, 3.0, 2.6]              # [building, clutter, vegetation, water, background, car]
    
    # Normalize class weights
    total = sum(class_weights)
    norm_class_weights = [w / total for w in class_weights]


    #focal_alpha = 0.25
    focal_gamma = 3.5
    raw_dice = sm.losses.DiceLoss(class_weights=norm_class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss(
        alpha=norm_class_weights,
        gamma=focal_gamma,
    )
    raw_cce = CategoricalCrossentropy()

    optimizer = mixed_precision.LossScaleOptimizer(
        Adam(learning_rate=learning_rate), dynamic=True
    )


    def total_loss(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        base_loss = (
            loss_weights[0] * cce +
            loss_weights[1] * dice +
            loss_weights[2] * focal
        )

        #return apply_ignore_class_mask(y_true_smoothed, y_pred, ignore_class=4, loss_fn=lambda yt, yp: base_loss)
        return base_loss



    # --- Train Model ---\
    import time
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print("🕒 Start Time:", timestamp)

    # ⏱️ Start training timer
    start_time = time.time()



    # Single Stage Training
    model.compile(
        optimizer=optimizer,
        loss=total_loss,
        metrics=metrics
    )

    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    # --- Save Model ---
    os.makedirs("/content/figs", exist_ok=True)
    model.save("/content/figs/segmentation_model.keras")


    # ⏱️ End training timer
    end_time = time.time()
    duration_sec = end_time - start_time
    duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_sec))
    print(f"\n✅ Training complete in {duration_str} ({duration_sec:.2f} seconds)")


 
    '''
    # Two Stage Training
    if fine_tune:
        hard_df = csv_to_hard_df()
        train_hard = build_tf_dataset(hard_df, img_dir, elev_dir, slope_dir, label_dir,
                                    input_type=input_type, split='train',
                                    augment=True, shuffle=True, batch_size=batch_size)


        print("Training on hard examples...")
        optimizer = mixed_precision.LossScaleOptimizer(
            Adam(learning_rate=1e-5), dynamic=True
        )

        model.compile(
            optimizer=optimizer,
            loss=total_loss,
            metrics=metrics
        )

        history = model.fit(
            train_hard, validation_data=val_gen,
            epochs=history.epoch[-1] + 16,
            initial_epoch=history.epoch[-1] + 1,
            callbacks=callbacks,
            verbose=verbose
        )

    '''


    def safe_get(var_name, default="❌ Not Defined"):
        return globals().get(var_name, default)

    def safe_get_history(history_dict, key):
        return history_dict.get(key, ["N/A"])[-1] if key in history_dict else "N/A"

    # --- Evaluate Model ---
    print(f"Initial Learning Rate: {safe_get('learning_rate')}\n"
        f"Loss Weights: {safe_get('loss_weights')}, Class Weights: {safe_get('class_weights')}\n"
        f"Focal Loss Gamma: {safe_get('focal_gamma')}\n"
        f"Label Smoothing: {safe_get('label_smoothing')}\n"
        f"Input Type: {safe_get('input_type')}, Model Type: {safe_get('model_type')}\n"
        f"Batch Size: {safe_get('batch_size')}, "
        f"Epochs: {history.epoch[-1] + 1 if 'history' in globals() else 'N/A'}\n"
        f"Number of Parameters: {model.count_params() if 'model' in globals() else 'N/A'}, "
        f"Number of Layers: {len(model.layers) if 'model' in globals() else 'N/A'}\n"
        f"Final Validation Loss: {safe_get_history(history.history, 'val_loss'):.4f}\n"
        f"Final Validation mIoU: {safe_get_history(history.history, 'val_iou_score'):.4f}\n"
        f"Final Validation F1 Score: {safe_get_history(history.history, 'val_f1-score'):.4f}\n"
        f"Final Validation Categorical Accuracy: {safe_get_history(history.history, 'categorical_accuracy'):.4f}\n")


    measure_inference_time(model, test_gen, num_batches=5)
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=4, n_cols=3)


     

    test_files = [
        "25f1c24f30_EB81FE6E2BOPENPIPELINE",
        "1d4fbe33f3_F1BE1D4184INSPIRE",
        "15efe45820_D95DF0B1F4INSPIRE",
        "c6d131e346_536DE05ED2OPENPIPELINE",
        "12fa5e614f_53197F206FOPENPIPELINE",
        "5fa39d6378_DB9FF730D9OPENPIPELINE",
        "ebffe540d0_7BA042D858OPENPIPELINE",
        "8710b98ea0_06E6522D6DINSPIRE",
        "84410645db_8D20F02042OPENPIPELINE",
        "a1af86939f_F1BE1D4184OPENPIPELINE"
    ]


    # --- Gangster Shit ---
    if yummy:
        for tile_prefix in test_files:
            img, label, pred = reconstruct_canvas(
                model,
                test_df,
                tile_prefix,
                build_tf_dataset,
                img_dir,
                elev_dir,
                slope_dir,
                label_dir
            )
            plot_reconstruction(img, label, pred, tile_prefix)





def train_segformer(
        base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
        input_type="rgb_elev", model_type="ENHANCED_unet", tile_size=256,
        batch_size=8, epochs=50, train_time=20, verbose=1, yummy=False, model_path=None,
    ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', 0.4)
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')


    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)

    for x_batch, y_batch in test_gen.take(1):
        y_np = np.argmax(y_batch.numpy(), axis=-1)
        print("Unique labels in y batch:", np.unique(y_np))


    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):

        if model_type == "B2":
            model = SegFormer_B2(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B0":
            model = SegFormer_B0(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "B5":
            model = SegFormer_B5(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B4":
            model = SegFormer_B4(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B1":
            model = SegFormer_B1(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B3":
            model = SegFormer_B3(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )
        for layer in model.layers:
            layer.trainable = True


    model.summary()

    for i, layer in enumerate(model.layers):
        print(i, layer.name)


    # --- Callbacks --- 
    #LearningRateLogger()
    #monitor = "val_iou_score"
    monitor = "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=20, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=6, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)

    lr_schedule = TransformerLRSchedule(d_model=tile_size, warmup_steps=2048)
    callbacks = [
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]


    loss_weights = [0.25, 1.5, 1.0]                 # [cce, dice, focal]
    class_weights = [1.1, 1.0, 1.0, 2.0, 0.0, 1.3]  # [building, clutter, vegetation, water, background, car]
    label_smoothing = 0.1

    raw_dice = sm.losses.DiceLoss(class_weights=class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss()
    raw_cce = CategoricalCrossentropy()


    def total_loss_with_smoothing(y_true, y_pred):
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        return loss_weights[0] * cce + loss_weights[1] * dice + loss_weights[2] * focal


    model.compile(
        optimizer=Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
        loss=total_loss_with_smoothing,
        metrics=metrics
    )
    
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )


    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=4, n_cols=3) 
    measure_inference_time(model, test_gen, num_batches=5)
 

    # --- Gangster Shit ---
    if yummy:
        for tile_prefix in test_files:
            img, label, pred = reconstruct_canvas(
                model,
                test_df,
                tile_prefix,
                build_tf_dataset,
                img_dir,
                elev_dir,
                slope_dir,
                label_dir
            )
            plot_reconstruction(img, label, pred, tile_prefix)


