In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import pandas as pd
import seaborn as sns
# Set segmentation models to use tf.keras backend
os.environ["SM_FRAMEWORK"] = "tf.keras"


# --- Constants ---
TILE_SIZE = 512  # Size of each training chip in pixels (512x512)
IGNORE_COLOR = (255, 0, 255) # The specific RGB color for ignored regions (magenta)

# Mapping from RGB color (as tuple) to integer class ID
COLOR_TO_CLASS = {
    (230, 25, 75): 0,    # Building
    (60, 180, 75): 1,    # Vegetation
    (0, 130, 200): 2,    # Water
    (255, 255, 255): 3,  # Background
    (245, 130, 48): 4,   # Car
    (128, 128, 128): 5,  # Road
}
CLASS_NAMES = ['Building', 'Vegetation', 'Water', 'Background', 'Car', 'Road']
class_cols = ['0: Building', '1: Vegetation', '2: Water', '3: Background', '4: Car', '5: Road']  # Column names for class pixel counts in the metadata CSV
#class_cols = {f"dist_{i}:{name}" for i, name in enumerate(CLASS_NAMES)}


# Inverse mapping from integer class ID to RGB color (as tuple)
NUM_CLASSES = len(COLOR_TO_CLASS)
CLASS_TO_COLOR = {v: k for k, v in COLOR_TO_CLASS.items() if v < NUM_CLASSES} # Ensure it matches NUM_CLASSES
COLOR_PALETTE = np.array(list(COLOR_TO_CLASS.keys()), dtype=np.uint8)
COLOR_LOOKUP = {tuple(c): i for c, i in COLOR_TO_CLASS.items()}


# --- Constants ---
INPUT_TYPE_CONFIG = {
    "rgb": {"description": "RGB only", "channels": 3},
    "rgb_elev": {"description": "RGB + elevation", "channels": 4},
    "rgb_elev_slope": {"description": "RGB + elevation + slope", "channels": 5}
}


base_dir="/content/chipped_data/chipped_data"
out_dir="/content/figs"

img_dir = os.path.join(base_dir, "train", "images")
elev_dir = os.path.join(base_dir, "train", "elevations")
slope_dir = os.path.join(base_dir, "train", "slopes")
label_dir = os.path.join(base_dir, "train", "labels")

metadata_path = os.path.join(base_dir, "train_metadata.csv")


# Create output directory if it doesn't exist
os.makedirs(out_dir, exist_ok=True)


# --- Jaccard Index (IoU) Metric ---
class MeanIoUMetric(tf.keras.metrics.MeanIoU):
    """Mean Intersection-over-Union (mIoU) metric for multi-class segmentation.

    This subclass overrides the update_state method to work directly with
    one-hot encoded ground truth and softmax predictions by converting them
    to integer class labels using `argmax`.
    """

    def __init__(self, num_classes: int, name: str = "mean_iou", dtype=None):
        """Initialises the MeanIoUMetric.

        Args:
            num_classes (int): Total number of segmentation classes.
            name (str): Optional name for the metric.
            dtype (Optional): Optional data type.
        """
        super().__init__(num_classes=num_classes, name=name, dtype=dtype)

    def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight=None):
        """Updates the confusion matrix with batch predictions.

        Args:
            y_true (tf.Tensor): One-hot encoded ground truth labels.
            y_pred (tf.Tensor): Softmax predictions from the model.
            sample_weight (Optional): Optional tensor for sample weights.

        Returns:
            tf.Operation: Update operation.
        """
        y_true = tf.argmax(y_true, axis=-1)
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight)


class TransformerLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    """Custom learning rate schedule based on the Transformer paper.

    The schedule increases the learning rate linearly for the first `warmup_steps`,
    and then decreases it proportionally to the inverse square root of the step number.

    This is commonly used in training Transformer models.

    Attributes:
        d_model (tf.Tensor): The dimensionality of the model.
        warmup_steps (tf.Tensor): Number of steps to linearly increase the learning rate.
    """

    def __init__(self, d_model: int, warmup_steps: int = 4000):
        """Initialises the TransformerLRSchedule.

        Args:
            d_model (int): The model dimensionality (e.g., 512).
            warmup_steps (int): Number of warm-up steps. Default is 4000.
        """
        super().__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = tf.cast(warmup_steps, tf.float32)

    def __call__(self, step: tf.Tensor) -> tf.Tensor:
        """Computes the learning rate at a given training step.

        Args:
            step (tf.Tensor): The current training step.

        Returns:
            tf.Tensor: The calculated learning rate for this step.
        """
        step = tf.cast(step, tf.float32)

        # Inverse square root decay and warmup scaling
        arg1 = tf.math.rsqrt(step)
        arg2 = step * tf.pow(self.warmup_steps, -1.5)

        # Apply the min schedule
        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

    def get_config(self) -> dict:
        """Returns the config of the learning rate schedule for serialization.

        Returns:
            dict: Configuration dictionary with d_model and warmup_steps.
        """
        return {
            "d_model": self.d_model.numpy(),
            "warmup_steps": self.warmup_steps.numpy(),
        }


def decode_label_image(label_img: np.ndarray) -> np.ndarray:
    """Converts a colour-coded label image into a class ID map.

    Args:
        label_img (np.ndarray): A (H, W, 3) RGB label image where each unique colour
            represents a class, and colours are defined in COLOR_LOOKUP.

    Returns:
        np.ndarray: A (H, W) array of class IDs corresponding to each pixel.

    Raises:
        ValueError: If an unknown colour is encountered in the label image.
    """
    h, w, _ = label_img.shape
    label_map = np.zeros((h, w), dtype=np.uint8)

    for y in range(h):
        for x in range(w):
            pixel = tuple(label_img[y, x])
            if pixel not in COLOR_LOOKUP:
                raise ValueError(f"Unknown label colour {pixel} at ({y}, {x})")
            label_map[y, x] = COLOR_LOOKUP[pixel]

    return label_map


def apply_label_smoothing(y_true: tf.Tensor, smoothing: float = 0.1) -> tf.Tensor:
    """Applies label smoothing to a one-hot encoded tensor.

    Args:
        y_true (tf.Tensor): One-hot encoded labels of shape (..., num_classes).
        smoothing (float): Smoothing factor in the range [0, 1].

    Returns:
        tf.Tensor: Smoothed label tensor of the same shape.
    """ 
    num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
    return y_true * (1.0 - smoothing) + (smoothing / num_classes)


def plot_augmented_grid_from_dataset(
    tf_dataset: tf.data.Dataset,
    input_type: str,  # 'rgb' or 'rgb_elev'
    n_rows: int = 3,
    n_cols: int = 4,
    title: str = "Augmented Training Chips"
):
    """
    Plots a grid of RGB + label masks from a tf.data.Dataset, using CLASS_TO_COLOR for display.
    Layout matches the style of visualise_prediction_grid, but without predictions.
    """
    import matplotlib.pyplot as plt
    import numpy as np

    print(f"Fetching one batch for {n_rows * n_cols} RGB + label pairs...")

    try:
        batch = next(iter(tf_dataset.take(1)), None)

        if batch is None:
            print("Warning: Dataset is empty. Skipping plot.")
            return

        images, labels_one_hot = batch
        rgb_images = images[:, :, :, :3].numpy() if input_type == 'rgb_elev' else images.numpy()
        rgb_images = np.clip(rgb_images, 0.0, 1.0)

        label_ids = tf.argmax(labels_one_hot, axis=-1).numpy()
        ignore_mask = tf.reduce_all(labels_one_hot == 0, axis=-1).numpy()

        total = n_rows * n_cols
        batch_size = rgb_images.shape[0]
        total = min(total, batch_size)

        fig, axs = plt.subplots(n_rows, n_cols * 2, figsize=(n_cols * 5.6, n_rows * 2.6))

        for i in range(total):
            rgb = rgb_images[i]
            mask = label_ids[i]
            ignore = ignore_mask[i]

            # Create RGB mask image from class IDs
            label_rgb = np.zeros((*mask.shape, 3), dtype=np.uint8)
            for class_id, color in CLASS_TO_COLOR.items():
                label_rgb[mask == class_id] = color
            label_rgb[ignore] = (255, 0, 255)  # Magenta for ignored pixels

            row = i // n_cols
            col = (i % n_cols) * 2

            axs[row, col].imshow(rgb)
            axs[row, col].set_title("RGB")
            axs[row, col].axis("off")

            axs[row, col + 1].imshow(label_rgb)
            axs[row, col + 1].set_title("Label")
            axs[row, col + 1].axis("off")

        # Hide any unused axes
        for j in range(total, n_rows * n_cols):
            row = j // n_cols
            col = (j % n_cols) * 2
            axs[row, col].axis("off")
            axs[row, col + 1].axis("off")

        #plt.suptitle(title, fontsize=16, y=1.02)
        plt.tight_layout()
        plt.show()
        plt.close(fig)

    except Exception as e:
        print(f"Error during plotting: {e}")





