In [5]:
# ==============================================================================
# ENVIRONMENT SETUP - RUN THIS CELL FIRST
# ==============================================================================

print("🔧 Setting up environment with required libraries...")
print("This may take a few minutes...")

# Uninstall existing packages to avoid conflicts
%pip uninstall -y numpy tensorflow torch

# Install all required packages with specific versions
%pip install \
    numpy==1.26.3 \
    pandas==2.2.2 \
    scikit-learn==1.4.2 \
    matplotlib==3.8.4 \
    seaborn==0.13.2 \
    tensorflow==2.15.0 \
    torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121 \
    ultralytics==8.1.47 \
    albumentations==1.4.0 \
    pyyaml==6.0.1

print("\n✅ Environment setup complete!")
print("⚠️  IMPORTANT: Please RESTART the runtime/kernel now before running the next cells.")
print("   Go to: Runtime → Restart runtime (or Kernel → Restart)")


🔧 Setting up environment with required libraries...
This may take a few minutes...
[0mNote: you may need to restart the kernel to use updated packages.
Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting numpy==1.26.3
  Downloading https://download.pytorch.org/whl/numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl (14.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[?25h[31mERROR: Could not find a version that satisfies the requirement pandas==2.2.2 (from versions: none)[0m[31m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[31mERROR: No matching distribution found for pandas==2.2.2[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.

✅ Envi

## Quick Configuration Guide

### Configuration Flags (Set at the top of script)

```py
    TRAIN_IDENTIFIER: bool = True
    TRAIN_CLASSIFIER: bool = True
    RUN_EVALUATION: bool = True
    VISUALIZE_PREDICTIONS: bool = True # Set to True to save example images
    NUM_VISUALIZATION_IMAGES: int = 10 # Number of test images to visualize

    # --- Model Paths ---
    IDENTIFIER_MODEL_PATH: str = 'identifier_model.h5'
    CLASSIFIER_MODEL_PATH: str = 'best_classifier.h5'
    OUTPUT_DIR: str = 'pipeline_output' # Directory for plots and images

    # --- Training Parameters ---
    IDENTIFIER_EPOCHS: int = 5
    CLASSIFIER_EPOCHS: int = 15
    IDENTIFIER_IMG_SIZE: int = 800
    CLASSIFIER_IMG_SIZE: int = 224
    VALIDATION_SPLIT: float = 0.2

    # --- Pipeline Parameters ---
    DETECTION_THRESHOLD: float = 0.3      # Minimum confidence for a detected object
    EVAL_IOU_THRESHOLD: float = 0.5       # IoU threshold to match pred/GT boxes

    # --- Dataset Path ---
    # Ensure this points to the correct directory
    DATA_ROOT: str = '/kaggle/input/2025-karyogram-cv-camp/2025_Karyogram_CV_Camp'
```


In [2]:
# ==============================================================================
# ⚙️ CELL 1: ENVIRONMENT SETUP, IMPORTS & CONFIGURATION
# ==============================================================================
# This cell pins the versions of key libraries to ensure full compatibility.
# Run this cell once, then RESTART the kernel before running the rest.
# After restarting, run the imports below


import os
import shutil
import yaml
import warnings
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import List, Tuple, Dict, Any
from sklearn.model_selection import train_test_split
from itertools import cycle

# Suppress TensorFlow logs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from sklearn.metrics import (
    classification_report, confusion_matrix, accuracy_score,
    balanced_accuracy_score, roc_curve, auc, precision_recall_curve,
    average_precision_score
)
from sklearn.preprocessing import label_binarize
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import ResNet50
from ultralytics import YOLO
import albumentations as A

warnings.filterwarnings('ignore')
tf.get_logger().setLevel('ERROR')
tf.config.optimizer.set_jit(False)
tf.keras.backend.clear_session()

