In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TerminateOnNaN
import cv2
import matplotlib.pyplot as plt
import itertools
import gc
import tensorflow.keras.backend as K

# --- Sanity ---
def test_training_sanity():
    print("✅ from training.ipynb")

# --- Visualisation ---
def visualise_prediction(rgb, true_mask, pred_mask):
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")
    axs[1].imshow(COLOR_PALETTE[true_mask])
    axs[1].set_title("True Mask")
    axs[1].axis("off")
    axs[2].imshow(COLOR_PALETTE[pred_mask])
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()

INPUT_TYPE_CONFIG = {
    "1ch": {"description": "grayscale only", "channels": 1},
    "2ch": {"description": "grayscale + elevation", "channels": 2},
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elevation": {"description": "RGB + elevation", "channels": 4}
}

COLOR_TO_CLASS = {
    (230, 25, 75): 0,
    (145, 30, 180): 1,
    (60, 180, 75): 2,
    (245, 130, 48): 3,
    (255, 255, 255): 4,
    (0, 130, 200): 5
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}

class ClearMemory(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        gc.collect()
        K.clear_session()

def decode_label_image(label_img):
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)
    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"❌ Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]
    return label_map

def load_rgb_pair_batch(image_dir, label_dir, batch_size=4):
    x_batch, y_batch = [], []
    filenames = [f for f in os.listdir(image_dir) if f.endswith('-ortho.png') and os.path.getsize(os.path.join(image_dir, f)) > 0][:batch_size]

    if not filenames:
        print(f"⚠️ No valid image files found in {image_dir}.")
        return np.array([]), np.array([])

    for f in filenames:
        base = f.replace('-ortho.png', '')
        rgb_path = os.path.join(image_dir, base + '-ortho.png')
        label_path = os.path.join(label_dir, base + '-label.png')

        if not os.path.exists(rgb_path) or not os.path.exists(label_path):
            print(f"Skipping {base}: Missing files.")
            continue

        img = cv2.cvtColor(cv2.imread(rgb_path), cv2.COLOR_BGR2RGB)
        label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)

        if img is None or label_rgb is None or img.shape != label_rgb.shape:
            print(f"Skipping {base}: Failed to load or mismatched shape.")
            continue

        h, w, _ = label_rgb.shape
        # Use int type to avoid wrap-around for unknown colors
        label = np.full((h, w), -1, dtype=np.int32) # Changed dtype to int32
        for color, idx in COLOR_TO_CLASS.items():
            mask = np.all(label_rgb == color, axis=-1)
            label[mask] = idx

        # Check for pixels that are still -1 (unknown colors)
        if np.any(label == -1):
            print(f"⚠️ Unknown colours found in label at {label_path}. Skipping this image.")
            continue # Skip this pair if unknown colors are present


        # Before converting to categorical, ensure all labels are within the valid range [0, NUM_CLASSES-1]
        # The previous check for -1 already handles this if we skip the image,
        # but it's good practice to be explicit before to_categorical.
        if np.any(label < 0) or np.any(label >= NUM_CLASSES):
             print(f"⚠️ Invalid label values found in label at {label_path} after mapping. Skipping this image.")
             continue


        label_onehot = tf.keras.utils.to_categorical(label, num_classes=NUM_CLASSES)

        print(f"Unique label indicies before one-hot for {base}: {np.unique(label)}")

        x_batch.append(img.astype(np.float32) / 255.0)
        y_batch.append(label_onehot.astype(np.float32))

    try:
        # Stack the loaded data.
        # Only return arrays if data was actually loaded.
        if x_batch and y_batch:
             return np.array(x_batch), np.array(y_batch)
        else:
             print("⚠️ No valid data loaded for the batch.")
             return np.array([]), np.array([])

    except ValueError as e:
        print(f"Error creating numpy arrays: {e}")
        print(f"Shapes of loaded images: {[x.shape for x in x_batch]}")
        print(f"Shapes of loaded labels: {[y.shape for y in y_batch]}")
        return np.array([]), np.array([])