def filter_tile_ids_by_substring(image_dir: str, base_names: list[str]) -> list[str]:
    """Filters filenames in a directory to extract tile IDs that match given substrings.

    Looks for files ending in '-ortho.png' and checks if any given base name is present
    in the filename. If matched, strips the '-ortho.png' suffix and returns the base ID.

    Args:
        image_dir (str): Path to the directory containing image files.
        base_names (list[str]): List of substrings to match in the filenames.

    Returns:
        list[str]: A list of tile IDs (filenames without '-ortho.png') that contain
        any of the specified base substrings.
    """
    return [
        f.replace('-ortho.png', '')
        for f in os.listdir(image_dir)
        if any(base in f for base in base_names)
    ]




# --- Global Configuration and Constants ---

# NEW: Explicitly define which scenes belong to which dataset split.
# The csv_to_df function will now ONLY use files from these lists.
TRAIN_SCENES = [
    '1726eb08ef_60693DB04DINSPIRE', '1df70e7340_4413A67E91INSPIRE', '2552eb56dd_2AABB46C86OPENPIPELINE', 
    '25f1c24f30_EB81FE6E2BOPENPIPELINE', '2ef883f08d_F317F9C1DFOPENPIPELINE',
    '420d6b69b8_84B52814D2OPENPIPELINE', '520947aa07_8FCB044F58OPENPIPELINE', 
    '57426ebe1e_84B52814D2OPENPIPELINE', '5fa39d6378_DB9FF730D9OPENPIPELINE', '6f93b9026b_F1BFB8B17DOPENPIPELINE', 
    '74d7796531_EB81FE6E2BOPENPIPELINE', '84410645db_8D20F02042OPENPIPELINE', 
    '888432f840_80E7FD39EBINSPIRE', '9170479165_625EDFBAB6OPENPIPELINE', 'a1af86939f_F1BE1D4184OPENPIPELINE', 
    'b61673f780_4413A67E91INSPIRE', 'b705d0cc9c_E5F5E0E316OPENPIPELINE', 'b771104de5_7E02A41EBEOPENPIPELINE',
    'c37dbfae2f_84B52814D2OPENPIPELINE', 'c6d131e346_536DE05ED2OPENPIPELINE', 'c8a7031e5f_32156F5DC2INSPIRE', 
    'cc4b443c7d_A9CBEF2C97INSPIRE', 'd06b2c67d2_2A62B67B52OPENPIPELINE', 'd9161f7e18_C05BA1BC72OPENPIPELINE', 
    'e87da4ebdb_29FEA32BC7INSPIRE', 'ebffe540d0_7BA042D858OPENPIPELINE', 'ec09336a6f_06BA0AF311OPENPIPELINE', 
    'f0747ed88d_E74C0DD8FDOPENPIPELINE', 'f4dd768188_NOLANOPENPIPELINE', 'f56b6b2232_2A62B67B52OPENPIPELINE', 
    'f971256246_MIKEINSPIRE'                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
]

VAL_SCENES = [
    "c644f91210_27E21B7F30OPENPIPELINE", "f9f43e5144_1DB9E6F68BINSPIRE", '8710b98ea0_06E6522D6DINSPIRE', '84410645db_8D20F02042OPENPIPELINE',
    '1553541487_APIGENERATED', '1553541585_APIGENERATED', 'fc5837dcf8_7CD52BE09EINSPIRE','1d056881e8_29FEA32BC7INSPIRE'
    'ec09336a6f_06BA0AF311OPENPIPELINE',
    '15efe45820_D95DF0B1F4INSPIRE',
    '107f24d6e9_F1BE1D4184INSPIRE',
    '1d4fbe33f3_F1BE1D4184INSPIRE',
    '2ef3a4994a_0CCD105428INSPIRE',
    '34fbf7c2bd_E8AD935CEDINSPIRE',
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "3502e187b2_23071E4605OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE",
    "551063e3c5_8FCB044F58INSPIRE",
    "39e77bedd0_729FB913CDOPENPIPELINE",
    'c2e8370ca3_3340CAC7AEOPENPIPELINE', 
]

