In [None]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict
from tensorflow.keras.metrics import MeanIoU
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import backend as K
K.clear_session()

# Set segmentation models to use tf.keras backend
os.environ["SM_FRAMEWORK"] = "tf.keras"



# --- Model Building --- 
miou_metric = MeanIoUMetric(num_classes=6)

# --- U-Net Model ---

def train_unet(
    base_dir: str = "/content/chipped_data/content/chipped_data",
    out_dir: str = "/content/figs",
    input_type: str = "rgb_elev",
    model_type: str = "enhanced_unet",
    tile_size: int = 256,
    batch_size: int = 8,
    epochs: int = 50,
    train_time: int = 20,
    verbose: int = 1,
    yummy: bool = False,
    model_path: str = None,
):
    """Trains a semantic segmentation model using a specified configuration.

    Args:
        base_dir (str): Path to the base data directory.
        out_dir (str): Output directory to save plots and the model.
        input_type (str): Input configuration, e.g., 'rgb' or 'rgb_elev'.
        model_type (str): Type of model to build (e.g., 'unet', 'resnet34').
        tile_size (int): Width and height of each input tile in pixels.
        batch_size (int): Batch size for training.
        epochs (int): Maximum number of training epochs.
        train_time (int): Maximum training time in minutes.
        verbose (int): Verbosity level for training output.
        yummy (bool): Whether to plot full-tile predictions after training.
        model_path (str): Optional path to a pretrained model to resume training.
    """

    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', subset=1.0)
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')

    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)


    # To visualize augmented training examples:
    plot_augmented_grid_from_dataset(
        tf_dataset=train_gen,
        input_type='rgb', 
    )


    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):
        if model_type == "unet":
            model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "new_unet":
            model, base_model = build_flexible_unet(input_shape=input_shape, num_classes=NUM_CLASSES, freeze_rgb_encoder=False)

        elif model_type == "multi_unet":
            model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "enhanced_unet":
            model = enhanced_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "resnet34":
            model = sm.Unet(
                backbone_name="resnet34",               # or 'efficientnetb0', 'mobilenetv2', etc.
                input_shape=input_shape,
                classes=NUM_CLASSES,                  
                activation='softmax', 
                encoder_weights='imagenet'              # Load ImageNet pre-trained weights
            )

        else:
            raise ValueError(f"Unknown model_type: {model_type}")
    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  

        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )

        for layer in model.layers:
            layer.trainable = True

    model.summary()
    print(f"Number of Parameters: {model.count_params()}\n"
          f"Number of Layers: {len(model.layers)}\n")


    # --- Callbacks --- 
    monitor = "val_iou_score"       # "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=10, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=5, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)
    #weight_callback = DynamicClassWeightUpdater(val_data=val_gen, update_every=5, target='iou', ignore_class=4)
    #LearningRateLogger()

    callbacks = [
        reduce_lr,
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]

    learning_rate = 5.2e-4
    label_smoothing = 0.075
    loss_weights = [0.25, 1.5, 2.25]                    # [cce, dice, focal]
    class_weights = [6.95, 3.3, 0.3, 12.5, 4.0, 2.6]    # [building, clutter, vegetation, water, background, car]

    # Normalize class weights
    total = sum(class_weights)
    norm_class_weights = [w / total for w in class_weights]

    focal_gamma = 4.0
    raw_dice = sm.losses.DiceLoss(class_weights=norm_class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss(
        alpha=norm_class_weights,
        gamma=focal_gamma,
    )
    raw_cce = CategoricalCrossentropy()

    optimizer = mixed_precision.LossScaleOptimizer(
        Adam(learning_rate=learning_rate), dynamic=True
    )


    def total_loss(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """Computes the total weighted loss from CCE, Dice, and Focal losses."""
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        base_loss = (
            loss_weights[0] * cce +
            loss_weights[1] * dice +
            loss_weights[2] * focal
        )

        #return apply_ignore_class_mask(y_true_smoothed, y_pred, ignore_class=4, loss_fn=lambda yt, yp: base_loss)
        return base_loss


    # --- Train Model ---
    # ⏱️ Start training timer
    import time
    from datetime import datetime

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    print("🕒 Start Time:", timestamp)
    start_time = time.time()


    # Single Stage Training
    model.compile(
        optimizer=optimizer,
        loss=total_loss,
        metrics=metrics
    )

    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    # --- Save Model ---
    os.makedirs("/content/figs", exist_ok=True)
    model.save("/content/figs/segmentation_model.keras")

    # Save to OneDrive (Google Drive in Colab)
    os.makedirs("/content/drive/MyDrive", exist_ok=True)
    model.save("/content/drive/MyDrive/segmentation_model.keras")

    # ⏱️ End training timer
    end_time = time.time()
    duration_sec = end_time - start_time
    duration_str = time.strftime('%H:%M:%S', time.gmtime(duration_sec))
    print(f"\n✅ Training complete in {duration_str} ({duration_sec:.2f} seconds)")


    '''
    # Two Stage Training
    if fine_tune:
        hard_df = csv_to_hard_df()
        train_hard = build_tf_dataset(hard_df, img_dir, elev_dir, slope_dir, label_dir,
                                    input_type=input_type, split='train',
                                    augment=True, shuffle=True, batch_size=batch_size)


        print("Training on hard examples...")
        optimizer = mixed_precision.LossScaleOptimizer(
            Adam(learning_rate=1e-5), dynamic=True
        )

        model.compile(
            optimizer=optimizer,
            loss=total_loss,
            metrics=metrics
        )

        history = model.fit(
            train_hard, validation_data=val_gen,
            epochs=history.epoch[-1] + 16,
            initial_epoch=history.epoch[-1] + 1,
            callbacks=callbacks,
            verbose=verbose
        )

    '''


    def safe_get(var_name, context, default="❌ Not Defined"):
        """Safely retrieve a variable from the local context."""
        return context.get(var_name, default)

    def safe_get_history(history_dict, key):
        """Safely retrieve the last value of a metric from the history dict."""
        return history_dict.get(key, ["N/A"])[-1] if key in history_dict else "N/A"
    local_vars = locals()


    print(f"Initial Learning Rate: {safe_get('learning_rate', local_vars)}\n"
        f"Loss Weights: {safe_get('loss_weights', local_vars)}, "
        f"Class Weights: {safe_get('class_weights', local_vars)}\n"
        f"Focal Loss Gamma: {safe_get('focal_gamma', local_vars)}\n"
        f"Label Smoothing: {safe_get('label_smoothing', local_vars)}\n"
        f"Input Type: {safe_get('input_type', local_vars)}, "
        f"Model Type: {safe_get('model_type', local_vars)}\n"
        f"Batch Size: {safe_get('batch_size', local_vars)}, "
        f"Epochs: {history.epoch[-1] + 1}\n"
        f"Number of Parameters: {model.count_params()}, "
        f"Number of Layers: {len(model.layers)}\n"
        f"Final Validation Loss: {safe_get_history(history.history, 'val_loss'):.4f}\n"
        f"Final Validation mIoU: {safe_get_history(history.history, 'val_iou_score'):.4f}\n"
        f"Final Validation F1 Score: {safe_get_history(history.history, 'val_f1-score'):.4f}\n"
        f"Final Validation Categorical Accuracy: {safe_get_history(history.history, 'categorical_accuracy'):.4f}\n")


    measure_inference_time(model, test_gen, num_batches=5)
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=8, n_cols=3)


    # --- Full-Tile Reconstruction (Optional) ---
    if yummy:
        for tile_prefix in test_files:
            img, label, pred = reconstruct_canvas(
                model,
                test_df,
                tile_prefix,
                build_tf_dataset,
                img_dir,
                elev_dir,
                slope_dir,
                label_dir
            )
            plot_reconstruction(img, label, pred, tile_prefix)





# --- SegFormer Training ---

def train_segformer(
    base_dir: str = "/content/chipped_data/content/chipped_data",
    out_dir: str = "/content/figs",
    input_type: str = "rgb",
    model_type: str = "B1",
    tile_size: int = 256,
    batch_size: int = 8,
    epochs: int = 50,
    train_time: int = 60,
    verbose: int = 1,
    yummy: bool = False,
    model_path: str = None,
):
    """Trains a semantic segmentation model using a specified configuration.

    Args:
        base_dir (str): Path to the base data directory.
        out_dir (str): Output directory to save plots and the model.
        input_type (str): Input configuration, e.g., 'rgb' or 'rgb_elev'.
        model_type (str): Type of model to build (e.g., 'B0', 'B1', 'B2' or 'B3').
        tile_size (int): Width and height of each input tile in pixels.
        batch_size (int): Batch size for training.
        epochs (int): Maximum number of training epochs.
        train_time (int): Maximum training time in minutes.
        verbose (int): Verbosity level for training output.
        yummy (bool): Whether to plot full-tile predictions after training.
        model_path (str): Optional path to a pretrained model to resume training.
    """

    assert model_type in ["B0", "B1", "B2", "B3", "B4", "B5"], f"Unknown model type: {model_type}"
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    img_dir = os.path.join(base_dir, "train", "images")
    elev_dir = os.path.join(base_dir, "train", "elevations")
    slope_dir = os.path.join(base_dir, "train", "slopes")
    label_dir = os.path.join(base_dir, "train", "labels")

    # Load metadata and define input shape
    input_shape = (tile_size, tile_size, num_channels)
    train_df = csv_to_df('train', 0.4)
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')


    # --- Streaming Data Generator ---
    train_gen = build_tf_dataset(train_df, img_dir, elev_dir, slope_dir, label_dir,
                                 input_type=input_type, split='train',
                                 augment=True, shuffle=True, batch_size=batch_size)

    val_gen = build_tf_dataset(val_df, img_dir, elev_dir, slope_dir, label_dir,
                                input_type=input_type, split='val',
                                augment=False, shuffle=False, batch_size=batch_size)
    
    test_gen = build_tf_dataset(test_df, img_dir, elev_dir, slope_dir, label_dir,
                            input_type=input_type, split='test',
                            augment=False, shuffle=False, batch_size=batch_size)

    for x_batch, y_batch in test_gen.take(1):
        y_np = np.argmax(y_batch.numpy(), axis=-1)
        print("Unique labels in y batch:", np.unique(y_np))


    # --- Model ---
    import segmentation_models as sm
    if model_path is None or not os.path.exists(model_path):

        if model_type == "B2":
            model = SegFormer_B2(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B0":
            model = SegFormer_B0(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "B5":
            model = SegFormer_B5(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B4":
            model = SegFormer_B4(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B1":
            model = SegFormer_B1(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "B3":
            model = SegFormer_B3(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    else:
        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MeanIoU': MeanIoUMetric
        }  
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )
        for layer in model.layers:
            layer.trainable = True


    model.summary()

    for i, layer in enumerate(model.layers):
        print(i, layer.name)


    # --- Callbacks --- 
    #LearningRateLogger()
    #monitor = "val_iou_score"
    monitor = "val_loss"
    nan_terminate = TerminateOnNaN()
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor=monitor, mode="max", patience=20, restore_best_weights=True, verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor=monitor, mode="max", patience=6, min_lr=5e-7, factor=0.5, verbose=1, min_delta=1e-4)

    lr_schedule = TransformerLRSchedule(d_model=tile_size, warmup_steps=2048)
    callbacks = [
        time_limit,
        early_stop,
        nan_terminate, 
        StepTimer(),
    ]

    metrics = [
        sm.metrics.IOUScore(threshold=None, name="iou_score"),   # fast, approximated mIoU per batch
        sm.metrics.FScore(threshold=None, name="f1-score"),
        tf.keras.metrics.CategoricalAccuracy(name="categorical_accuracy"),
    ]


    loss_weights = [0.25, 1.5, 1.0]                 # [cce, dice, focal]
    class_weights = [1.1, 1.0, 1.0, 2.0, 0.0, 1.3]  # [building, clutter, vegetation, water, background, car]
    label_smoothing = 0.1

    raw_dice = sm.losses.DiceLoss(class_weights=class_weights)
    raw_focal = sm.losses.CategoricalFocalLoss()
    raw_cce = CategoricalCrossentropy()


    def total_loss_with_smoothing(y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        """Computes the total weighted loss from CCE, Dice, and Focal losses."""
        y_true_smoothed = apply_label_smoothing(y_true, smoothing=label_smoothing)

        dice = raw_dice(y_true_smoothed, y_pred)
        focal = raw_focal(y_true_smoothed, y_pred)
        cce = raw_cce(y_true_smoothed, y_pred)

        return loss_weights[0] * cce + loss_weights[1] * dice + loss_weights[2] * focal


    model.compile(
        optimizer=Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9),
        loss=total_loss_with_smoothing,
        metrics=metrics
    )
    
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )


    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, test_df, "/content/figs", img_dir, label_dir, tile_size, n_rows=4, n_cols=3) 
    measure_inference_time(model, test_gen, num_batches=5)
 

    # --- Full-Tile Reconstruction (Optional) ---
    if yummy:
        for tile_prefix in test_files:
            img, label, pred = reconstruct_canvas(
                model,
                test_df,
                tile_prefix,
                build_tf_dataset,
                img_dir,
                elev_dir,
                slope_dir,
                label_dir
            )
            plot_reconstruction(img, label, pred, tile_prefix)