def train_model(input_type="rgb_elevation", model_type="unet", batch_size=8, epochs=10, tile_size=512, steps_per_epoch=None, verbose=1):
    assert input_type in INPUT_TYPE_CONFIG, f"Unknown input type: {input_type}"
    num_channels = INPUT_TYPE_CONFIG[input_type]["channels"]

    print(f"\n🔧 Training {model_type.upper()} with input type: {input_type} ({num_channels} channels)")
    print(f"🧪 Computed input shape: ({tile_size}, {tile_size}, {num_channels})")

    base_dir = "/content/chipped_data/chipped"
    train_images = os.path.join(base_dir, "train", "images")
    train_elev = os.path.join(base_dir, "train", "elevations")
    train_labels = os.path.join(base_dir, "train", "labels")

    val_images = os.path.join(base_dir, "val", "images")
    val_elev = os.path.join(base_dir, "val", "elevations")
    val_labels = os.path.join(base_dir, "val", "labels")

    # --- Generators ---
    '''
    train_gen = StreamingDataGenerator(train_images, train_elev, train_labels, 
                                       batch_size=batch_size, input_type=input_type, 
                                       shuffle=True, steps_per_epoch=steps_per_epoch,
                                       )
    val_gen = StreamingDataGenerator(val_images, val_elev, val_labels, 
                                      batch_size=batch_size, input_type=input_type, 
                                      shuffle=False, steps_per_epoch=steps_per_epoch,
                                      )'''
    
    train_gen = load_rgb_pair_batch(train_images, train_labels, batch_size=16)
    val_gen = load_rgb_pair_batch(val_images, val_labels, batch_size=16)


    # --- Model ---
    if model_type == "unet":
        print("🧪 Calling build_unet...")
        model = build_unet(input_shape=(tile_size, tile_size, num_channels), num_classes=NUM_CLASSES)
    elif model_type == "segformer":
        raise NotImplementedError("SegFormer support is coming soon.")
    else:
        raise ValueError(f"Unknown model_type: {model_type}")

    print(f"🧪 Final model input shape: {model.input_shape}")

    class_weights = compute_class_weights(train_gen)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)

    def weighted_loss(y_true, y_pred):
        weights = tf.reduce_sum(class_weights * y_true, axis=-1)
        return tf.reduce_mean(weights * loss_fn(y_true, y_pred))

    model.compile(optimizer='adam', loss=weighted_loss, metrics=['accuracy'])
    print("✅ Model compiled with class weights")

    os.makedirs("checkpoints", exist_ok=True)
    checkpoint = ModelCheckpoint("checkpoints/best_model.h5", monitor='val_accuracy', save_best_only=True)
    early_stop = EarlyStopping(monitor='val_loss', patience=12, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=6, verbose=1)
    nan_terminate = TerminateOnNaN()

    class TimeLimitCallback(tf.keras.callbacks.Callback):
        def __init__(self, max_minutes=20):
            super().__init__()
            self.max_duration = max_minutes * 60
        def on_train_begin(self, logs=None):
            self.start_time = tf.timestamp()
        def on_epoch_end(self, epoch, logs=None):
            elapsed = tf.timestamp() - self.start_time
            if elapsed > self.max_duration:
                print(f"⏱️ Training time exceeded {self.max_duration // 60} minutes. Stopping early.")
                self.model.stop_training = True

    time_limit = TimeLimitCallback(max_minutes=120)
    callbacks = [checkpoint, early_stop, reduce_lr, nan_terminate, time_limit, ClearMemory()]

    print("Shape of data from train_gen:", next(iter(train_gen))[0].shape)
    print("Type of data from train_gen:", type(next(iter(train_gen))[0]))

    print("🚀 Starting training...")
    history = model.fit(
        train_gen, validation_data=val_gen,
        epochs=epochs,
        callbacks=callbacks,
        verbose=verbose
    )

    print("🧪 Running evaluation on full validation set...")
    all_preds = []
    all_trues = []
    shown = 0

    for val_imgs, val_lbls in val_gen:
        pred = model.predict(val_imgs)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(val_lbls, axis=-1).astype(np.uint8)

        all_preds.extend(pred_mask.reshape(-1))
        all_trues.extend(true_mask.reshape(-1))

        for i in range(min(5 - shown, len(val_imgs))):
            rgb_tile = (val_imgs[i][:, :, :3] * 255).astype(np.uint8)
            visualise_prediction(rgb_tile, true_mask[i], pred_mask[i])
            shown += 1
            if shown >= 5:
                break
        if shown >= 5:
            break

    print("\n📊 Evaluation Results:")
    y_pred = np.array(all_preds)
    y_true = np.array(all_trues)

    evaluate_predictions(y_pred, y_true)

    print("\n📊 Class distribution in training set:")
    plot_class_distribution(train_gen, title="Training Class Distribution")

    print("\n📊 Class distribution in validation set:")
    plot_class_distribution(val_gen, title="Validation Class Distribution")
