## Quick Configuration Guide

### Configuration Flags (Set at the top of script)

```python
    TRAIN_IDENTIFIER: bool = True
    TRAIN_CLASSIFIER: bool = True
    RUN_EVALUATION: bool = True
    VISUALIZE_PREDICTIONS: bool = True # Set to True to save example images
    NUM_VISUALIZATION_IMAGES: int = 10 # Number of test images to visualize

    # --- Model Paths ---
    IDENTIFIER_MODEL_PATH: str = 'identifier_model.h5'
    CLASSIFIER_MODEL_PATH: str = 'best_classifier.h5'
    OUTPUT_DIR: str = 'pipeline_output' # Directory for plots and images

    # --- Training Parameters ---
    IDENTIFIER_EPOCHS: int = 5
    CLASSIFIER_EPOCHS: int = 15
    IDENTIFIER_IMG_SIZE: int = 800
    CLASSIFIER_IMG_SIZE: int = 224
    VALIDATION_SPLIT: float = 0.2

    # --- Pipeline Parameters ---
    DETECTION_THRESHOLD: float = 0.3      # Minimum confidence for a detected object
    EVAL_IOU_THRESHOLD: float = 0.5       # IoU threshold to match pred/GT boxes

    # --- Dataset Path ---
    # Ensure this points to the correct directory
    DATA_ROOT: str = '/kaggle/input/2025-karyogram-cv-camp/2025_Karyogram_CV_Camp'
```


In [None]:
# ==============================================================================
#      ENHANCED CHROMOSOME DETECTION & CLASSIFICATION PIPELINE
#
# Key Improvements:
# - Config Class: Centralized configuration for easier tuning.
# - Memory-Efficient Generator: Classifier data generator loads images on-the-fly,
#   drastically reducing RAM usage.
# - Robust Evaluation: Uses IoU matching for a more accurate assessment of model
#   performance on the test set.
# - Added Visualizations:
#   1. Training history plots (Accuracy/Loss).
#   2. Detailed confusion matrix for classifier performance.
#   3. Output images with predicted and ground-truth bounding boxes.
#
# Dependencies:
# pip install tensorflow scikit-learn matplotlib seaborn pandas albumentations
# ==============================================================================

import os
import sys
import time
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from itertools import cycle
from typing import List, Tuple, Dict, Any

# Suppress TensorFlow logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    balanced_accuracy_score
)
from sklearn.preprocessing import label_binarize
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import ResNet50

try:
    import albumentations as A
except ImportError:
    print("Installing albumentations...")
    os.system('pip install -q albumentations')
    import albumentations as A

warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')
tf.config.optimizer.set_jit(False)
tf.keras.backend.clear_session()


# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================

class Config:
    """Holds all configuration parameters for the pipeline."""
    # --- Execution Control ---
    TRAIN_IDENTIFIER: bool = True
    TRAIN_CLASSIFIER: bool = True
    RUN_EVALUATION: bool = True
    VISUALIZE_PREDICTIONS: bool = True # Set to True to save example images
    NUM_VISUALIZATION_IMAGES: int = 10 # Number of test images to visualize

    # --- Model Paths ---
    IDENTIFIER_MODEL_PATH: str = 'identifier_model.h5'
    CLASSIFIER_MODEL_PATH: str = 'best_classifier.h5'
    OUTPUT_DIR: str = 'pipeline_output' # Directory for plots and images

    # --- Training Parameters ---
    IDENTIFIER_EPOCHS: int = 5
    CLASSIFIER_EPOCHS: int = 15
    IDENTIFIER_IMG_SIZE: int = 800
    CLASSIFIER_IMG_SIZE: int = 224
    VALIDATION_SPLIT: float = 0.2

    # --- Pipeline Parameters ---
    DETECTION_THRESHOLD: float = 0.3      # Minimum confidence for a detected object
    EVAL_IOU_THRESHOLD: float = 0.5       # IoU threshold to match pred/GT boxes

    # --- Dataset Path ---
    # Ensure this points to the correct directory
    DATA_ROOT: str = '/kaggle/input/2025-karyogram-cv-camp/2025_Karyogram_CV_Camp'

    # --- Class Mapping ---
    CLASS_MAP: Dict[str, int] = {
        '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5,
        '7': 6, '8': 7, '9': 8, '10': 9, '11': 10, '12': 11,
        '13': 12, '14': 13, '15': 14, '16': 15, '17': 16, '18': 17,
        '19': 18, '20': 19, '21': 20, '22': 21, 'X': 22, 'Y': 23
    }
    CLASS_NAMES: List[str] = [name for name, _ in sorted(CLASS_MAP.items(), key=lambda item: item[1])]