TEST_SCENES = [
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    '1476907971_CHADGRISMOPENPIPELINE',
    '6f93b9026b_F1BFB8B17DOPENPIPELINE',
    '12fa5e614f_53197F206FOPENPIPELINE',
    '11cdce7802_B6A62F8BE0INSPIRE',
    '7c719dfcc0_310490364FINSPIRE',
    '7008b80b00_FF24A4975DINSPIRE',     
    'dabec5e872_E8AD935CEDINSPIRE',
    '130a76ebe1_68B40B480AOPENPIPELINE',
    'fc5837dcf8_7CD52BE09EINSPIRE',
    'f9f43e5144_1DB9E6F68BINSPIRE',
    '1553627230_APIGENERATED',
]




# List of all available file prefixes (tile_ids without coordinate suffix)
# These are used to identify the source files for splitting the dataset.
all_files = [
    '107f24d6e9_F1BE1D4184INSPIRE', '11cdce7802_B6A62F8BE0INSPIRE', '12fa5e614f_53197F206FOPENPIPELINE', '130a76ebe1_68B40B480AOPENPIPELINE', 
    '1476907971_CHADGRISMOPENPIPELINE', '1553541487_APIGENERATED', '1553541585_APIGENERATED', '1553627230_APIGENERATED', '15efe45820_D95DF0B1F4INSPIRE', 
    '1726eb08ef_60693DB04DINSPIRE', '1d056881e8_29FEA32BC7INSPIRE', '1d4fbe33f3_F1BE1D4184INSPIRE', '1df70e7340_4413A67E91INSPIRE', '2552eb56dd_2AABB46C86OPENPIPELINE', 
    '25f1c24f30_EB81FE6E2BOPENPIPELINE', '2ef3a4994a_0CCD105428INSPIRE', '2ef883f08d_F317F9C1DFOPENPIPELINE', '34fbf7c2bd_E8AD935CEDINSPIRE', 
    '3502e187b2_23071E4605OPENPIPELINE', '39e77bedd0_729FB913CDOPENPIPELINE', '420d6b69b8_84B52814D2OPENPIPELINE', '520947aa07_8FCB044F58OPENPIPELINE', 
    '551063e3c5_8FCB044F58INSPIRE', '57426ebe1e_84B52814D2OPENPIPELINE', '5fa39d6378_DB9FF730D9OPENPIPELINE', '6f93b9026b_F1BFB8B17DOPENPIPELINE', 
    '7008b80b00_FF24A4975DINSPIRE', '74d7796531_EB81FE6E2BOPENPIPELINE', '7c719dfcc0_310490364FINSPIRE', '84410645db_8D20F02042OPENPIPELINE', 
    '8710b98ea0_06E6522D6DINSPIRE', '888432f840_80E7FD39EBINSPIRE', '9170479165_625EDFBAB6OPENPIPELINE', 'a1af86939f_F1BE1D4184OPENPIPELINE', 
    'b61673f780_4413A67E91INSPIRE', 'b705d0cc9c_E5F5E0E316OPENPIPELINE', 'b771104de5_7E02A41EBEOPENPIPELINE', 'c2e8370ca3_3340CAC7AEOPENPIPELINE', 
    'c37dbfae2f_84B52814D2OPENPIPELINE', 'c644f91210_27E21B7F30OPENPIPELINE', 'c6d131e346_536DE05ED2OPENPIPELINE', 'c8a7031e5f_32156F5DC2INSPIRE', 
    'cc4b443c7d_A9CBEF2C97INSPIRE', 'd06b2c67d2_2A62B67B52OPENPIPELINE', 'd9161f7e18_C05BA1BC72OPENPIPELINE', 'dabec5e872_E8AD935CEDINSPIRE', 
    'e87da4ebdb_29FEA32BC7INSPIRE', 'ebffe540d0_7BA042D858OPENPIPELINE', 'ec09336a6f_06BA0AF311OPENPIPELINE', 
    'f0747ed88d_E74C0DD8FDOPENPIPELINE', 'f4dd768188_NOLANOPENPIPELINE', 'f56b6b2232_2A62B67B52OPENPIPELINE', 
    'f971256246_MIKEINSPIRE', 'f9f43e5144_1DB9E6F68BINSPIRE', 'fc5837dcf8_7CD52BE09EINSPIRE'
]