class Config:
    """Holds all configuration parameters for the pipeline."""
    # --- Execution Control ---
    # Set these to False to skip re-training and use saved models
    TRAIN_IDENTIFIER: bool = True
    TRAIN_CLASSIFIER: bool = True
    
    # --- Model & Training Paths ---
    OUTPUT_DIR: str = 'pipeline_output'
    CLASSIFIER_MODEL_PATH: str = 'best_classifier.h5'

    # --- YOLO Identifier Parameters ---
    YOLO_MODEL_NAME: str = 'yolov8n.pt'
    YOLO_DATA_DIR: str = 'yolo_dataset'
    YOLO_DATA_CONFIG: str = 'yolo_dataset/data.yaml'
    YOLO_PROJECT_NAME: str = 'chromosome_detector'
    YOLO_RUN_NAME: str = 'train_run'
    YOLO_MODEL_PATH: str = f'{YOLO_PROJECT_NAME}/{YOLO_RUN_NAME}/weights/best.pt'

    # --- Training Parameters ---
    IDENTIFIER_EPOCHS: int = 25
    CLASSIFIER_EPOCHS: int = 15
    IDENTIFIER_IMG_SIZE: int = 640
    CLASSIFIER_IMG_SIZE: int = 224
    VALIDATION_SPLIT: float = 0.2

    # --- Pipeline Parameters ---
    DETECTION_THRESHOLD: float = 0.3
    EVAL_IOU_THRESHOLD: float = 0.5

    # --- Dataset Path ---
    DATA_ROOT: str = '/kaggle/input/inter-iit-small-set/2025_Karyogram_CV_Camp_small'

    # --- Class Mapping (for Classifier) ---
    CLASS_MAP: Dict[str, int] = {
        '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5, '7': 6, '8': 7,
        '9': 8, '10': 9, '11': 10, '12': 11, '13': 12, '14': 13, '15': 14,
        '16': 15, '17': 16, '18': 17, '19': 18, '20': 19, '21': 20, '22': 21,
        'X': 22, 'Y': 23
    }
    CLASS_NAMES: List[str] = [name for name, _ in sorted(CLASS_MAP.items(), key=lambda item: item[1])]
    NUM_CLASSES = len(CLASS_NAMES)

ModuleNotFoundError: No module named 'numpy'

In [None]:
# ==============================================================================
# 🛠️ CELL 2: UTILITIES & DATA VISUALIZATION
# ==============================================================================

class AnnotationParser:
    """Parses XML annotation files."""
    @staticmethod
    def parse_xml(xml_path: Path) -> Dict[str, Any]:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        annotations = {'objects': []}
        size_node = root.find('size')
        if size_node:
            annotations['width'] = int(size_node.find('width').text)
            annotations['height'] = int(size_node.find('height').text)
        
        for obj in root.findall('object'):
            bndbox = obj.find('bndbox')
            if bndbox:
                annotations['objects'].append({
                    'name': obj.find('name').text,
                    'bbox': [
                        int(float(bndbox.find('xmin').text)), int(float(bndbox.find('ymin').text)),
                        int(float(bndbox.find('xmax').text)), int(float(bndbox.find('ymax').text))
                    ]
                })
        return annotations

def compute_iou(boxA: List[int], boxB: List[int]) -> float:
    xA = max(boxA[0], boxB[0]); yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2]); yB = min(boxA[3], boxB[3])
    inter_area = max(0, xB - xA) * max(0, yB - yA)
    boxA_area = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
    boxB_area = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
    union_area = float(boxA_area + boxB_area - inter_area)
    return inter_area / union_area if union_area > 0 else 0.0

