In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
from tensorflow.keras.losses import CategoricalCrossentropy
import segmentation_models as sm
from collections import defaultdict




# --- Sanity ---
def test_training_sanity():
    print("✅ from training.ipynb")



INPUT_TYPE_CONFIG = {
    "1ch": {"description": "grayscale only", "channels": 1},
    "2ch": {"description": "grayscale + elevation", "channels": 2},
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5,
    (255, 0, 255): 6 # Ignore pixel for visualisation
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6}  # Exclude ignore class
NUM_CLASSES = 6
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map


def filter_tile_ids_by_substring(image_dir, base_names):
    return [f.replace('-ortho.png', '') for f in os.listdir(image_dir) if any(base in f for base in base_names)]


def measure_inference_time(model, generator, num_batches=5):
    import time
    total_time = 0
    total_images = 0

    for i, (x_batch, _) in enumerate(generator):
        if i >= num_batches:
            break
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        end = time.time()
        total_time += (end - start)
        total_images += x_batch.shape[0]

    print(f"🧠 Inference time: {total_time:.2f} sec for {total_images} images")
    print(f"⏱️ Avg inference time per image: {total_time / total_images:.4f} sec")


from tqdm import trange
def plot_training_curves(history, out_dir):
    import matplotlib.pyplot as plt
    import os

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Define metrics to plot
    metrics = ['loss', 'accuracy', 'masked_mean_iou', 'f1-score', 'iou_score', 'categorical_accuracy']

    # 📈 Plotting Training Curves
    for metric in metrics:
        plt.figure()
        plt.plot(history.history[metric], label="Training " + metric)
        plt.plot(history.history["val_" + metric], label="Validation " + metric)
        plt.title("Training and Validation " + metric)
        plt.xlabel("Epochs")
        plt.ylabel(metric)
        plt.legend()
        plt.savefig(os.path.join(out_dir, metric + "_plot.png"))
        plt.close()

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.savefig(os.path.join(out_dir, "loss_plot.png"))
    plt.show()



# --- Metrics ---
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, TerminateOnNaN

# Custom learning rate schedule
class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * tf.pow(self.warmup_steps, -1.5)
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self):
        return {
            "d_model": self.d_model.numpy(),  # convert back to Python float
            "warmup_steps": self.warmup_steps.numpy()
        }


# Instantiate learning rate schedule and optimizer
lr_schedule = TransformerLRSchedule(d_model=256)
optimizer = Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# Loss function with label smoothing
def apply_label_smoothing(y_true, smoothing=0.1):
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)


os.environ["SM_FRAMEWORK"] = "tf.keras"

# Set class weights
weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]

# Raw losses from segmentation_models
raw_dice = sm.losses.DiceLoss(class_weights=weights)
raw_focal = sm.losses.CategoricalFocalLoss()

# Final loss function with label smoothing applied
def total_loss_with_smoothing(y_true, y_pred):
    y_true_smoothed = apply_label_smoothing(y_true, smoothing=0.1)
    return raw_dice(y_true_smoothed, y_pred) + raw_focal(y_true_smoothed, y_pred)



import tensorflow as tf
from tensorflow.keras.metrics import MeanIoU

# --- Jaccard Index ---
'''
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    def __init__(self, num_classes, name="mean_iou", dtype=None):
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)

miou_metric = MeanIoUMetric(num_classes=6)
'''

class MaskedMeanIoU(tf.keras.metrics.Metric):
    def __init__(self, num_classes, name="masked_mean_iou", **kwargs):
        super(MaskedMeanIoU, self).__init__(name=name, **kwargs)
        self.num_classes = num_classes
        self.total_cm = self.add_weight(
            name="total_confusion_matrix",
            shape=(num_classes, num_classes),
            initializer="zeros",
            dtype=tf.float32,
        )

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)

        # MASK: Ignore pixels labeled with 6 (your ignore class)
        mask = tf.not_equal(y_true, 6)
        y_true = tf.boolean_mask(y_true, mask)
        y_pred = tf.boolean_mask(y_pred, mask)

        current_cm = tf.math.confusion_matrix(y_true, y_pred, num_classes=self.num_classes, dtype=tf.float32)
        self.total_cm.assign_add(current_cm)

    def result(self):
        sum_over_row = tf.reduce_sum(self.total_cm, axis=0)
        sum_over_col = tf.reduce_sum(self.total_cm, axis=1)
        true_positives = tf.linalg.diag_part(self.total_cm)
        denominator = sum_over_row + sum_over_col - true_positives

        iou = tf.math.divide_no_nan(true_positives, denominator)
        return tf.reduce_mean(iou)

    def reset_states(self):
        tf.keras.backend.set_value(self.total_cm, tf.zeros((self.num_classes, self.num_classes)))