# ==============================================================================
# 2. DEVICE SETUP
# ==============================================================================

def setup_device_strategy() -> Tuple[tf.distribute.Strategy, str]:
    """Detects and initializes the appropriate hardware (TPU, GPU, CPU)."""
    print("="*60 + "\nINITIALIZING DEVICE\n" + "="*60)
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
        strategy = tf.distribute.TPUStrategy(tpu)
        device_type = "TPU"
        print(f"✅ TPU detected and initialized with {strategy.num_replicas_in_sync} cores.")
        return strategy, device_type
    except (ValueError, tf.errors.NotFoundError):
        print("INFO: No TPU detected.")

    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        if len(gpus) > 1:
            strategy = tf.distribute.MirroredStrategy()
            device_type = "Multi-GPU"
            print(f"✅ Multi-GPU detected with {strategy.num_replicas_in_sync} devices.")
        else:
            strategy = tf.distribute.get_strategy()
            device_type = "Single-GPU"
            print("✅ Single GPU detected.")
        return strategy, device_type

    print("✅ No GPU or TPU detected. Falling back to CPU.")
    return tf.distribute.get_strategy(), "CPU"

STRATEGY, DEVICE_TYPE = setup_device_strategy()

# ==============================================================================
# 3. UTILITIES & VISUALIZATION
# ==============================================================================

class AnnotationParser:
    """Parses XML annotation files."""
    @staticmethod
    def parse_xml(xml_path: Path) -> Dict[str, Any]:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        annotations = {'objects': []}
        for obj in root.findall('object'):
            bndbox = obj.find('bndbox')
            if bndbox is not None:
                annotations['objects'].append({
                    'name': obj.find('name').text,
                    'bbox': [
                        int(float(bndbox.find('xmin').text)),
                        int(float(bndbox.find('ymin').text)),
                        int(float(bndbox.find('xmax').text)),
                        int(float(bndbox.find('ymax').text))
                    ]
                })
        return annotations

def compute_iou(boxA: List[int], boxB: List[int]) -> float:
    """Calculates Intersection over Union (IoU) between two bounding boxes."""
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    inter_area = max(0, xB - xA) * max(0, yB - yA)
    boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    union_area = float(boxA_area + boxB_area - inter_area)
    return inter_area / union_area if union_area > 0 else 0.0

def plot_training_history(history: keras.callbacks.History, save_path: Path):
    """Plots and saves the training/validation accuracy and loss curves."""
    pd.DataFrame(history.history).plot(figsize=(12, 8))
    plt.grid(True)
    plt.gca().set_ylim(0, 1)
    plt.title("Classifier Training History")
    plt.xlabel("Epoch")
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Training history saved to '{save_path}'")

