In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K
from scoring import evaluate_predictions

# --- Sanity ---
def test_training_sanity():
    print("✅ from training.ipynb")
    
class ClearMemory(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        gc.collect()
        K.clear_session()

# --- Visualisation ---
def visualise_prediction(rgb, true_mask, pred_mask):
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")
    axs[1].imshow(true_mask, cmap='tab10', vmin=0, vmax=5)
    axs[1].set_title("True Mask")
    axs[1].axis("off")
    axs[2].imshow(pred_mask, cmap='tab10', vmin=0, vmax=5)
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()

# --- Input config ---
INPUT_TYPE_CONFIG = {
    "1ch": {"description": "grayscale only", "channels": 1},
    "2ch": {"description": "grayscale + elevation", "channels": 2},
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elevation": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}

COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map

def train_model(input_type="rgb_elevation", model_type="unet", batch_size=8, epochs=10, tile_size=512, steps_per_epoch=None, verbose=1):
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    print(f"\n🔧 Training {model_type.upper()} with input type: {input_type} ({num_channels} channels)")
    print(f"🧪 Computed input shape: ({tile_size}, {tile_size}, {num_channels})")

    # --- Paths ---
    base_dir = "/content/chipped_data/chipped"
    train_images = os.path.join(base_dir, "train", "images")
    train_elev = os.path.join(base_dir, "train", "elevations")
    train_labels = os.path.join(base_dir, "train", "labels")

    val_images = os.path.join(base_dir, "val", "images")
    val_elev = os.path.join(base_dir, "val", "elevations")
    val_labels = os.path.join(base_dir, "val", "labels")

    # --- Generators ---
    train_gen = StreamingDataGenerator(train_images, train_elev, train_labels, batch_size=batch_size, input_type=input_type, shuffle=True, steps_per_epoch=steps_per_epoch)
    val_gen = StreamingDataGenerator(val_images, val_elev, val_labels, batch_size=batch_size, input_type=input_type, shuffle=False, steps_per_epoch=5, validation=True)

    # --- Model ---
    if model_type == "unet":
        print("🧪 Calling build_unet...")
        model = build_unet(input_shape=(tile_size, tile_size, num_channels), num_classes=6)
    elif model_type == "segformer":
        raise NotImplementedError("SegFormer support is coming soon.")
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

    print(f"🧪 Final model input shape: {model.input_shape}")
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    print("✅ Model compiled")

    # --- Callbacks ---
    os.makedirs("checkpoints", exist_ok=True)
    checkpoint = ModelCheckpoint("checkpoints/best_model.h5", monitor='val_accuracy', save_best_only=True)
    early_stop = EarlyStopping(monitor='val_loss', patience=12, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=6, verbose=1)
    nan_terminate = TerminateOnNaN()

    class TimeLimitCallback(tf.keras.callbacks.Callback):
        def __init__(self, max_minutes=20):
            super().__init__()
            self.max_duration = max_minutes * 60
        def on_train_begin(self, logs=None):
            self.start_time = tf.timestamp()
        def on_epoch_end(self, epoch, logs=None):
            elapsed = tf.timestamp() - self.start_time
            if elapsed > self.max_duration:
                print(f"⏱️ Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
                self.model.stop_training = True

    time_limit = TimeLimitCallback(max_minutes=60)
    callbacks = [checkpoint, early_stop, reduce_lr, nan_terminate, time_limit, ClearMemory()]

    print("Shape of data from train_gen:", next(iter(train_gen))[0].shape)
    print("Type of data from train_gen:", type(next(iter(train_gen))[0]))

    print("🚀 Starting training...")
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    # --- Evaluation ---
    print("🧪 Running evaluation on full validation set...")
    all_preds = []
    all_trues = []
    shown = 0

    for val_imgs, val_lbls in val_gen:
        pred = model.predict(val_imgs)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(val_lbls, axis=-1).astype(np.uint8)

        all_preds.append(pred_mask.flatten())
        all_trues.append(true_mask.flatten())

        if shown < 5:
            rgb_tile = (val_imgs[0][:, :, :3] * 255).astype(np.uint8)
            visualise_prediction(rgb_tile, true_mask[0], pred_mask[0])
            shown += 1

    print("\n📊 Evaluation Results:")
    y_pred = np.concatenate(all_preds)
    y_true = np.concatenate(all_trues)

    evaluate_predictions(y_pred, y_true)
