In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Sanity ---
def test_util_sanity():
    print("✅ from util.ipynb")

# Define the mapping from RGB color to class index and vice-versa
COLOR_TO_CLASS = {
    (230, 25, 75): 0,       # BUILDING
    (145, 30, 180): 1,      # CLUTTER
    (60, 180, 75): 2,       # VEGETATION
    (245, 130, 48): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (0, 130, 200): 5        # CAR
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)

# --- Visualisation ---
def visualise_prediction(rgb, true_mask, pred_mask):
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")
    axs[1].imshow(COLOR_PALETTE[true_mask])
    axs[1].set_title("True Mask")
    axs[1].axis("off")
    axs[2].imshow(COLOR_PALETTE[pred_mask])
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()

# --- Class Weights ---
def compute_class_weights(generator):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)

    for _, labels in generator:
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts

    total = np.sum(pixel_counts)
    weights = total / (NUM_CLASSES * np.maximum(pixel_counts, 1))
    weights = weights / np.sum(weights) * NUM_CLASSES
    return tf.constant(weights, dtype=tf.float32)

# --- Distribution Plot ---
def plot_class_distribution(generator, title="Class Distribution"):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)

    for _, labels in generator:
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts

    # Define class names for the x-axis labels
    class_names = [
        "0: Building",
        "1: Clutter",
        "2: Vegetation",
        "3: Water",
        "4: Background",
        "5: Car"
    ]

    # Use the class colours with a black edge
    colours = [np.array(CLASS_TO_COLOR[i]) / 255.0 for i in range(NUM_CLASSES)]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_names, pixel_counts, color=colours, edgecolor='black')
    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Pixel Count")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