# Prefixes for files designated for the validation set
val_files = [
    "c644f91210_27E21B7F30OPENPIPELINE",
    "f9f43e5144_1DB9E6F68BINSPIRE",
    "1d056881e8_29FEA32BC7INSPIRE",
    "3502e187b2_23071E4605OPENPIPELINE",
    "d9161f7e18_C05BA1BC72OPENPIPELINE",
    "c8a7031e5f_32156F5DC2INSPIRE",
    "551063e3c5_8FCB044F58INSPIRE",
    "fc5837dcf8_7CD52BE09EINSPIRE",
    "39e77bedd0_729FB913CDOPENPIPELINE",
]

# Prefixes for files designated for the test set
test_files = [
    "25f1c24f30_EB81FE6E2BOPENPIPELINE",
    "1d4fbe33f3_F1BE1D4184INSPIRE",
    "15efe45820_D95DF0B1F4INSPIRE",
    "c6d131e346_536DE05ED2OPENPIPELINE",
    "12fa5e614f_53197F206FOPENPIPELINE",
    "5fa39d6378_DB9FF730D9OPENPIPELINE",
    "ebffe540d0_7BA042D858OPENPIPELINE",
    "8710b98ea0_06E6522D6DINSPIRE",
    "84410645db_8D20F02042OPENPIPELINE",
    "a1af86939f_F1BE1D4184OPENPIPELINE"
]


