In [None]:
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import trange

out_dir="/content/figs"
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < 6} # Exclude ignore color


def visualise_prediction(rgb, true_mask_onehot, pred_mask):
    true_mask = np.argmax(true_mask_onehot, axis=-1)

    h, w = true_mask.shape
    true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

    for class_id, color in CLASS_TO_COLOR.items():
        true_rgb[true_mask == class_id] = color
        pred_rgb[pred_mask == class_id] = color

    ignore_mask = np.all(true_mask_onehot == 0, axis=-1)
    true_rgb[ignore_mask] = (255, 0, 255)
    pred_rgb[ignore_mask] = (255, 0, 255)

    fig, axs = plt.subplots(1, 3, figsize=(14, 4))
    axs[0].imshow(rgb)
    axs[0].set_title("Input")
    axs[0].axis('off')

    axs[1].imshow(true_rgb)
    axs[1].set_title("Ground Truth")
    axs[1].axis('off')

    axs[2].imshow(pred_rgb)
    axs[2].set_title("Prediction")
    axs[2].axis('off')

    plt.tight_layout()
    plt.show()
    plt.close(fig)



def evaluate_predictions(pred_mask, true_mask, num_classes=6):
    print("🔎 Evaluating predictions...")

    pred_flat = np.asarray(pred_mask).flatten()
    true_flat = np.asarray(true_mask).flatten()

    # Remove ignored pixels (label == 6)
    valid_indices = true_flat != 6
    pred_flat = pred_flat[valid_indices]
    true_flat = true_flat[valid_indices]

    if pred_flat.shape != true_flat.shape:
        raise ValueError(f"Prediction and ground truth shapes don't match: {pred_flat.shape} vs {true_flat.shape}")

    print("📈 Classification Report:")
    print(classification_report(
        true_flat,
        pred_flat,
        labels=np.arange(num_classes),
        zero_division=0,
        digits=3
    ))

    class_names = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']

    print("\n🌀 Confusion Matrix:")
    cm = confusion_matrix(true_flat, pred_flat, labels=np.arange(num_classes))
    cm_normalised = cm.astype(np.float32) / cm.sum(axis=1, keepdims=True)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm_normalised, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Normalised Confusion Matrix")
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "conf_matrix.png"))
    plt.show()

    print("\n📊 IoU per Class:")
    intersection = np.diag(cm)
    ground_truth_set = cm.sum(axis=1)
    predicted_set = cm.sum(axis=0)
    union = ground_truth_set + predicted_set - intersection

    iou_per_class = intersection / np.maximum(union, 1)

    for i, iou in enumerate(iou_per_class):
        name = class_names[i]
        print(f"  Class {i} ({name}): {iou:.4f}")

    mean_iou = np.mean(iou_per_class)
    print(f"\n📈 Mean IoU: {mean_iou:.4f}")




import os
import gc
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from sklearn.metrics import classification_report, confusion_matrix
from PIL import Image
from datetime import datetime

def evaluate_on_test(model, test_gen, test_df, pred_out_dir, image_dir, label_dir, tile_size=256, n_vis=10):
    print("🧪 Running evaluation on test set...")

    all_test_preds = []
    all_test_trues = []
    shown_ids = set()
    os.makedirs(pred_out_dir, exist_ok=True)

    test_tile_ids = test_df['tile_id'].tolist()
    tile_index = 0

    for batch_x, batch_y in test_gen.as_numpy_iterator():
        if batch_x.size == 0:
            continue

        pred = model.predict(batch_x, verbose=0)
        pred_mask = np.argmax(pred, axis=-1).astype(np.uint8)
        true_mask = np.argmax(batch_y, axis=-1).astype(np.uint8)

        for j in range(batch_x.shape[0]):
            if tile_index >= len(test_tile_ids):
                break  # safety stop
            tile_id = test_tile_ids[tile_index]
            tile_index += 1

            # 💾 Save prediction
            pred_rgb = np.zeros((*pred_mask[j].shape, 3), dtype=np.uint8)
            for class_id, color in CLASS_TO_COLOR.items():
                pred_rgb[pred_mask[j] == class_id] = color
            Image.fromarray(pred_rgb).save(os.path.join(pred_out_dir, tile_id + ".png"))

            # 👁️ Visualise
            if len(shown_ids) < n_vis:
                rgb_tile = (batch_x[j][:, :, :3] * 255).astype(np.uint8)
                visualise_prediction(rgb_tile, batch_y[j], pred_mask[j])
                shown_ids.add(tile_id)

        all_test_preds.extend(pred_mask.reshape(-1))
        all_test_trues.extend(true_mask.reshape(-1))

        del batch_x, batch_y, pred, pred_mask, true_mask
        gc.collect()

    if not all_test_preds:
        print("⚠️ No test predictions were collected.")
        return

    print("\n📊 Test Set Evaluation Results:")
    evaluate_predictions(np.array(all_test_preds), np.array(all_test_trues))
    '''
    # 🧩 Reconstruct each base file and save outputs
    os.makedirs("/content/figs", exist_ok=True)
    os.makedirs("/content/drive/MyDrive/figs", exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    base_files = test_df['tile_id'].str.extract(r"(^[^_]+_[^_]+_[^_]+)")[0].unique()

    for base in base_files:
        base_df = test_df[test_df['tile_id'].str.startswith(base)]
        img_canvas, label_canvas, pred_canvas = reconstruct_prediction_canvas(
            df=base_df,
            tile_size=tile_size,
            image_dir=image_dir,
            label_dir=label_dir,
            pred_dir=pred_out_dir
        )

        for name, canvas in zip(['img', 'label', 'pred'], [img_canvas, label_canvas, pred_canvas]):
            filename = f"{timestamp}_{name}_{base}.png"
            for target_dir in ["/content/figs", "/content/drive/MyDrive/figs"]:
                Image.fromarray(canvas).save(os.path.join(target_dir, filename))

        gc.collect()'''

    print("✅ Reconstruction and saving complete.")




def reconstruct_prediction_canvas(df, tile_size, image_dir, label_dir, pred_dir):
    import numpy as np
    import cv2
    import os

    # Determine canvas size from tile coordinates
    x_coords = df['x'].values
    y_coords = df['y'].values
    max_x = x_coords.max() + tile_size
    max_y = y_coords.max() + tile_size
    min_x = x_coords.min()
    min_y = y_coords.min()

    canvas_shape = (max_y - min_y, max_x - min_x, 3)
    img_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)  # Magenta default
    label_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)
    pred_canvas = np.full(canvas_shape, (255, 0, 255), dtype=np.uint8)

    for _, row in df.iterrows():
        tile_id = row['tile_id']
        x, y = row['x'], row['y']
        x_offset = x - min_x
        y_offset = y - min_y

        try:
            img_path = os.path.join(image_dir, tile_id + '-ortho.png')
            label_path = os.path.join(label_dir, tile_id + '-label.png')
            pred_path = os.path.join(pred_dir, tile_id + '.png')

            if os.path.exists(img_path):
                rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
                img_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = rgb

            if os.path.exists(label_path):
                label_rgb = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)
                label_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = label_rgb

            if os.path.exists(pred_path):
                pred_rgb = cv2.cvtColor(cv2.imread(pred_path), cv2.COLOR_BGR2RGB)
                pred_canvas[y_offset:y_offset+tile_size, x_offset:x_offset+tile_size] = pred_rgb

        except Exception as e:
            print(f"⚠️ Failed to load tile {tile_id}: {e}")
            continue

    return img_canvas, label_canvas, pred_canvas


