In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import seaborn as sns

# --- Color and Class Constants ---
COLOR_TO_CLASS = {
    (230, 25, 75): 0,       # BUILDING
    (145, 30, 180): 1,      # CLUTTER
    (60, 180, 75): 2,       # VEGETATION
    (245, 130, 48): 3,      # WATER
    (255, 255, 255): 4,     # GROUND
    (0, 130, 200): 5        # CAR
}

CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items()}
NUM_CLASSES = len(COLOR_TO_CLASS)
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}
CLASS_NAMES = ['Building', 'Clutter', 'Vegetation', 'Water', 'Background', 'Car']
class_names = CLASS_NAMES  # For compatibility with existing code


# --- Constants ---
INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 5}
}


base_dir="/content/chipped_data/content/chipped_data"
out_dir="/content/figs"

# Create output directory if it doesn't exist
os.makedirs(out_dir, exist_ok=True)





def filter_tile_ids_by_substring(image_dir: str, base_names: list[str]) -> list[str]:
    """Filters filenames in a directory to extract tile IDs that match given substrings.

    Looks for files ending in '-ortho.png' and checks if any given base name is present
    in the filename. If matched, strips the '-ortho.png' suffix and returns the base ID.

    Args:
        image_dir (str): Path to the directory containing image files.
        base_names (list[str]): List of substrings to match in the filenames.

    Returns:
        list[str]: A list of tile IDs (filenames without '-ortho.png') that contain
        any of the specified base substrings.
    """
    return [
        f.replace('-ortho.png', '')
        for f in os.listdir(image_dir)
        if any(base in f for base in base_names)
    ]




# List of all available file prefixes (tile_ids without coordinate suffix)
# These are used to identify the source files for splitting the dataset.
all_files = [
    '107f24d6e9_F1BE1D4184INSPIRE', '11cdce7802_B6A62F8BE0INSPIRE', '12fa5e614f_53197F206FOPENPIPELINE', '130a76ebe1_68B40B480AOPENPIPELINE', 
    '1476907971_CHADGRISMOPENPIPELINE', '1553541487_APIGENERATED', '1553541585_APIGENERATED', '1553627230_APIGENERATED', '15efe45820_D95DF0B1F4INSPIRE', 
    '1726eb08ef_60693DB04DINSPIRE', '1d056881e8_29FEA32BC7INSPIRE', '1d4fbe33f3_F1BE1D4184INSPIRE', '1df70e7340_4413A67E91INSPIRE', '2552eb56dd_2AABB46C86OPENPIPELINE', 
    '25f1c24f30_EB81FE6E2BOPENPIPELINE', '2ef3a4994a_0CCD105428INSPIRE', '2ef883f08d_F317F9C1DFOPENPIPELINE', '34fbf7c2bd_E8AD935CEDINSPIRE', 
    '3502e187b2_23071E4605OPENPIPELINE', '39e77bedd0_729FB913CDOPENPIPELINE', '420d6b69b8_84B52814D2OPENPIPELINE', '520947aa07_8FCB044F58OPENPIPELINE', 
    '551063e3c5_8FCB044F58INSPIRE', '57426ebe1e_84B52814D2OPENPIPELINE', '5fa39d6378_DB9FF730D9OPENPIPELINE', '6f93b9026b_F1BFB8B17DOPENPIPELINE', 
    '7008b80b00_FF24A4975DINSPIRE', '74d7796531_EB81FE6E2BOPENPIPELINE', '7c719dfcc0_310490364FINSPIRE', '84410645db_8D20F02042OPENPIPELINE', 
    '8710b98ea0_06E6522D6DINSPIRE', '888432f840_80E7FD39EBINSPIRE', '9170479165_625EDFBAB6OPENPIPELINE', 'a1af86939f_F1BE1D4184OPENPIPELINE', 
    'b61673f780_4413A67E91INSPIRE', 'b705d0cc9c_E5F5E0E316OPENPIPELINE', 'b771104de5_7E02A41EBEOPENPIPELINE', 'c2e8370ca3_3340CAC7AEOPENPIPELINE', 
    'c37dbfae2f_84B52814D2OPENPIPELINE', 'c644f91210_27E21B7F30OPENPIPELINE', 'c6d131e346_536DE05ED2OPENPIPELINE', 'c8a7031e5f_32156F5DC2INSPIRE', 
    'cc4b443c7d_A9CBEF2C97INSPIRE', 'd06b2c67d2_2A62B67B52OPENPIPELINE', 'd9161f7e18_C05BA1BC72OPENPIPELINE', 'dabec5e872_E8AD935CEDINSPIRE', 
    'e87da4ebdb_29FEA32BC7INSPIRE', 'ebffe540d0_7BA042D858OPENPIPELINE', 'ec09336a6f_06BA0AF311OPENPIPELINE', 
    'f0747ed88d_E74C0DD8FDOPENPIPELINE', 'f4dd768188_NOLANOPENPIPELINE', 'f56b6b2232_2A62B67B52OPENPIPELINE', 
    'f971256246_MIKEINSPIRE', 'f9f43e5144_1DB9E6F68BINSPIRE', 'fc5837dcf8_7CD52BE09EINSPIRE'
]

