In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict



INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6 # Ignore pixel for visualisation
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6}  # Exclude ignore class
NUM_CLASSES = 6
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map


def filter_tile_ids_by_substring(image_dir, base_names):
    return [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if any(base in f for base in base_names)]



def measure_inference_time(model, generator, num_batches=5):
    import time
    total_time = 0
    total_images = 0

    for i, (x_batch, _) in enumerate(generator.take(num_batches)):
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")




import matplotlib.pyplot as plt
import os

def plot_training_curves(history, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Define metric pairs to plot
    metric_pairs = [
        ("loss", "val_loss"),
        ("f1-score", "val_f1-score"),
        ("iou_score", "val_iou_score"),
        ("categorical_accuracy", "val_categorical_accuracy")
    ]

    for train_metric, val_metric in metric_pairs:
        if train_metric not in history.history or val_metric not in history.history:
            print(f"⚠️ Skipping {train_metric} — missing in history.")
            continue

        plt.figure()
        plt.plot(history.history[train_metric], label=f"Train {train_metric}")
        plt.plot(history.history[val_metric], label=f"Val {val_metric}")
        plt.title(f"{train_metric} over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel(train_metric)
        plt.legend()
        filename = os.path.join(out_dir, f"{train_metric}_plot.png")
        plt.savefig(filename)
        plt.close()
        print(f"✅ Saved {filename}")



# --- Metrics ---
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TerminateOnNaN

# Custom learning rate schedule
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * tf.pow(self.warmup_steps, -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "d_model": self.d_model.numpy(),  # convert back to Python float
            "warmup_steps": self.warmup_steps.numpy()
        }



# Loss function with label smoothing
def apply_label_smoothing(y_true, smoothing=0.1):
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)


os.environ["SM_FRAMEWORK"] = "tf.keras"

# Set class weights
weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]

# Raw losses from segmentation_models
raw_dice = sm.losses.DiceLoss(class_weights=weights)
raw_focal = sm.losses.CategoricalFocalLoss()

# Final loss function with label smoothing applied
def total_loss_with_smoothing(y_true, y_pred):
    y_true_smoothed = apply_label_smoothing(y_true, smoothing=0.1)
    return raw_dice(y_true_smoothed, y_pred) + raw_focal(y_true_smoothed, y_pred)



import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU

# --- Jaccard Index ---

class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name="mean_iou", dtype=None):
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

miou_metric = MeanIoUMetric(num_classes=6)




# --- Model Building ---
def train_model(base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
                freeze=False, model_path=None,
                input_type="rgb_elev", model_type="unet", tile_size=256,
                batch_size=8, steps=None, epochs=10, train_time=20, verbose=1
                ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', 256)
    val_df = csv_to_df('val', 256)
    test_df = csv_to_df('test', 256)

    
    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)

    
    for x_batch, y_batch in test_gen.take(1):
        y_np = np.argmax(y_batch.numpy(), axis=-1)
        print("Unique labels in batch:", np.unique(y_np))


    # --- Model ---
    import segmentation_models as sm

    if model_path is None or not os.path.exists(model_path):
        if model_type == "unet":
            print("🧪 Calling build_unet...")
            model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "multi_unet":
            print("🧪 Calling build_multi_unet...")
            model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "unet_aux":
            print("🧪 Calling build_multi_unet_aux...")
            model = build_unet_aux(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B2":
            model = SegFormer_B2(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B0":
            model = SegFormer_B0(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "B5":
            model = SegFormer_B5(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B4":
            model = SegFormer_B4(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B1":
            model = SegFormer_B1(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "resnet34":
            model = sm.Unet(
                backbone_name="resnet34",               # or 'efficientnetb0', 'mobilenetv2', etc.
                input_shape=input_shape,
                classes=NUM_CLASSES,                  
                activation='softmax', 
                encoder_weights='imagenet'              # Load ImageNet pre-trained weights
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")
    
    else:
        print(f"Loading Stage 1 model from {model_path} and unfreezing all layers...")

        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }
        
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )

        for layer in model.layers:
            layer.trainable = True


    #lr_schedule = TransformerLRSchedule(d_model=tile_size)
    #optimizer = Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
    optimizer = Adam(learning_rate=5e-3)
    miou_metric = MeanIoUMetric(num_classes=6)

    metrics=[
        miou_metric,
        sm.metrics.IOUScore(threshold=None),
        sm.metrics.FScore(threshold=None),
        'categorical_accuracy',
    ]

    if freeze:
        print("Stage 1: Freezing all layers except head...")
        for layer in model.layers:
            layer.trainable = False
        for layer in model.layers[-10:]:
            layer.trainable = True

        optimizer = Adam(learning_rate=1e-3)


    # --- Compile Model ---
    model.compile(
        optimizer=optimizer,
        loss=total_loss_with_smoothing,
        metrics=metrics
    )

    model.summary()


    # --- Callbacks --- 
    from tensorflow.keras.callbacks import ReduceLROnPlateau
    
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor='val_iou_score', patience=32, restore_best_weights=True, mode='max')
    reduce_lr = ReduceLROnPlateau(
        monitor='val_iou_score',      # or 'val_mean_iou' or any other metric you use
        factor=0.5,        
        patience=12,             
        verbose=1,               
        min_lr=5e-6              # don’t go below 1e-6
    )

    nan_terminate = TerminateOnNaN()
    dual_ckpt = DualCheckpointSaver(
        base_model=model,
        monitor='val_iou_score',
        mode='max'
    )


    callbacks = [
        reduce_lr,
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
        dual_ckpt
    ]

    #LearningRateLogger(),
    #ValidationPredictionLogger(val_gen, model, max_batches=1),

    # --- Train Model ---
    print("🚀 Starting training...")
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )


    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_vis=5)
    measure_inference_time(model, test_gen, num_batches=steps)
    print("🚀 Training complete!")








    # --- Archive ---

    '''
    class_weights = compute_class_weights(train_gen)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    def weighted_loss(y_true, y_pred):
        weights = tf.reduce_sum(class_weights * y_true, axis=-1)
        return tf.reduce_mean(weights * loss_fn(y_true, y_pred))

    model.compile(optimizer='adam', loss=weighted_loss, metrics=['accuracy'])
    print("✅ Model compiled with class weights")
    '''

    '''
    from segmentation_models.metrics import iou_score as jaccard_coef
    metrics = ["accuracy", jaccard_coef]
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
    
    dice_loss = dice_loss(class_weights = weights)
    focal_loss = categorical_focal_loss()
    total_loss = dice_loss + (1 * focal_loss)
    
    model.compile(optimizer="adam", loss="total_loss", metrics=metrics)
    '''

    '''
    def jaccard_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value
    '''