# List of specific tile IDs to select for detailed visualization (e.g., problematic chips)
# These are the full tile_id strings including x_y coordinates
specific_tile_ids = [
    # Group 1
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3456_1280", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3584_8320",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_2816", "25f1c24f30_EB81FE6E2BOPENPIPELINE_3840_4736",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_3968_384", "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_512",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_4736_768", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_5888",
    "25f1c24f30_EB81FE6E2BOPENPIPELINE_896_5888", "25f1c24f30_EB81FE6E2BOPENPIPELINE_1024_6016",

    # Group 2
    "1d4fbe33f3_F1BE1D4184INSPIRE_2560_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3584",
    "1d4fbe33f3_F1BE1D4184INSPIRE_768_3584", "1d4fbe33f3_F1BE1D4184INSPIRE_896_3712",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4608",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_2432", "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4864",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1664_4736", "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1280",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1152_4864", "1d4fbe33f3_F1BE1D4184INSPIRE_1280_2432",
    "1d4fbe33f3_F1BE1D4184INSPIRE_1408_1408", "1d4fbe33f3_F1BE1D4184INSPIRE_1536_4736",
    "1d4fbe33f3_F1BE1D4184INSPIRE_384_1152",

    # Group 3
    "15efe45820_D95DF0B1F4INSPIRE_4736_9472", "15efe45820_D95DF0B1F4INSPIRE_9600_6016",
    "15efe45820_D95DF0B1F4INSPIRE_5888_6272", "15efe45820_D95DF0B1F4INSPIRE_7168_7936",
    "15efe45820_D95DF0B1F4INSPIRE_6016_6272", "15efe45820_D95DF0B1F4INSPIRE_8704_1024",
    "15efe45820_D95DF0B1F4INSPIRE_7040_6912", "15efe45820_D95DF0B1F4INSPIRE_8064_3968",
    "15efe45820_D95DF0B1F4INSPIRE_2688_2048", "15efe45820_D95DF0B1F4INSPIRE_7680_1920",
    "15efe45820_D95DF0B1F4INSPIRE_6272_10624", "15efe45820_D95DF0B1F4INSPIRE_6784_6784",
    "15efe45820_D95DF0B1F4INSPIRE_6528_8576",

    # Group 4
    "c6d131e346_536DE05ED2OPENPIPELINE_128_896", "c6d131e346_536DE05ED2OPENPIPELINE_256_768",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_896", "c6d131e346_536DE05ED2OPENPIPELINE_1792_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_1792_640", "c6d131e346_536DE05ED2OPENPIPELINE_256_640",
    "c6d131e346_536DE05ED2OPENPIPELINE_128_640", "c6d131e346_536DE05ED2OPENPIPELINE_128_128",
    "c6d131e346_536DE05ED2OPENPIPELINE_256_128", "c6d131e346_536DE05ED2OPENPIPELINE_256_512",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2176",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2048", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2688_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2560_2304",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2176", "c6d131e346_536DE05ED2OPENPIPELINE_2816_2048",
    "c6d131e346_536DE05ED2OPENPIPELINE_2816_2304", "c6d131e346_536DE05ED2OPENPIPELINE_2304_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2304_2688", "c6d131e346_536DE05ED2OPENPIPELINE_2432_2688",
    "c6d131e346_536DE05ED2OPENPIPELINE_2432_2560", "c6d131e346_536DE05ED2OPENPIPELINE_2176_2560",
    "c6d131e346_536DE05ED2OPENPIPELINE_2176_2688",

    # Group 5
    "12fa5e614f_53197F206FOPENPIPELINE_384_3072", "12fa5e614f_53197F206FOPENPIPELINE_512_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3200", "12fa5e614f_53197F206FOPENPIPELINE_1024_3712",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3200", "12fa5e614f_53197F206FOPENPIPELINE_640_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_256_3328", "12fa5e614f_53197F206FOPENPIPELINE_256_3072",
    "12fa5e614f_53197F206FOPENPIPELINE_3200_1152", "12fa5e614f_53197F206FOPENPIPELINE_1152_2688",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2432", "12fa5e614f_53197F206FOPENPIPELINE_1280_2560",
    "12fa5e614f_53197F206FOPENPIPELINE_1536_2048", "12fa5e614f_53197F206FOPENPIPELINE_512_3840",
    "12fa5e614f_53197F206FOPENPIPELINE_512_3712", "12fa5e614f_53197F206FOPENPIPELINE_1664_2304",
    "12fa5e614f_53197F206FOPENPIPELINE_384_3456", "12fa5e614f_53197F206FOPENPIPELINE_384_3328",
    "12fa5e614f_53197F206FOPENPIPELINE_1280_3584", "12fa5e614f_53197F206FOPENPIPELINE_384_3584",
    "12fa5e614f_53197F206FOPENPIPELINE_3072_1152", "12fa5e614f_53197F206FOPENPIPELINE_3456_1024",

    # Group 6
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_2688", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6784",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3712_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_3200_2688",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_2816", "5fa39d6378_DB9FF730D9OPENPIPELINE_4224_3072",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_4992", "5fa39d6378_DB9FF730D9OPENPIPELINE_1024_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3840_5888", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_4224",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5760_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_3328_2816",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4352_4864", "5fa39d6378_DB9FF730D9OPENPIPELINE_3072_6912",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4096_3328", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_3968",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_5888_1920", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_2432",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_3584_2560", "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5632",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5504", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5504",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5632", "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6400", "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6400",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5760", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6528",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6528", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5760",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4608_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4352",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4608", "5fa39d6378_DB9FF730D9OPENPIPELINE_4480_4480",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_1280_5376", "5fa39d6378_DB9FF730D9OPENPIPELINE_1408_5376",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_4736_4352", "5fa39d6378_DB9FF730D9OPENPIPELINE_2816_6272",
    "5fa39d6378_DB9FF730D9OPENPIPELINE_2944_6272", "5fa39d6378_DB9FF730D9OPENPIPELINE_1152_5760"
]





# --- Class Weights ---
def compute_class_weights(generator):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)

    for _, labels in generator:
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts

    total = np.sum(pixel_counts)
    weights = total / (NUM_CLASSES * np.maximum(pixel_counts, 1))
    weights = weights / np.sum(weights) * NUM_CLASSES
    return tf.constant(weights, dtype=tf.float32)