# Prefixes for files designated for the validation set
val_files = [
    "c644f91210_27E21B7F30OPENPIPELINE",
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "1d056881e8_29FEA32BC7INSPIRE",
    "3502e187b2_23071E4605OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE",
    "c8a7031e5f_32156F5DC2INSPIRE",
    "551063e3c5_8FCB044F58INSPIRE",
    "fc5837dcf8_7CD52BE09EINSPIRE",
    "39e77bedd0_729FB913CDOPENPIPELINE",
]

# Prefixes for files designated for the test set
test_files = [
    "25f1c24f30_EB81FE6E2BOPENPIPELINE",
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "15efe45820_D95DF0B1F4INSPIRE",
    "c6d131e346_536DE05ED2OPENPIPELINE",
    "12fa5e614f_53197F206FOPENPIPELINE",
    "5fa39d6378_DB9FF730D9OPENPIPELINE",
    "ebffe540d0_7BA042D858OPENPIPELINE",
    "8710b98ea0_06E6522D6DINSPIRE",
    "84410645db_8D20F02042OPENPIPELINE",
    "a1af86939f_F1BE1D4184OPENPIPELINE"
]





# List of specific tile IDs to select for detailed visualization (e.g., problematic chips)
# These are the full tile_id strings including x_y coordinates
specific_tile_ids = [
    # Group 1
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3456_1280", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3584_8320",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_2816", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3840_4736",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3968_384", "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_512",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_768", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_5888",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_5888", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_6016",

    # Group 2
    "1d4fbe33f3_F1BE1D4184INSPIRE_2560_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3584",
    "1d4fbe33f3_F1BE1D4184INSPIRE_768_3584", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3712",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4608",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4864",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4736", "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1280",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1408", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4736",
    "1d4fbe33f3_F1BE1D4184INSPIRE_384_1152",

    # Group 3
    "15efe45820_D95DF0B1F4INSPIRE_4736_9472", "15efe45820_D95DF0B1F4INSPIRE_9600_6016",
    "15efe45820_D95DF0B1F4INSPIRE_5888_6272", "15efe45820_D95DF0B1F4INSPIRE_7168_7936",
    "15efe45820_D95DF0B1F4INSPIRE_6016_6272", "15efe45820_D95DF0B1F4INSPIRE_8704_1024",
    "15efe45820_D95DF0B1F4INSPIRE_7040_6912", "15efe45820_D95DF0B1F4INSPIRE_8064_3968",
    "15efe45820_D95DF0B1F4INSPIRE_2688_2048", "15efe45820_D95DF0B1F4INSPIRE_7680_1920",
    "15efe45820_D95DF0B1F4INSPIRE_6272_10624", "15efe45820_D95DF0B1F4INSPIRE_6784_6784",
    "15efe45820_D95DF0B1F4INSPIRE_6528_8576",

    # Group 4
    "c6d131e346_536DE05ED2OPENPIPELINE_128_896", "c6d131e346_536DE05ED2OPENPIPELINE_256_768",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_896", "c6d131e346_536DE05ED2OPENPIPELINE_1792_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_1792_640", "c6d131e346_536DE05ED2OPENPIPELINE_256_640",
    "c6d131e346_536DE05ED2OPENPIPELINE_128_640", "c6d131e346_536DE05ED2OPENPIPELINE_128_128",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_128", "c6d131e346_536DE05ED2OPENPIPELINE_256_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2176",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2048", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2304",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2816_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2304_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2304_2688", "c6d131e346_536DE05ED2OPENPIPELINE_2432_2688",
    "c6d131e346_536DE05ED2OPENPIPELINE_2432_2560", "c6d131e346_536DE05ED2OPENPIPELINE_2176_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2176_2688",

    # Group 5
    "12fa5e614f_53197F206FOPENPIPELINE_384_3072", "12fa5e614f_53197F206FOPENPIPELINE_512_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3200", "12fa5e614f_53197F206FOPENPIPELINE_1024_3712",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3200", "12fa5e614f_53197F206FOPENPIPELINE_640_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3328", "12fa5e614f_53197F206FOPENPIPELINE_256_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_3200_1152", "12fa5e614f_53197F206FOPENPIPELINE_1152_2688",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2432", "12fa5e614f_53197F206FOPENPIPELINE_1280_2560",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2048", "12fa5e614f_53197F206FOPENPIPELINE_512_3840",
    "12fa5e614f_53197F206FOPENPIPELINE_512_3712", "12fa5e614f_53197F206FOPENPIPELINE_1664_2304",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3456", "12fa5e614f_53197F206FOPENPIPELINE_384_3328",
    "12fa5e614f_53197F206FOPENPIPELINE_1280_3584", "12fa5e614f_53197F206FOPENPIPELINE_384_3584",
    "12fa5e614f_53197F206FOPENPIPELINE_3072_1152", "12fa5e614f_53197F206FOPENPIPELINE_3456_1024",

    # Group 6
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_2688", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6784",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3712_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_3200_2688",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_4224_3072",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_4992", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3840_5888", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_4224",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5760_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_2816",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4352_4864", "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_6912",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4096_3328", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_3968",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5888_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_2432",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3584_2560", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5632",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5504", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5504",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5632", "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6400", "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6400",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5760", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6528", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5760",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4352",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5376", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5376",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6272",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6272", "5fa39d6378_DB9FF730D9OPENPIPELINE_1152_5760"
]