def visualize_ground_truth_examples(image_dir: Path, annotation_dir: Path, num_examples: int = 4):
    """Plots a few images with their ground truth bounding boxes."""
    print("Displaying ground truth examples...")
    image_files = list(image_dir.glob('*.jpg')) + list(image_dir.glob('*.png'))
    plt.figure(figsize=(15, 5 * num_examples // 2))
    
    for i, img_path in enumerate(np.random.choice(image_files, num_examples, replace=False)):
        xml_path = annotation_dir / (img_path.stem + ".xml")
        if not xml_path.exists(): continue
            
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        annotations = AnnotationParser.parse_xml(xml_path)
        
        ax = plt.subplot(num_examples // 2, 2, i + 1)
        ax.imshow(image)
        ax.set_title(img_path.name)
        ax.axis('off')
        
        for obj in annotations['objects']:
            xmin, ymin, xmax, ymax = obj['bbox']
            rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='lime', facecolor='none')
            ax.add_patch(rect)
    plt.tight_layout()
    plt.show()

def plot_training_history(history: keras.callbacks.History, save_path: Path):
    pd.DataFrame(history.history).plot(figsize=(12, 8))
    plt.grid(True); plt.gca().set_ylim(0, 1)
    plt.title("Classifier Training History"); plt.xlabel("Epoch")
    plt.savefig(save_path); plt.close()
    print(f"✅ Classifier training history saved to '{save_path}'")

def visualize_predictions(image_path: Path, gt_boxes: List, predictions: List, save_path: Path):
    image = cv2.imread(str(image_path)); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    fig, ax = plt.subplots(1, figsize=(15, 15)); ax.imshow(image); ax.axis('off')
    for gt in gt_boxes:
        xmin, ymin, xmax, ymax = gt['bbox']
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='g', facecolor='none')
        ax.add_patch(rect); ax.text(xmin, ymin - 10, f"GT: {gt['name']}", color='g', fontsize=12, weight='bold')
    for pred in predictions:
        xmin, ymin, xmax, ymax = map(int, pred['bbox'])
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect); label = f"Pred: {pred['class_name']} ({pred['combined_score']:.2f})"
        ax.text(xmin, ymax + 20, label, color='r', fontsize=12, weight='bold')
    plt.tight_layout(); plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1); plt.close()

In [None]:
# ==============================================================================
# 🎯 CELL 3: STEP 1 - TRAIN THE YOLO IDENTIFIER (DETECTOR)
# ==============================================================================

def prepare_yolo_data(image_dir: Path, annotation_dir: Path, output_dir: Path, val_split: float):
    """Converts XML annotations to YOLO format and creates data.yaml."""
    print("\n" + "="*60 + "\n📦 PREPARING YOLO DATASET\n" + "="*60)
    if output_dir.exists():
        shutil.rmtree(output_dir)

    dir_paths = { 'train_images': output_dir / 'images' / 'train', 'val_images': output_dir / 'images' / 'val',
                  'train_labels': output_dir / 'labels' / 'train', 'val_labels': output_dir / 'labels' / 'val' }
    for path in dir_paths.values(): path.mkdir(parents=True, exist_ok=True)

    image_stems = sorted([p.stem for p in image_dir.glob('*.jpg')] + [p.stem for p in image_dir.glob('*.png')])
    train_stems, val_stems = train_test_split(image_stems, test_size=val_split, random_state=42)
    stem_map = {'train': train_stems, 'val': val_stems}

    for split, stems in stem_map.items():
        for stem in stems:
            xml_path = annotation_dir / f"{stem}.xml"
            img_path = (image_dir / f"{stem}.jpg") or (image_dir / f"{stem}.png")
            if not xml_path.exists() or not img_path.exists(): continue
            
            shutil.copy(img_path, dir_paths[f'{split}_images'])
            ann = AnnotationParser.parse_xml(xml_path)
            img_w, img_h = ann['width'], ann['height']
            yolo_lines = []
            for obj in ann['objects']:
                class_id = 0 # Single class "chromosome" for detector
                xmin, ymin, xmax, ymax = obj['bbox']
                x_center, width = ((xmin + xmax) / 2 / img_w, (xmax - xmin) / img_w)
                y_center, height = ((ymin + ymax) / 2 / img_h, (ymax - ymin) / img_h)
                yolo_lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
            
            with open(dir_paths[f'{split}_labels'] / f"{stem}.txt", 'w') as f:
                f.write("\n".join(yolo_lines))
    
    with open(output_dir / 'data.yaml', 'w') as f:
        yaml.dump({ 'train': str(dir_paths['train_images'].resolve()), 'val': str(dir_paths['val_images'].resolve()),
                     'nc': 1, 'names': ['chromosome']}, f, default_flow_style=False)
    print(f"✅ YOLO dataset created at '{output_dir.resolve()}'")