# --- Distribution ---
def plot_class_distribution(class_counts, title="Class Distribution"):
    if isinstance(class_counts, dict):
        pixel_counts = [class_counts.get(i, 0) for i in range(NUM_CLASSES)]
    else:
        pixel_counts = list(class_counts)

    total_pixels = sum(pixel_counts)
    if total_pixels == 0:
        print("No pixels counted for class distribution plot.")
        return

    class_names = [f"{i}: {name}" for i, name in enumerate(CLASS_NAMES)]
    colours = [np.array(CLASS_TO_COLOR[i]) / 255.0 for i in range(NUM_CLASSES)]

    plt.figure(figsize=(10, 5))
    bars = plt.bar(class_names, pixel_counts, color=colours, edgecolor='black')

    max_height = max(pixel_counts)
    plt.ylim(0, max_height * 1.15)  # Add 15% headroom above the tallest bar

    plt.title(title)
    plt.xlabel("Class")
    plt.ylabel("Pixel Count")
    plt.grid(True, axis='y', linestyle='--', alpha=0.5)

    for bar, count in zip(bars, pixel_counts):
        percent = 100.0 * count / total_pixels
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + max_height * 0.01,
            f"{count:,}\n({percent:.2f}%)",
            ha='center',
            va='bottom',
            fontsize=9
        )

    plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "epoch_dist.png"))
    plt.show()


def print_class_distribution(generator, title="Class Distribution", max_batches=32):
    pixel_counts = np.zeros(NUM_CLASSES, dtype=np.int64)
    total_batches = 0

    print(f"\n{title}")
    print(f"Starting distribution scan (max {max_batches} batches)...")

    for i, (_, labels) in enumerate(generator):
        if labels.size == 0:
            print(f"Skipping empty batch at index {i}")
            continue
        
        flat = np.argmax(labels, axis=-1).flatten()
        counts = np.bincount(flat, minlength=NUM_CLASSES)
        pixel_counts[:len(counts)] += counts
        total_batches += 1

        if total_batches >= max_batches:
            print(f"Processed {total_batches} batches. Stopping early.")
            break

    total_pixels = np.sum(pixel_counts)
    if total_pixels == 0:
        print("No valid pixels found.")
        return

    class_names = [
        "0: Building", "1: Clutter", "2: Vegetation",
        "3: Water", "4: Background", "5: Car"
    ]

    print(f"\nTotal pixels processed: {total_pixels:,}")
    print("Pixel Distribution (percentages):\n")
    for i in range(NUM_CLASSES):
        pct = (pixel_counts[i] / total_pixels) * 100
        print(f"Class {class_names[i]:<14} : {pct:6.2f}% ({pixel_counts[i]:,} px)")

    print("\nDone.\n")












# --- Loss Functions (Unused) ---

weights = np.array([0.1666, 0.1666, 0.1666, 0.1666, 0.1666, 0.1666], dtype=np.float32)

# Dice Loss
def dice_loss(y_true, y_pred, smooth=1e-6):
    # Ensure inputs are float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Flatten predictions and ground truth
    y_true_f = tf.reshape(y_true, [-1, 6])  # Shape: [batch_size * pixels, 6]
    y_pred_f = tf.reshape(y_pred, [-1, 6])
    
    # Compute intersection and union per class
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)  # Sum over pixels, per class
    denominator = tf.reduce_sum(y_true_f, axis=0) + tf.reduce_sum(y_pred_f, axis=0) + smooth
    
    # Dice score per class
    dice = (2.0 * intersection + smooth) / denominator
    
    # Apply class weights and compute mean loss
    weighted_dice = weights * dice
    dice_loss = 1.0 - tf.reduce_mean(weighted_dice)
    
    return dice_loss

# Categorical Focal Loss
def categorical_focal_loss(gamma=2.0):
    def focal_loss(y_true, y_pred):
        # Ensure inputs are float32
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        # Clip predictions to avoid log(0)
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        
        # Compute cross-entropy
        ce = -y_true * tf.math.log(y_pred)
        
        # Focal factor: (1 - p_t)^gamma
        focal_factor = tf.pow(1.0 - y_pred, gamma)
        
        # Weighted focal loss
        loss = focal_factor * ce
        
        # Mean over classes and pixels
        return tf.reduce_mean(loss)
    
    return focal_loss