# --- Visualisation ---

'''
def visualise_prediction(rgb, true_mask_onehot, pred_mask):
    """
    Visualise prediction with magenta overlay for ignore regions.
    """
    IGNORE_COLOR = (255, 0, 255)
    CLASS_TO_COLOR = {
        0: (230, 25, 75),    # Building
        1: (145, 30, 180),   # Clutter
        2: (60, 180, 75),    # Vegetation
        3: (245, 130, 48),   # Water
        4: (255, 255, 255),  # Background
        5: (0, 130, 200),    # Car
    }

    ignore_mask = np.all(true_mask_onehot == 0, axis=-1)  # 🟣 Detect ignored pixels

    true_mask = np.argmax(true_mask_onehot, axis=-1)      # 🔢 Decode only after ignore_mask is made

    h, w = true_mask.shape
    true_rgb = np.zeros((h, w, 3), dtype=np.uint8)
    pred_rgb = np.zeros((h, w, 3), dtype=np.uint8)

    for cid, col in CLASS_TO_COLOR.items():
        true_rgb[true_mask == cid] = col
        pred_rgb[pred_mask == cid] = col

    # 🟣 Overlay ignored regions
    true_rgb[ignore_mask] = IGNORE_COLOR
    pred_rgb[ignore_mask] = IGNORE_COLOR

    # --- Plot ---
    fig, axs = plt.subplots(1, 3, figsize=(16, 5))
    axs[0].imshow(rgb)
    axs[0].set_title("RGB Image")
    axs[0].axis("off")
    axs[1].imshow(true_rgb)
    axs[1].set_title("True Mask")
    axs[1].axis("off")
    axs[2].imshow(pred_rgb)
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")
    plt.tight_layout()
    plt.show()
'''





# --- Class Weights ---
def compute_class_weights(generator):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)

    for _, labels in generator:
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts

    total = np.sum(pixel_counts)
    weights = total / (NUM_CLASSES * np.maximum(pixel_counts, 1))
    weights = weights / np.sum(weights) * NUM_CLASSES
    return tf.constant(weights, dtype=tf.float32)