class YOLOIdentifier:
    """Wrapper for the YOLOv8 chromosome detection model."""
    def __init__(self, model_name: str = 'yolov8n.pt'):
        self.model = YOLO(model_name)
    def train(self, data_config_path: str, epochs: int, img_size: int, project: str, name: str):
        self.model.train(data=data_config_path, epochs=epochs, imgsz=img_size,
                         project=project, name=name, exist_ok=True)
    def load_weights(self, weights_path: str):
        self.model = YOLO(weights_path)
    def predict_boxes(self, image: np.ndarray, confidence_threshold: float):
        results = self.model.predict(image, conf=confidence_threshold, verbose=False)
        if not results: return np.array([]), np.array([])
        return results[0].boxes.xyxy.cpu().numpy(), results[0].boxes.conf.cpu().numpy()

# --- Main execution for Identifier Training ---
if Config.TRAIN_IDENTIFIER:
    data_root = Path(Config.DATA_ROOT)
    single_chr_images = data_root / 'single_chromosomes_object' / 'images'
    single_chr_annotations = data_root / 'single_chromosomes_object' / 'annotations'

    # Visualize Ground Truth Data Before Training
    visualize_ground_truth_examples(single_chr_images, single_chr_annotations, num_examples=4)
    
    # Prepare data and train
    prepare_yolo_data(single_chr_images, single_chr_annotations,
                      Path(Config.YOLO_DATA_DIR), Config.VALIDATION_SPLIT)
    
    print("\n🚀 STARTING YOLO IDENTIFIER TRAINING...")
    identifier_trainer = YOLOIdentifier(model_name=Config.YOLO_MODEL_NAME)
    identifier_trainer.train(data_config_path=Config.YOLO_DATA_CONFIG, epochs=Config.IDENTIFIER_EPOCHS,
                             img_size=Config.IDENTIFIER_IMG_SIZE, project=Config.YOLO_PROJECT_NAME,
                             name=Config.YOLO_RUN_NAME)
    print("\n✅ YOLO Identifier training complete.")
else:
    print("☑️ Skipping Identifier training as per Config.")

In [None]:
# ==============================================================================
# 🏷️ CELL 4: STEP 2 - TRAIN THE TENSORFLOW CLASSIFIER
# ==============================================================================