# --- Callbacks ---
class ClearMemory(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        gc.collect()
        K.clear_session()

class LearningRateLogger(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = self.model.optimizer.learning_rate.numpy()
        print(f"📉 Learning Rate at epoch {epoch + 1}: {lr:.6f}")

class TimeLimitCallback(tf.keras.callbacks.Callback):
    def __init__(self, max_minutes=20):
        super().__init__()
        self.max_duration = max_minutes * 60
    def on_train_begin(self, logs=None):
        self.start_time = tf.timestamp()
    def on_epoch_end(self, epoch, logs=None):
        elapsed = tf.timestamp() - self.start_time
        if elapsed > self.max_duration:
            print(f"⏱️ Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
            self.model.stop_training = True

            
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
import matplotlib.pyplot as plt

class DistributionLogger(tf.keras.callbacks.Callback):
    def __init__(self, generator, name="Training", max_batches=16, visualise_samples=2):
        super().__init__()
        self.generator = generator
        self.name = name
        self.max_batches = max_batches
        self.visualise_samples = visualise_samples
        self.cumulative_class_counts = defaultdict(int)

    def on_epoch_end(self, epoch, logs=None):
        batch_class_counts = defaultdict(int)
        all_samples = []
        batches_seen = 0

        for batch_images, batch_labels in self.generator:
            if batches_seen >= self.max_batches:
                break

            batch_preds = np.argmax(batch_labels, axis=-1)
            unique, counts = np.unique(batch_preds, return_counts=True)

            for u, c in zip(unique, counts):
                batch_class_counts[u] += c
                self.cumulative_class_counts[u] += c

            for img, label in zip(batch_images, batch_preds):
                all_samples.append((img, label))

            batches_seen += 1

        total_pixels = sum(batch_class_counts.values())
        total_images = batches_seen * self.generator.batch_size

        print(f"📊 {self.name} Distribution After Epoch {epoch + 1} ({total_images:,} images):")
        for cls in sorted(batch_class_counts):
            count = batch_class_counts[cls]
            percent = 100.0 * count / total_pixels
            print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

    def on_train_end(self, logs=None):
        total_pixels = sum(self.cumulative_class_counts.values())
        print("📊 Final Cumulative Training Class Distribution:")
        print(f"Total pixels: {total_pixels:,} px")
        for cls in sorted(self.cumulative_class_counts):
            count = self.cumulative_class_counts[cls]
            percent = 100.0 * count / total_pixels
            print(f"  Class {cls} ({CLASS_NAMES[cls]}): {count:,} px ({percent:.2f}%)")

        plot_class_distribution(self.cumulative_class_counts)


class ValidationPredictionLogger(tf.keras.callbacks.Callback):
    def __init__(self, val_gen, user_model, max_batches=1):
        super().__init__()
        self.val_gen = val_gen
        self.user_model = user_model
        self.max_batches = max_batches
        self.ignore_color = (255, 0, 255)
        self.class_to_color = {
            0: (230, 25, 75),    # Building
            1: (145, 30, 180),   # Clutter
            2: (60, 180, 75),    # Vegetation
            3: (245, 130, 48),   # Water
            4: (255, 255, 255),  # Background
            5: (0, 130, 200),    # Car
        }

    def on_epoch_end(self, epoch, logs=None): 
        if (epoch + 1) % 5 != 0:
            return

        batches_seen = 0
        for batch_images, batch_labels in self.val_gen:
            if batches_seen >= self.max_batches:
                break
            preds = self.user_model.predict(batch_images)
            preds_argmax = np.argmax(preds, axis=-1)
            true_argmax = np.argmax(batch_labels, axis=-1)

            fig, axs = plt.subplots(len(batch_images), 3, figsize=(10, 3 * len(batch_images)))
            for i in range(len(batch_images)):
                axs[i, 0].imshow((batch_images[i] * 255).astype(np.uint8))
                axs[i, 0].set_title("Input")
                axs[i, 0].axis('off')

                h, w = true_argmax[i].shape
                true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
                pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

                # Decode classes
                for cid, col in self.class_to_color.items():
                    true_rgb[true_argmax[i] == cid] = col
                    pred_rgb[preds_argmax[i] == cid] = col

                # 🟣 Highlight ignore pixels (zero-vectors in one-hot)
                ignore_mask = np.all(batch_labels[i] == 0, axis=-1)
                true_rgb[ignore_mask] = self.ignore_color
                pred_rgb[ignore_mask] = self.ignore_color

                axs[i, 1].imshow(true_rgb)
                axs[i, 1].set_title("Ground Truth")
                axs[i, 1].axis('off')

                axs[i, 2].imshow(pred_rgb)
                axs[i, 2].set_title("Prediction")
                axs[i, 2].axis('off')

            plt.tight_layout()
            plt.suptitle(f"Validation Predictions After Epoch {epoch + 1}", y=1.02)
            plt.show()
            batches_seen += 1


class StepTimer(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        self.total_time = 0.0
        self.total_steps = 0

    def on_train_batch_begin(self, batch, logs=None):
        self.start_time = tf.timestamp()

    def on_train_batch_end(self, batch, logs=None):
        elapsed = tf.timestamp() - self.start_time
        self.total_time += elapsed
        self.total_steps += 1

    def on_train_end(self, logs=None):
        avg_step_time = self.total_time / self.total_steps
        print(f"🕒 Average training step time: {avg_step_time:.4f} sec")



# Create both output folders
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("/content/drive/MyDrive/segmentation_checkpoints", exist_ok=True)

# Checkpoint in Colab workspace
checkpoint_local = ModelCheckpoint(
    "checkpoints/best_model.h5",
    monitor='val_mean_iou',
    save_best_only=True
)

# Checkpoint in Google Drive
checkpoint_drive = ModelCheckpoint(
    "/content/drive/MyDrive/segmentation_checkpoints/best_model.h5",
    monitor='val_mean_iou',
    save_best_only=True
)



# --- Model Building ---
#from segmentation_models import Unet, MultiUnet, UnetAux, Segformer, CRF

def train_model(base_dir="/content/chipped_data/content/chipped_data", out_dir="/content/figs", 
                freeze=False, model_path=None,
                input_type="rgb_elev", model_type="unet", tile_size=256,
                batch_size=8, steps=None, epochs=10, train_time=20, verbose=1
                ):
    
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    train_images = os.path.join(base_dir, "train", "images")
    train_elev = os.path.join(base_dir, "train", "elevations")
    train_labels = os.path.join(base_dir, "train", "labels")

    eval_images = os.path.join(base_dir, "raw", "images")
    eval_elev = os.path.join(base_dir, "raw", "elevations")
    eval_labels = os.path.join(base_dir, "raw", "labels")

    val_images = os.path.join(base_dir, "raw", "images")
    val_elev = os.path.join(base_dir, "raw", "elevations")
    val_labels = os.path.join(base_dir, "raw", "labels")

    test_images = os.path.join(base_dir, "raw", "images")
    test_elev = os.path.join(base_dir, "raw", "elevations")
    test_labels = os.path.join(base_dir, "raw", "labels")


    # Load metadata
    train_df = csv_to_df('train')
    val_df = csv_to_df('val')
    test_df = csv_to_df('test')

    # Compute steps
    eval_batch_size = batch_size // 2
    val_steps = len(val_df) // eval_batch_size + (len(val_df) % eval_batch_size > 0)
    test_steps = len(test_df) // eval_batch_size + (len(test_df) % eval_batch_size > 0)

    # --- Streaming Data Generator ---
    train_gen = StreamingDataGenerator(
        train_images, train_elev, train_labels,
        split='train', df=train_df,
        batch_size=batch_size,
        input_type=input_type,
        shuffle=True,
        steps=steps,
        fixed=False,
        augment=True,
    )
 

    val_gen = StreamingDataGenerator(eval_images, eval_elev, eval_labels,
                                     split='val', df=val_df,
                                     batch_size=8, steps=val_steps,
                                     input_type=input_type, shuffle=False, fixed=True, augment=False)


    test_gen = StreamingDataGenerator(eval_images, eval_elev, eval_labels,
                                      split='test', df=test_df,
                                      batch_size=batch_size, steps=test_steps,
                                      input_type=input_type, shuffle=False, fixed=True, augment=False)




    #input_shape = (tile_size, tile_size, num_channels)
    input_shape = (None, None, num_channels)

    # --- Model ---
    if model_path is None or not os.path.exists(model_path):
        if model_type == "unet":
            print("🧪 Calling build_unet...")
            model = build_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "multi_unet":
            print("🧪 Calling build_multi_unet...")
            model = build_multi_unet(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "unet_aux":
            print("🧪 Calling build_multi_unet_aux...")
            model = build_unet_aux(input_shape=input_shape, num_classes=NUM_CLASSES)

        elif model_type == "segformer":
            model = SegFormer_B2(input_shape=input_shape, num_classes=NUM_CLASSES)
        
        elif model_type == "resnet34":
            import segmentation_models as sm
            model = sm.Unet(
                backbone_name="resnet34",          # or 'efficientnetb0', 'mobilenetv2', etc.
                input_shape=input_shape,
                classes=NUM_CLASSES,                  
                activation='softmax', 
                encoder_weights='imagenet'         # Load ImageNet pre-trained weights
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")
    
    else:
        import segmentation_models as sm
        print(f"Loading Stage 1 model from {model_path} and unfreezing all layers...")

        custom_objects={
            'DiceLoss': sm.losses.DiceLoss,
            'CategoricalFocalLoss': sm.losses.CategoricalFocalLoss,
            'MaskedMeanIoU': MaskedMeanIoU
        }
        
        model = tf.keras.models.load_model(
            model_path,
            custom_objects=custom_objects,
            compile=True
        )

        for layer in model.layers:
            layer.trainable = True


    lr_schedule = TransformerLRSchedule(d_model=tile_size)
    optimizer = Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
    miou_metric = MaskedMeanIoU(num_classes=NUM_CLASSES)
    metrics=[
        miou_metric,
        sm.metrics.IOUScore(threshold=None),
        sm.metrics.FScore(threshold=None),
        'categorical_accuracy',
        'accuracy'
    ]

    if freeze:
        print("Stage 1: Freezing all layers except head...")
        for layer in model.layers:
            layer.trainable = False
        for layer in model.layers[-10:]:
            layer.trainable = True

        optimizer = Adam(learning_rate=1e-3)


    # --- Compile Model ---
    model.compile(
        optimizer=optimizer,
        loss=total_loss_with_smoothing,
        metrics=metrics
    )

    model.summary()


    # --- Callbacks --- 
    time_limit = TimeLimitCallback(max_minutes=train_time)
    early_stop = EarlyStopping(monitor='val_mean_iou', patience=12, restore_best_weights=True, mode='max')
    nan_terminate = TerminateOnNaN()

    callbacks = [
        checkpoint_local, checkpoint_drive, 
        early_stop, nan_terminate, time_limit, 
        ClearMemory(), 
        LearningRateLogger(),
        StepTimer(),
        DistributionLogger(train_gen, name="Training", max_batches=steps),
        ValidationPredictionLogger(val_gen, model, max_batches=1)
    ]


    # --- Train Model ---
    print("🚀 Starting training...")

    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_name = f"unet_resnet34_stage1_{timestamp}"
    model_name

    model.save(f"/content/checkpoints/{model_name}.keras")
    print(f"model saved to /content/checkpoints/{model_name}.keras")



    # --- Evaluate Model ---
    plot_training_curves(history, out_dir)
    evaluate_on_test(model, test_gen, n_vis=9)
    measure_inference_time(model, test_gen, num_batches=steps)
    print("🚀 Training complete!")





def reconstruct_prediction_canvas(df, tile_size, image_dir, label_dir, pred_dir):
    import numpy as np
    import cv2
    import os

    # Determine canvas size from tile coordinates
    x_coords = df['x'].values
    y_coords = df['y'].values
    max_x = x_coords.max() + tile_size
    max_y = y_coords.max() + tile_size
    min_x = x_coords.min()
    min_y = y_coords.min()

    canvas_shape = (max_y - min_y, max_x - min_x, 3)
    img_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)  # Magenta default
    label_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)
    pred_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)

    for _, row in df.iterrows():
        tile_id = row['tile_id']
        x, y = row['x'], row['y']
        x_offset = x - min_x
        y_offset = y - min_y

        try:
            img_path = os.path.join(image_dir, tile_id + '-ortho.png')
            label_path = os.path.join(label_dir, tile_id + '-label.png')
            pred_path = os.path.join(pred_dir, tile_id + '.png')

            if os.path.exists(img_path):
                rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                img_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = rgb

            if os.path.exists(label_path):
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)
                label_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = label_rgb

            if os.path.exists(pred_path):
                pred_rgb = cv2.cvtColor(cv2.imread(pred_path), cv2.COLOR_BGR2RGB)
                pred_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = pred_rgb

        except Exception as e:
            print(f"⚠️ Failed to load tile {tile_id}: {e}")
            continue

    return img_canvas, label_canvas, pred_canvas








    # --- Archive ---

    '''
    class_weights = compute_class_weights(train_gen)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    def weighted_loss(y_true, y_pred):
        weights = tf.reduce_sum(class_weights * y_true, axis=-1)
        return tf.reduce_mean(weights * loss_fn(y_true, y_pred))

    model.compile(optimizer='adam', loss=weighted_loss, metrics=['accuracy'])
    print("✅ Model compiled with class weights")
    '''

    '''
    from segmentation_models.metrics import iou_score as jaccard_coef
    metrics = ["accuracy", jaccard_coef]
    weights = [0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666]
    
    dice_loss = dice_loss(class_weights = weights)
    focal_loss = categorical_focal_loss()
    total_loss = dice_loss + (1 * focal_loss)
    
    model.compile(optimizer="adam", loss="total_loss", metrics=metrics)
    '''

    '''
    def jaccard_coef(y_true, y_pred):
    y_true_flatten = K.flatten(y_true)
    y_pred_flatten = K.flatten(y_pred)
    intersection = K.sum(y_true_flatten * y_pred_flatten)
    final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
    return final_coef_value
    '''