# --- Distribution ---
def plot_class_distribution(class_counts, title="Class Distribution"):
    if isinstance(class_counts, dict):
        pixel_counts = [class_counts.get(i, 0) for i in range(NUM_CLASSES)]
    else:
        pixel_counts = list(class_counts)

    total_pixels = sum(pixel_counts)
    if total_pixels == 0:
        print("No pixels counted for class distribution plot.")
        return

    class_names = [f"{i}: {name}" for i, name in enumerate(CLASS_NAMES)]
    colours = [np.array(CLASS_TO_COLOR[i]) / 255.0 for i in range(NUM_CLASSES)]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_names, pixel_counts, color=colours, edgecolor='black')

    max_height = max(pixel_counts)
    plt.ylim(0, max_height * 1.15)  # Add 15% headroom above the tallest bar

    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Pixel Count")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)

    for bar, count in zip(bars, pixel_counts):
        percent = 100.0 * count / total_pixels
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + max_height * 0.01,
            f"{count:,}\n({percent:.2f}%)",
            ha='center',
            va='bottom',
            fontsize=9
        )

    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "epoch_dist.png"))
    plt.show()


def print_class_distribution(generator, title="Class Distribution", max_batches=32):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)
    total_batches = 0

    print(f"\n📊 {title}")
    print(f"🔍 Starting distribution scan (max {max_batches} batches)...")

    for i, (_, labels) in enumerate(generator):
        if labels.size == 0:
            print(f"⚠️ Skipping empty batch at index {i}")
            continue
        
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts
        total_batches += 1

        if total_batches >= max_batches:
            print(f"✅ Processed {total_batches} batches. Stopping early.")
            break

    total_pixels = np.sum(pixel_counts)
    if total_pixels == 0:
        print("❌ No valid pixels found.")
        return

    class_names = [
        "0: Building", "1: Clutter", "2: Vegetation",
        "3: Water", "4: Background", "5: Car"
    ]

    print(f"\n🧮 Total pixels processed: {total_pixels:,}")
    print("📈 Pixel Distribution (percentages):\n")
    for i in range(NUM_CLASSES):
        pct = (pixel_counts[i] / total_pixels) * 100
        print(f"Class {class_names[i]:<14} : {pct:6.2f}% ({pixel_counts[i]:,} px)")

    print("\n✅ Done.\n")












# =================================================================== 
# ------------------------------------------------------------------- 
# =================================================================== 
# ------------------------------------------------------------------- 
# =================================================================== 
# -------------------------------------------------------------------
# =================================================================== 



# Class weights for 6 classes (0 to 5), based on imbalance
weights = np.array([0.2374, 0.2374, 0.0356, 0.2374, 0.0148, 0.2374], dtype=np.float32)

# Dice Loss
def dice_loss(y_true, y_pred, smooth=1e-6):
    # Ensure inputs are float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Flatten predictions and ground truth
    y_true_f = tf.reshape(y_true, [-1, 6])  # Shape: [batch_size * pixels, 6]
    y_pred_f = tf.reshape(y_pred, [-1, 6])
    
    # Compute intersection and union per class
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)  # Sum over pixels, per class
    denominator = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0) + smooth
    
    # Dice score per class
    dice = (2.0 * intersection + smooth) / denominator
    
    # Apply class weights and compute mean loss
    weighted_dice = weights * dice
    dice_loss = 1.0 - tf.reduce_mean(weighted_dice)
    
    return dice_loss

# Categorical Focal Loss
def categorical_focal_loss(gamma=2.0):
    def focal_loss(y_true, y_pred):
        # Ensure inputs are float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Clip predictions to avoid log(0)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        
        # Compute cross-entropy
        ce = -y_true * tf.math.log(y_pred)
        
        # Focal factor: (1 - p_t)^gamma
        focal_factor = tf.pow(1.0 - y_pred, gamma)
        
        # Weighted focal loss
        loss = focal_factor * ce
        
        # Mean over classes and pixels
        return tf.reduce_mean(loss)
    
    return focal_loss

# Combined Loss
def combined_loss(y_true, y_pred):
    return dice_loss(y_true, y_pred) + categorical_focal_loss(gamma=2.0)(y_true, y_pred)

# Example: Compile model
# model = your_model  # e.g., U-Net with 6 classes
# model.compile(optimizer='adam', loss=combined_loss, metrics=['accuracy'])