class ChromosomeClassifierDataGenerator(keras.utils.Sequence):
    """Memory-efficient data generator for the classifier."""
    def __init__(self, images_dir: Path, annotations_dir: Path, batch_size: int, img_size: int,
                 class_map: Dict, is_validation: bool = False, augment: bool = False,
                 validation_split: float = 0.2, data_list=None, indexes=None):
        self.images_dir, self.annotations_dir = images_dir, annotations_dir
        self.batch_size, self.img_size = batch_size, img_size
        self.class_map, self.is_validation = class_map, is_validation
        if data_list is None:
            self.chromosome_data = self._build_dataset_index()
            all_indexes = np.arange(len(self.chromosome_data))
            np.random.seed(42); np.random.shuffle(all_indexes)
            val_size = int(len(self.chromosome_data) * validation_split)
            self.val_indexes = all_indexes[:val_size]
            self.train_indexes = all_indexes[val_size:]
            self.indexes = self.val_indexes if is_validation else self.train_indexes
        else:
            self.chromosome_data, self.indexes = data_list, indexes
        self.aug = A.Compose([A.HorizontalFlip(), A.VerticalFlip(), A.Rotate(limit=20),
                              A.RandomBrightnessContrast(p=0.3), A.GaussNoise(p=0.2)]) if augment else None
    def _build_dataset_index(self):
        print("\nBuilding classifier dataset index..."); index = []
        image_files = sorted([p.stem for p in self.images_dir.glob('*.jpg')] + [p.stem for p in self.images_dir.glob('*.png')])
        for name in image_files:
            xml_path = self.annotations_dir / f"{name}.xml"
            if not xml_path.exists(): continue
            for obj in AnnotationParser.parse_xml(xml_path)['objects']:
                if obj['name'] in self.class_map:
                    index.append({'img_name': name, 'bbox': obj['bbox'], 'label': self.class_map[obj['name']]})
        print(f"✅ Found {len(index)} chromosome samples.")
        return index
    def get_validation_generator(self):
        return ChromosomeClassifierDataGenerator(self.images_dir, self.annotations_dir, self.batch_size,
                                             self.img_size, self.class_map, is_validation=True,
                                             data_list=self.chromosome_data, indexes=self.val_indexes)
    def __len__(self): return int(np.ceil(len(self.indexes) / self.batch_size))
    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_images, batch_labels = [], []
        for idx in batch_indexes:
            item = self.chromosome_data[idx]
            path = (self.images_dir / f"{item['img_name']}.jpg") or (self.images_dir / f"{item['img_name']}.png")
            image = cv2.imread(str(path))
            if image is None: continue
            xmin, ymin, xmax, ymax = item['bbox']
            crop = image[ymin:ymax, xmin:xmax]
            if crop.size == 0: continue
            if self.aug: crop = self.aug(image=crop)['image']
            crop = cv2.resize(crop, (self.img_size, self.img_size))
            norm_crop = (crop.astype(np.float32) / 255.0 - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
            batch_images.append(norm_crop); batch_labels.append(item['label'])
        return np.array(batch_images), np.array(batch_labels)

class ChromosomeClassifier:
    def __init__(self, num_classes: int, img_size: int):
        self.num_classes, self.img_size = num_classes, img_size
        self.model = None
    def _build_model(self):
        base = ResNet50(include_top=False, weights='imagenet', input_shape=(self.img_size, self.img_size, 3))
        for layer in base.layers[:140]: layer.trainable = False
        x = layers.GlobalAveragePooling2D()(base.output)
        x = layers.Dense(512, activation='relu')(x)
        x = layers.Dropout(0.5)(x)
        outputs = layers.Dense(self.num_classes, activation='softmax')(x)
        return models.Model(inputs=base.input, outputs=outputs)
    def compile_model(self, lr: float, strategy: tf.distribute.Strategy):
        with strategy.scope():
            self.model = self._build_model()
            self.model.compile(optimizer=optimizers.Adam(lr), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    def train(self, train_gen, val_gen, epochs: int, save_path: str):
        callbacks = [ keras.callbacks.ReduceLROnPlateau(patience=3, min_lr=1e-7),
                      keras.callbacks.ModelCheckpoint(save_path, save_best_only=True),
                      keras.callbacks.EarlyStopping(patience=7, restore_best_weights=True) ]
        return self.model.fit(train_gen, validation_data=val_gen, epochs=epochs, callbacks=callbacks)

# --- Main execution for Classifier Training ---
if Config.TRAIN_CLASSIFIER:
    STRATEGY, DEVICE_TYPE = setup_device_strategy()
    data_root = Path(Config.DATA_ROOT)
    multi_chr_images = data_root / '24_chromosomes_object' / 'images'
    multi_chr_annotations = data_root / '24_chromosomes_object' / 'annotations'
    
    cls_batch = 64 * STRATEGY.num_replicas_in_sync
    train_gen = ChromosomeClassifierDataGenerator(
        multi_chr_images, multi_chr_annotations, cls_batch, Config.CLASSIFIER_IMG_SIZE,
        Config.CLASS_MAP, augment=True, validation_split=Config.VALIDATION_SPLIT)
    val_gen = train_gen.get_validation_generator()
    
    print("\n🚀 STARTING CLASSIFIER TRAINING...")
    classifier = ChromosomeClassifier(num_classes=Config.NUM_CLASSES, img_size=Config.CLASSIFIER_IMG_SIZE)
    classifier.compile_model(lr=0.001, strategy=STRATEGY)
    history = classifier.train(train_gen, val_gen, Config.CLASSIFIER_EPOCHS, Config.CLASSIFIER_MODEL_PATH)
    
    output_dir = Path(Config.OUTPUT_DIR); output_dir.mkdir(exist_ok=True)
    plot_training_history(history, output_dir / 'classifier_training_history.png')
    print(f"\n✅ Classifier training complete. Model saved to {Config.CLASSIFIER_MODEL_PATH}")
else:
    print("☑️ Skipping Classifier training as per Config.")

In [None]:
# ==============================================================================
# 🔎 CELL 5: PREDICTION PIPELINE & VISUALIZATION
# ==============================================================================

class KaryotypePipeline:
    def __init__(self, identifier, classifier):
        self.identifier = identifier
        self.classifier = classifier
        self.classifier_img_size = classifier.input_shape[1]

    def process_image(self, image_path: Path, detection_threshold: float):
        image = cv2.imread(str(image_path))
        if image is None: return []
        
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        boxes, scores = self.identifier.predict_boxes(image_rgb, detection_threshold)
        if len(boxes) == 0: return []
        
        predictions = []
        for i, box in enumerate(boxes):
            xmin, ymin, xmax, ymax = map(int, box)
            if xmax <= xmin or ymax <= ymin: continue
            crop = image_rgb[ymin:ymax, xmin:xmax]
            if crop.size == 0: continue
            
            resized_crop = cv2.resize(crop, (self.classifier_img_size, self.classifier_img_size))
            norm_crop = (resized_crop.astype(np.float32) / 255.0 - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
            norm_crop = np.expand_dims(norm_crop, axis=0)
            
            probs = self.classifier.predict(norm_crop, verbose=0)[0]
            class_id = np.argmax(probs)
            
            predictions.append({
                'bbox': box.tolist(), 'class_id': int(class_id),
                'class_name': Config.CLASS_NAMES[class_id],
                'combined_score': float(scores[i]) * probs[class_id],
                'probs': probs 
            })
        return predictions

# --- Main execution for Prediction ---
print("\n" + "="*60 + "\n📊 RUNNING PREDICTION ON TEST SET\n" + "="*60)

# 1. Load trained models
print("Loading trained models...")
identifier = YOLOIdentifier()
identifier.load_weights(Config.YOLO_MODEL_PATH)
classifier_model = keras.models.load_model(Config.CLASSIFIER_MODEL_PATH)
print("✅ Models loaded successfully.")

# 2. Setup pipeline and paths
pipeline = KaryotypePipeline(identifier, classifier_model)
data_root = Path(Config.DATA_ROOT)
test_file = data_root / 'test.txt'
multi_chr_images = data_root / '24_chromosomes_object' / 'images'
multi_chr_annotations = data_root / '24_chromosomes_object' / 'annotations'
output_dir = Path(Config.OUTPUT_DIR); output_dir.mkdir(exist_ok=True)
vis_dir = output_dir / 'visualizations'; vis_dir.mkdir(exist_ok=True)

# 3. Process test set and store results for evaluation
with open(test_file, 'r') as f:
    test_list = [line.strip() for line in f.readlines()]

all_gt_labels = []
all_pred_labels = []
all_pred_scores = [] # For ROC/PRC

for i, img_name in enumerate(test_list):
    print(f"Processing image {i+1}/{len(test_list)}: {img_name}", end='\r')
    img_path = (multi_chr_images / f"{img_name}.jpg") or (multi_chr_images / f"{img_name}.png")
    xml_path = multi_chr_annotations / f"{img_name}.xml"
    if not img_path.exists() or not xml_path.exists(): continue
    
    predictions = pipeline.process_image(img_path, Config.DETECTION_THRESHOLD)
    gt_annotations = AnnotationParser.parse_xml(xml_path)
    gt_boxes = [obj for obj in gt_annotations['objects'] if obj['name'] in Config.CLASS_MAP]

    # Match predictions to ground truth
    for gt_box_info in gt_boxes:
        gt_bbox = gt_box_info['bbox']
        gt_label = Config.CLASS_MAP[gt_box_info['name']]
        best_iou, best_pred = -1, None

        for pred in predictions:
            iou = compute_iou(gt_bbox, pred['bbox'])
            if iou > best_iou:
                best_iou, best_pred = iou, pred
        
        all_gt_labels.append(gt_label)
        if best_iou >= Config.EVAL_IOU_THRESHOLD:
            all_pred_labels.append(best_pred['class_id'])
            all_pred_scores.append(best_pred['probs'])
        else:
            all_pred_labels.append(-1) # Missed detection
            all_pred_scores.append(np.zeros(Config.NUM_CLASSES))

    # Visualize first few predictions
    if i < Config.NUM_VISUALIZATION_IMAGES:
        vis_save_path = vis_dir / f"{img_name}_pred.png"
        visualize_predictions(img_path, gt_boxes, predictions, vis_save_path)

print(f"\n✅ Prediction complete. Visualizations saved to '{vis_dir}'.")

In [None]:
# ==============================================================================
# 📈 CELL 6: PERFORMANCE EVALUATION METRICS
# ==============================================================================

def plot_roc_curves(y_true_bin, y_pred_scores, class_names, save_path):
    """Plots multiclass ROC curves and calculates AUC."""
    fpr, tpr, roc_auc = {}, {}, {}
    n_classes = len(class_names)

    # Compute ROC for each class
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_scores[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Compute micro-average ROC
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin.ravel(), y_pred_scores.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    # Plotting
    plt.figure(figsize=(12, 10))
    plt.plot(fpr["micro"], tpr["micro"], label=f'Micro-average ROC (AUC = {roc_auc["micro"]:.2f})', color='deeppink', linestyle=':', linewidth=4)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red'])
    for i, color in zip(range(n_classes), colors):
        if i < 5: # Plot first 5 classes for clarity
            plt.plot(fpr[i], tpr[i], color=color, lw=2, label=f'ROC curve of class {class_names[i]} (AUC = {roc_auc[i]:.2f})')

    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right"); plt.grid(True)
    plt.savefig(save_path); plt.show()
    print(f"✅ ROC curves saved to '{save_path}'")

def plot_prc_curves(y_true_bin, y_pred_scores, class_names, save_path):
    """Plots multiclass Precision-Recall curves and calculates Average Precision."""
    precision, recall, avg_precision = {}, {}, {}
    n_classes = len(class_names)

    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], y_pred_scores[:, i])
        avg_precision[i] = average_precision_score(y_true_bin[:, i], y_pred_scores[:, i])

    precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_bin.ravel(), y_pred_scores.ravel())
    avg_precision["micro"] = average_precision_score(y_true_bin, y_pred_scores, average="micro")

    plt.figure(figsize=(12, 10))
    plt.plot(recall["micro"], precision["micro"], label=f'Micro-average PRC (AP = {avg_precision["micro"]:.2f})', color='navy', linestyle=':', linewidth=4)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red'])
    for i, color in zip(range(n_classes), colors):
         if i < 5: # Plot first 5 classes for clarity
            plt.plot(recall[i], precision[i], color=color, lw=2, label=f'PRC of class {class_names[i]} (AP = {avg_precision[i]:.2f})')

    plt.xlabel('Recall'); plt.ylabel('Precision')
    plt.title('Precision-Recall Curves (PRC)')
    plt.legend(loc="best"); plt.grid(True)
    plt.savefig(save_path); plt.show()
    print(f"✅ PRC curves saved to '{save_path}'")

# --- Main execution for Evaluation ---
print("\n" + "="*60 + "\nEVALUATION RESULTS\n" + "="*60)
output_dir = Path(Config.OUTPUT_DIR)

# Filter out missed detections for metrics
y_true_eval = np.array([gt for i, gt in enumerate(all_gt_labels) if all_pred_labels[i] != -1])
y_pred_eval = np.array([p for p in all_pred_labels if p != -1])
y_scores_eval = np.array([s for i, s in enumerate(all_pred_scores) if all_pred_labels[i] != -1])

if len(y_true_eval) > 0:
    # 1. Classification Report
    print("Classification Report (on matched boxes):\n")
    report = classification_report(y_true_eval, y_pred_eval, target_names=Config.CLASS_NAMES, zero_division=0)
    print(report)

    # 2. Confusion Matrix
    plot_confusion_matrix(y_true_eval, y_pred_eval, Config.CLASS_NAMES, output_dir / 'confusion_matrix.png')
    
    # Binarize labels for ROC/PRC
    y_true_bin = label_binarize(y_true_eval, classes=range(Config.NUM_CLASSES))

    # 3. ROC Curve
    plot_roc_curves(y_true_bin, y_scores_eval, Config.CLASS_NAMES, output_dir / 'roc_curves.png')

    # 4. PRC Curve
    plot_prc_curves(y_true_bin, y_scores_eval, Config.CLASS_NAMES, output_dir / 'prc_curves.png')
else:
    print("No valid matched boxes found to generate an evaluation report.")