def plot_confusion_matrix(y_true: List, y_pred: List, class_names: List, save_path: Path):
    """Plots and saves a confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(18, 15))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    print(f"✅ Confusion matrix saved to '{save_path}'")

def visualize_predictions(image_path: Path, gt_boxes: List, predictions: List, save_path: Path, iou_thresh: float):
    """Draws ground truth and predicted boxes on an image and saves it."""
    image = cv2.imread(str(image_path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    fig, ax = plt.subplots(1, figsize=(15, 15))
    ax.imshow(image)
    ax.axis('off')

    # Draw Ground Truth Boxes (Green)
    for gt in gt_boxes:
        xmin, ymin, xmax, ymax = gt['bbox']
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin - 10, f"GT: {gt['name']}", color='g', fontsize=12, weight='bold')

    # Draw Predicted Boxes (Red)
    for pred in predictions:
        xmin, ymin, xmax, ymax = map(int, pred['bbox'])
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        label = f"Pred: {pred['class_name']} ({pred['combined_score']:.2f})"
        ax.text(xmin, ymax + 20, label, color='r', fontsize=12, weight='bold')

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
    plt.close()


# ==============================================================================
# 4. IDENTIFIER MODULE (Chromosome Detector)
# ==============================================================================

class ChromosomeIdentifierDataGenerator(keras.utils.Sequence):
    """Generates batches of data for the identifier model."""
    def __init__(self, images_dir: Path, annotations_dir: Path, batch_size: int, img_size: int, shuffle: bool = True):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.parser = AnnotationParser()
        self.image_list = sorted([p.stem for p in self.images_dir.glob('*.jpg')] + [p.stem for p in self.images_dir.glob('*.png')])
        self.indexes = np.arange(len(self.image_list))
        self.on_epoch_end()
        print(f"IdentifierDataGenerator: Found {len(self.image_list)} images.")

    def __len__(self):
        return int(np.ceil(len(self.image_list) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_images = [self.image_list[k] for k in batch_indexes]
        images, targets = self._generate_data(batch_images)
        return np.array(images), targets

    def on_epoch_end(self):
        if self.shuffle: np.random.shuffle(self.indexes)

    def _generate_data(self, batch_images):
        images, targets = [], []
        for img_name in batch_images:
            img_path = self.images_dir / f"{img_name}.jpg"
            if not img_path.exists(): img_path = self.images_dir / f"{img_name}.png"
            xml_path = self.annotations_dir / f"{img_name}.xml"

            if not img_path.exists() or not xml_path.exists(): continue

            image = cv2.imread(str(img_path))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            orig_h, orig_w = image.shape[:2]

            image_resized = cv2.resize(image, (self.img_size, self.img_size))
            image_resized = image_resized.astype(np.float32) / 255.0

            annotations = self.parser.parse_xml(xml_path)
            boxes = []
            for obj in annotations['objects']:
                bbox = obj['bbox']
                boxes.append([bbox[1]/orig_h, bbox[0]/orig_w, bbox[3]/orig_h, bbox[2]/orig_w]) # ymin, xmin, ymax, xmax

            images.append(image_resized)
            targets.append(np.array(boxes) if boxes else np.zeros((0, 4)))

        return images, targets


class ChromosomeIdentifier:
    """The chromosome detection model (based on ResNet50)."""
    def __init__(self, img_size: int = 800, max_detections: int = 50):
        self.img_size = img_size
        self.max_detections = max_detections
        self.model = None

    def _build_model(self):
        inputs = layers.Input(shape=(self.img_size, self.img_size, 3))
        base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=inputs)
        for layer in base_model.layers[:100]: layer.trainable = False
        x = base_model.output
        x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(1024, activation='relu')(x)
        x = layers.Dropout(0.5)(x)
        bbox_output = layers.Dense(self.max_detections * 4, name='bbox')(x)
        objectness = layers.Dense(self.max_detections, activation='sigmoid', name='objectness')(x)
        return models.Model(inputs=inputs, outputs=[bbox_output, objectness])

    def train(self, train_gen: ChromosomeIdentifierDataGenerator, epochs: int, strategy: tf.distribute.Strategy):
        print("\n" + "="*60 + "\n🚀 STARTING IDENTIFIER TRAINING\n" + "="*60)
        with strategy.scope():
            self.model = self._build_model()
            optimizer = optimizers.Adam(0.001)

        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")
            epoch_losses = []
            for batch_idx in range(len(train_gen)):
                images, targets = train_gen[batch_idx]
                if images.size == 0: continue

                with tf.GradientTape() as tape:
                    bbox_pred, obj_pred = self.model(images, training=True)
                    bbox_pred = tf.reshape(bbox_pred, (-1, self.max_detections, 4))

                    batch_loss = 0
                    for i, target_boxes in enumerate(targets):
                        num_objs = min(len(target_boxes), self.max_detections)
                        if num_objs > 0:
                            target_padded = np.zeros((self.max_detections, 4))
                            target_padded[:num_objs] = target_boxes[:num_objs]
                            bbox_loss = tf.reduce_mean(tf.square(bbox_pred[i] - target_padded))
                            obj_target = np.zeros(self.max_detections)
                            obj_target[:num_objs] = 1
                            obj_loss = tf.keras.losses.binary_crossentropy(obj_target, obj_pred[i])
                            batch_loss += bbox_loss + 0.5 * tf.reduce_mean(obj_loss)
                    batch_loss /= len(targets)

                grads = tape.gradient(batch_loss, self.model.trainable_variables)
                optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
                epoch_losses.append(float(batch_loss))
                if (batch_idx + 1) % 10 == 0:
                    print(f"  Batch {batch_idx+1}/{len(train_gen)}, Loss: {np.mean(epoch_losses[-10:]):.4f}")
            print(f"Epoch {epoch+1} Average Loss: {np.mean(epoch_losses):.4f}")
        print("\n✅ Identifier training complete.")
        return self.model

    def predict_boxes(self, image: np.ndarray, confidence_threshold: float) -> Tuple[np.ndarray, np.ndarray]:
        orig_h, orig_w = image.shape[:2]
        image_resized = cv2.resize(image, (self.img_size, self.img_size))
        image_resized = np.expand_dims(image_resized.astype(np.float32) / 255.0, axis=0)
        bbox_pred, obj_pred = self.model.predict(image_resized, verbose=0)
        bbox_pred = bbox_pred.reshape(self.max_detections, 4)
        obj_pred = obj_pred.flatten()

        keep_indices = obj_pred > confidence_threshold
        boxes = bbox_pred[keep_indices]
        scores = obj_pred[keep_indices]

        final_boxes = []
        for ymin, xmin, ymax, xmax in boxes:
            final_boxes.append([
                max(0, int(xmin * orig_w)), max(0, int(ymin * orig_h)),
                min(orig_w, int(xmax * orig_w)), min(orig_h, int(ymax * orig_h))
            ])
        return np.array(final_boxes), scores

# ==============================================================================
# 5. CLASSIFIER MODULE (Chromosome Type Classifier)
# ==============================================================================
class ChromosomeClassifierDataGenerator(keras.utils.Sequence):
    """
    Memory-efficient data generator for the classifier.
    Loads and crops images on-the-fly instead of storing them in memory.
    """
    def __init__(self, images_dir: Path, annotations_dir: Path, batch_size: int, img_size: int,
                 class_map: Dict, is_validation: bool = False, augment: bool = False,
                 validation_split: float = 0.2, data_list=None, indexes=None):
        self.images_dir, self.annotations_dir = images_dir, annotations_dir
        self.batch_size, self.img_size = batch_size, img_size
        self.class_map, self.is_validation = class_map, is_validation

        if data_list is None or indexes is None:
            self.chromosome_data = self._build_dataset_index()
            all_indexes = np.arange(len(self.chromosome_data))
            np.random.seed(42)
            np.random.shuffle(all_indexes)
            val_size = int(len(self.chromosome_data) * validation_split)
            self.val_indexes = all_indexes[:val_size]
            self.train_indexes = all_indexes[val_size:]
            self.indexes = self.val_indexes if self.is_validation else self.train_indexes
        else:
            self.chromosome_data = data_list
            self.indexes = indexes

        self.aug = A.Compose([
            A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), A.Rotate(limit=20, p=0.5),
            A.RandomBrightnessContrast(p=0.3), A.GaussNoise(p=0.2),
        ]) if augment and not self.is_validation else None

    def _build_dataset_index(self) -> List[Dict]:
        """Builds an index of chromosome locations, not the images themselves."""
        print("\nBuilding classifier dataset index...")
        index = []
        image_files = sorted([p.stem for p in self.images_dir.glob('*.jpg')] + [p.stem for p in self.images_dir.glob('*.png')])
        for img_name in image_files:
            xml_path = self.annotations_dir / f"{img_name}.xml"
            if not xml_path.exists(): continue
            try:
                annotations = AnnotationParser.parse_xml(xml_path)
                for obj in annotations['objects']:
                    if obj['name'] in self.class_map:
                        index.append({
                            'img_name': img_name,
                            'bbox': obj['bbox'],
                            'label': self.class_map[obj['name']]
                        })
            except Exception:
                continue
        print(f"✅ Found {len(index)} chromosome samples across {len(image_files)} images.")
        return index

    def get_validation_generator(self):
        """Creates a generator for the validation set."""
        return ChromosomeClassifierDataGenerator(
            self.images_dir, self.annotations_dir, self.batch_size, self.img_size,
            self.class_map, is_validation=True, augment=False,
            data_list=self.chromosome_data, indexes=self.val_indexes
        )

    def __len__(self):
        return int(np.ceil(len(self.indexes) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_images, batch_labels = [], []

        for idx in batch_indexes:
            item = self.chromosome_data[idx]
            img_path = self.images_dir / f"{item['img_name']}.jpg"
            if not img_path.exists(): img_path = self.images_dir / f"{item['img_name']}.png"

            try:
                image = cv2.imread(str(img_path))
                xmin, ymin, xmax, ymax = item['bbox']
                crop = image[ymin:ymax, xmin:xmax]
                if crop.size == 0: continue

                if self.aug: crop = self.aug(image=crop)['image']

                resized_crop = cv2.resize(crop, (self.img_size, self.img_size))
                normalized_crop = (resized_crop.astype(np.float32) / 255.0 - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]

                batch_images.append(normalized_crop)
                batch_labels.append(item['label'])
            except Exception:
                continue

        if not batch_images: # Handle empty batches
            return np.zeros((self.batch_size, self.img_size, self.img_size, 3), dtype=np.float32), \
                   np.zeros(self.batch_size, dtype=np.int32)
        return np.array(batch_images), np.array(batch_labels)

class ChromosomeClassifier:
    """The chromosome classification model (based on ResNet50)."""
    def __init__(self, num_classes: int, img_size: int):
        self.num_classes = num_classes
        self.img_size = img_size
        self.model = None

    def _build_model(self):
        base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(self.img_size, self.img_size, 3))
        for layer in base_model.layers[:140]: layer.trainable = False
        x = base_model.output
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(512, activation='relu')(x)
        x = layers.Dropout(0.5)(x)
        outputs = layers.Dense(self.num_classes, activation='softmax')(x)
        return models.Model(inputs=base_model.input, outputs=outputs)

    def compile_model(self, lr: float, strategy: tf.distribute.Strategy):
        with strategy.scope():
            self.model = self._build_model()
            self.model.compile(optimizer=optimizers.Adam(lr), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    def train(self, train_gen, val_gen, epochs: int, model_save_path: str):
        print("\n" + "="*60 + "\n🚀 STARTING CLASSIFIER TRAINING\n" + "="*60)
        callbacks = [
            keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5, patience=3, min_lr=1e-7, verbose=1),
            keras.callbacks.ModelCheckpoint(model_save_path, monitor='val_accuracy', save_best_only=True, verbose=1),
            keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=7, restore_best_weights=True, verbose=1)
        ]
        history = self.model.fit(train_gen, validation_data=val_gen, epochs=epochs, callbacks=callbacks, verbose=1)
        print("\n✅ Classifier training complete.")
        return history

# ==============================================================================
# 6. END-TO-END PIPELINE
# ==============================================================================

class KaryotypePipeline:
    """Orchestrates the detection and classification process."""
    def __init__(self, identifier: ChromosomeIdentifier, classifier: ChromosomeClassifier, detection_threshold: float):
        self.identifier = identifier
        self.classifier = classifier
        self.detection_threshold = detection_threshold
        self.classifier_img_size = classifier.img_size

    def process_image(self, image_path: Path) -> List[Dict]:
        image = cv2.imread(str(image_path))
        if image is None: return []

        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        boxes, scores = self.identifier.predict_boxes(image_rgb, self.detection_threshold)
        if len(boxes) == 0: return []

        predictions = []
        for i, box in enumerate(boxes):
            xmin, ymin, xmax, ymax = map(int, box)
            if xmax <= xmin or ymax <= ymin: continue

            crop = image_rgb[ymin:ymax, xmin:xmax]
            if crop.size == 0: continue

            resized_crop = cv2.resize(crop, (self.classifier_img_size, self.classifier_img_size))
            normalized_crop = (resized_crop.astype(np.float32)/255.0 - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
            normalized_crop = np.expand_dims(normalized_crop, axis=0)

            probs = self.classifier.model.predict(normalized_crop, verbose=0)[0]
            class_id = np.argmax(probs)
            confidence = probs[class_id]
            combined_score = float(scores[i]) * confidence

            predictions.append({
                'bbox': box.tolist(), 'class_id': int(class_id), 'confidence': float(confidence),
                'class_name': self._get_class_name(class_id), 'combined_score': combined_score
            })
        return predictions

    @staticmethod
    def _get_class_name(class_id: int) -> str:
        return Config.CLASS_NAMES[class_id]

# ==============================================================================
# 7. MAIN EXECUTION SCRIPT
# ==============================================================================
def main():
    """Main function to run the entire pipeline."""
    # Setup paths and output directory
    output_dir = Path(Config.OUTPUT_DIR)
    output_dir.mkdir(exist_ok=True)
    data_root = Path(Config.DATA_ROOT)
    single_chr_images = data_root / 'single_chromosomes_object' / 'images'
    single_chr_annotations = data_root / 'single_chromosomes_object' / 'annotations'
    multi_chr_images = data_root / '24_chromosomes_object' / 'images'
    multi_chr_annotations = data_root / '24_chromosomes_object' / 'annotations'

    # Dynamic batch sizes based on device
    if DEVICE_TYPE == "TPU":
        id_batch, cls_batch = 8 * STRATEGY.num_replicas_in_sync, 128 * STRATEGY.num_replicas_in_sync
    else: # GPU or CPU
        id_batch, cls_batch = 4 * STRATEGY.num_replicas_in_sync, 64 * STRATEGY.num_replicas_in_sync
    print(f"\nUsing Batch Sizes -> Identifier: {id_batch}, Classifier: {cls_batch}\n")

    # --- Step 1: Train or Load Identifier ---
    identifier = ChromosomeIdentifier(img_size=Config.IDENTIFIER_IMG_SIZE)
    if Config.TRAIN_IDENTIFIER:
        id_gen = ChromosomeIdentifierDataGenerator(single_chr_images, single_chr_annotations, id_batch, Config.IDENTIFIER_IMG_SIZE)
        identifier.train(id_gen, Config.IDENTIFIER_EPOCHS, STRATEGY)
        identifier.model.save(Config.IDENTIFIER_MODEL_PATH)
        print(f"✅ Identifier model saved to '{Config.IDENTIFIER_MODEL_PATH}'")
    else:
        print(f"🔄 Loading pre-trained identifier from '{Config.IDENTIFIER_MODEL_PATH}'...")
        identifier.model = keras.models.load_model(Config.IDENTIFIER_MODEL_PATH)

    # --- Step 2: Train or Load Classifier ---
    classifier = ChromosomeClassifier(num_classes=len(Config.CLASS_NAMES), img_size=Config.CLASSIFIER_IMG_SIZE)
    if Config.TRAIN_CLASSIFIER:
        train_cls_gen = ChromosomeClassifierDataGenerator(
            multi_chr_images, multi_chr_annotations, cls_batch, Config.CLASSIFIER_IMG_SIZE,
            Config.CLASS_MAP, augment=True, validation_split=Config.VALIDATION_SPLIT
        )
        val_cls_gen = train_cls_gen.get_validation_generator()
        classifier.compile_model(lr=0.001, strategy=STRATEGY)
        history = classifier.train(train_cls_gen, val_cls_gen, Config.CLASSIFIER_EPOCHS, Config.CLASSIFIER_MODEL_PATH)
        plot_training_history(history, output_dir / 'classifier_training_history.png')
    else:
        print(f"🔄 Loading pre-trained classifier from '{Config.CLASSIFIER_MODEL_PATH}'...")
        classifier.model = keras.models.load_model(Config.CLASSIFIER_MODEL_PATH)

    # --- Step 3: Run Evaluation and Visualization ---
    if Config.RUN_EVALUATION:
        print("\n" + "="*60 + "\n📊 RUNNING EVALUATION\n" + "="*60)
        test_file = data_root / 'test.txt'
        if not test_file.exists():
            print(f"⚠️ Warning: test.txt not found. Skipping evaluation.")
            return

        with open(test_file, 'r') as f:
            test_list = [line.strip() for line in f.readlines()]
        print(f"Found {len(test_list)} test samples.")

        pipeline = KaryotypePipeline(identifier, classifier, Config.DETECTION_THRESHOLD)
        parser = AnnotationParser()
        y_true, y_pred = [], []
        
        # Create a directory for visualization images
        vis_dir = output_dir / 'visualizations'
        if Config.VISUALIZE_PREDICTIONS:
            vis_dir.mkdir(exist_ok=True)
            print(f"Visualizations will be saved to '{vis_dir}'")

        for i, img_name in enumerate(test_list):
            print(f"Processing image {i+1}/{len(test_list)}: {img_name}")
            img_path = multi_chr_images / f"{img_name}.jpg"
            if not img_path.exists(): img_path = multi_chr_images / f"{img_name}.png"
            xml_path = multi_chr_annotations / f"{img_name}.xml"
            if not img_path.exists() or not xml_path.exists(): continue

            gt_annotations = parser.parse_xml(xml_path)
            predictions = pipeline.process_image(img_path)

            gt_boxes = [obj for obj in gt_annotations['objects'] if obj['name'] in Config.CLASS_MAP]
            
            # Match predictions to ground truth
            for gt_box_info in gt_boxes:
                gt_bbox = gt_box_info['bbox']
                gt_label = Config.CLASS_MAP[gt_box_info['name']]
                best_iou, best_pred_label = -1, -1

                for pred in predictions:
                    iou = compute_iou(gt_bbox, pred['bbox'])
                    if iou > best_iou:
                        best_iou = iou
                        best_pred_label = pred['class_id']
                
                if best_iou >= Config.EVAL_IOU_THRESHOLD:
                    y_true.append(gt_label)
                    y_pred.append(best_pred_label)
                else: # Missed detection
                    y_true.append(gt_label)
                    y_pred.append(-1) # Special class for missed detections

            # Visualize some predictions
            if Config.VISUALIZE_PREDICTIONS and i < Config.NUM_VISUALIZATION_IMAGES:
                vis_save_path = vis_dir / f"{img_name}_pred.png"
                visualize_predictions(img_path, gt_boxes, predictions, vis_save_path, Config.EVAL_IOU_THRESHOLD)
        
        # Filter out missed detections for classification metrics
        y_true_cls = [yt for i, yt in enumerate(y_true) if y_pred[i] != -1]
        y_pred_cls = [yp for yp in y_pred if yp != -1]

        # --- Display Results ---
        print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
        if y_true_cls:
            report = classification_report(y_true_cls, y_pred_cls, target_names=Config.CLASS_NAMES, zero_division=0)
            print("Classification Report (on matched boxes):")
            print(report)
            plot_confusion_matrix(y_true_cls, y_pred_cls, Config.CLASS_NAMES, output_dir / 'confusion_matrix.png')
            
            acc = accuracy_score(y_true_cls, y_pred_cls)
            bal_acc = balanced_accuracy_score(y_true_cls, y_pred_cls)
            print(f"\nOverall Accuracy (matched): {acc:.4f}")
            print(f"Balanced Accuracy (matched): {bal_acc:.4f}")
        else:
            print("No valid matched boxes found for classification report.")


if __name__ == "__main__":
    main()