In [None]:
# ==============================================================================
# CELL 0: DEVICE SETUP WITH TPU/GPU/CPU FALLBACK
# This cell should be run FIRST before any other cells
# ==============================================================================

import os
import tensorflow as tf

def setup_device_strategy():
    """
    Setup device strategy with fallback: TPU -> GPU -> CPU
    Returns the strategy to use for training
    """
    print("=" * 60)
    print("DEVICE SETUP")
    print("=" * 60)
    
    # Try TPU first
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
        print(f"✓ TPU detected: {tpu.cluster_spec().as_dict()}")
        strategy = tf.distribute.TPUStrategy(tpu)
        device_type = "TPU"
        print(f"✓ Using TPU strategy with {strategy.num_replicas_in_sync} replicas")
        return strategy, device_type
    except (ValueError, tf.errors.NotFoundError):
        print("✗ No TPU detected")
    
    # Try GPU next
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            # Enable memory growth for GPUs to avoid OOM errors
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            
            print(f"✓ GPU(s) detected: {len(gpus)} device(s)")
            for i, gpu in enumerate(gpus):
                print(f"  GPU {i}: {gpu}")
            
            if len(gpus) > 1:
                # Use MirroredStrategy for multi-GPU
                strategy = tf.distribute.MirroredStrategy()
                device_type = "Multi-GPU"
                print(f"✓ Using MirroredStrategy with {strategy.num_replicas_in_sync} GPUs")
            else:
                # Use default strategy for single GPU
                strategy = tf.distribute.get_strategy()
                device_type = "Single-GPU"
                print("✓ Using single GPU")
            
            return strategy, device_type
        except RuntimeError as e:
            print(f"✗ GPU setup failed: {e}")
    else:
        print("✗ No GPU detected")
    
    # Fallback to CPU
    print("✓ Falling back to CPU")
    strategy = tf.distribute.get_strategy()
    device_type = "CPU"
    print("⚠ Warning: Training on CPU will be significantly slower")
    
    return strategy, device_type

# Setup device strategy
STRATEGY, DEVICE_TYPE = setup_device_strategy()

print("\n" + "=" * 60)
print(f"DEVICE CONFIGURATION COMPLETE: {DEVICE_TYPE}")
print("=" * 60)
print(f"Number of devices: {STRATEGY.num_replicas_in_sync}")
print(f"TensorFlow version: {tf.__version__}")
print("=" * 60 + "\n")

In [None]:
# ==============================================================================
# CELL 1: TRAIN IDENTIFIER MODULE WITH TPU/GPU/CPU SUPPORT
# Uses ALL data from single_chromosomes_object folder
# ==============================================================================

import os
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import ResNet50
import xml.etree.ElementTree as ET
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
import albumentations as A

# ==================== SHARED UTILITIES ====================

class AnnotationParser:
    """Parse XML annotations for both single and 24-chromosome datasets"""

    @staticmethod
    def parse_xml(xml_path: str) -> Dict:
        tree = ET.parse(xml_path)
        root = tree.getroot()

        annotations = {
            'filename': root.find('filename').text if root.find('filename') is not None else '',
            'size': {},
            'objects': []
        }

        size = root.find('size')
        if size is not None:
            annotations['size'] = {
                'width': int(size.find('width').text),
                'height': int(size.find('height').text),
                'depth': int(size.find('depth').text) if size.find('depth') is not None else 3
            }

        for obj in root.findall('object'):
            name = obj.find('name').text
            bndbox = obj.find('bndbox')

            if bndbox is not None:
                bbox = {
                    'xmin': int(float(bndbox.find('xmin').text)),
                    'ymin': int(float(bndbox.find('ymin').text)),
                    'xmax': int(float(bndbox.find('xmax').text)),
                    'ymax': int(float(bndbox.find('ymax').text))
                }

                annotations['objects'].append({
                    'name': name,
                    'bbox': bbox
                })

        return annotations


# ==================== IDENTIFIER MODULE ====================

class ChromosomeIdentifierDataGenerator(keras.utils.Sequence):
    """Data generator for training chromosome identifier"""

    def __init__(self, images_dir: str, annotations_dir: str,
                 batch_size=2, img_size=800, shuffle=True):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.parser = AnnotationParser()

        # Get all image files from directory
        self.image_list = []
        for filename in os.listdir(images_dir):
            if filename.endswith('.jpg') or filename.endswith('.png'):
                base_name = os.path.splitext(filename)[0]
                xml_path = os.path.join(annotations_dir, f"{base_name}.xml")
                if os.path.exists(xml_path):
                    self.image_list.append(base_name)

        print(f"Found {len(self.image_list)} images for identifier training")

        self.indexes = np.arange(len(self.image_list))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        return int(np.ceil(len(self.image_list) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_images = [self.image_list[k] for k in batch_indexes]

        X, y = self._generate_data(batch_images)
        return X, y

    def _generate_data(self, batch_images):
        images = []
        targets = []

        for img_name in batch_images:
            img_path = os.path.join(self.images_dir, f"{img_name}.jpg")
            xml_path = os.path.join(self.annotations_dir, f"{img_name}.xml")

            if not os.path.exists(img_path):
                img_path = os.path.join(self.images_dir, f"{img_name}.png")

            if not os.path.exists(img_path) or not os.path.exists(xml_path):
                continue

            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            orig_h, orig_w = image.shape[:2]

            image = cv2.resize(image, (self.img_size, self.img_size))
            image = image.astype(np.float32) / 255.0

            annotations = self.parser.parse_xml(xml_path)

            boxes = []
            for obj in annotations['objects']:
                bbox = obj['bbox']
                xmin = bbox['xmin'] / orig_w
                ymin = bbox['ymin'] / orig_h
                xmax = bbox['xmax'] / orig_w
                ymax = bbox['ymax'] / orig_h
                boxes.append([ymin, xmin, ymax, xmax])

            images.append(image)
            targets.append(np.array(boxes) if len(boxes) > 0 else np.zeros((0, 4)))

        return images, targets

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)


class ChromosomeIdentifier:
    """Chromosome detection model"""

    def __init__(self, img_size=800, max_detections=50):
        self.img_size = img_size
        self.max_detections = max_detections
        self.model = None

    def _build_model(self):
        inputs = layers.Input(shape=(self.img_size, self.img_size, 3))

        base_model = ResNet50(include_top=False, weights='imagenet', input_tensor=inputs)

        for layer in base_model.layers[:100]:
            layer.trainable = False

        x = base_model.output
        x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
        x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
        x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(1024, activation='relu')(x)
        x = layers.Dropout(0.5)(x)

        bbox_output = layers.Dense(self.max_detections * 4, name='bbox')(x)
        objectness = layers.Dense(self.max_detections, activation='sigmoid', name='objectness')(x)

        model = models.Model(inputs=inputs, outputs=[bbox_output, objectness])
        return model

    def train(self, train_gen, epochs=10, strategy=None):
        """Train with device strategy support"""
        
        # Build model within strategy scope
        if strategy is not None:
            with strategy.scope():
                self.model = self._build_model()
                optimizer = optimizers.Adam(0.001)
        else:
            self.model = self._build_model()
            optimizer = optimizers.Adam(0.001)

        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")
            epoch_losses = []

            for batch_idx in range(len(train_gen)):
                images, targets = train_gen[batch_idx]

                if len(images) == 0:
                    continue

                images = np.array(images)

                with tf.GradientTape() as tape:
                    bbox_pred, obj_pred = self.model(images, training=True)
                    bbox_pred = tf.reshape(bbox_pred, (-1, self.max_detections, 4))

                    batch_loss = 0
                    for i, target_boxes in enumerate(targets):
                        if len(target_boxes) > 0:
                            num_objs = min(len(target_boxes), self.max_detections)

                            target_padded = np.zeros((self.max_detections, 4))
                            target_padded[:num_objs] = target_boxes[:num_objs]

                            bbox_loss = tf.reduce_mean(tf.square(bbox_pred[i] - target_padded))

                            obj_target = np.zeros(self.max_detections)
                            obj_target[:num_objs] = 1
                            obj_loss = tf.keras.losses.binary_crossentropy(obj_target, obj_pred[i])

                            batch_loss += bbox_loss + 0.5 * tf.reduce_mean(obj_loss)

                    batch_loss = batch_loss / len(targets)

                grads = tape.gradient(batch_loss, self.model.trainable_variables)
                optimizer.apply_gradients(zip(grads, self.model.trainable_variables))

                epoch_losses.append(float(batch_loss))

                if (batch_idx + 1) % 10 == 0:
                    print(f"  Batch {batch_idx+1}/{len(train_gen)}, Loss: {np.mean(epoch_losses[-10:]):.4f}")

            print(f"Epoch Loss: {np.mean(epoch_losses):.4f}")

        return self.model

    def predict_boxes(self, image, confidence_threshold=0.3):
        orig_h, orig_w = image.shape[:2]
        image_resized = cv2.resize(image, (self.img_size, self.img_size))
        image_resized = image_resized.astype(np.float32) / 255.0
        image_resized = np.expand_dims(image_resized, axis=0)

        bbox_pred, obj_pred = self.model.predict(image_resized, verbose=0)

        bbox_pred = bbox_pred.reshape(self.max_detections, 4)
        obj_pred = obj_pred.reshape(self.max_detections)

        keep = obj_pred > confidence_threshold
        boxes = bbox_pred[keep]
        scores = obj_pred[keep]

        final_boxes = []
        for box in boxes:
            ymin, xmin, ymax, xmax = box
            xmin = int(xmin * orig_w)
            ymin = int(ymin * orig_h)
            xmax = int(xmax * orig_w)
            ymax = int(ymax * orig_h)

            xmin = max(0, min(xmin, orig_w))
            xmax = max(0, min(xmax, orig_w))
            ymin = max(0, min(ymin, orig_h))
            ymax = max(0, min(ymax, orig_h))

            if xmax > xmin and ymax > ymin:
                final_boxes.append([xmin, ymin, xmax, ymax])

        return np.array(final_boxes), scores[:len(final_boxes)]


# ==================== TRAIN IDENTIFIER ====================

print("="*60)
print("CELL 1: TRAINING CHROMOSOME IDENTIFIER")
print("="*60)

DATA_ROOT = '/content/drive/MyDrive/ParadoX/inter_iit/2025_Karyogram_CV_Camp'
SINGLE_CHR_IMAGES = os.path.join(DATA_ROOT, 'single_chromosomes_object', 'images')
SINGLE_CHR_ANNOTATIONS = os.path.join(DATA_ROOT, 'single_chromosomes_object', 'annotations')

print(f"\nUsing device: {DEVICE_TYPE}")
print(f"Number of replicas: {STRATEGY.num_replicas_in_sync}")

# Adjust batch size based on device
if DEVICE_TYPE == "TPU":
    BATCH_SIZE = 8 * STRATEGY.num_replicas_in_sync  # TPU works best with larger batches
elif DEVICE_TYPE in ["Single-GPU", "Multi-GPU"]:
    BATCH_SIZE = 2 * STRATEGY.num_replicas_in_sync
else:  # CPU
    BATCH_SIZE = 1  # Smaller batch for CPU

print(f"Adjusted batch size: {BATCH_SIZE}")

train_identifier_gen = ChromosomeIdentifierDataGenerator(
    SINGLE_CHR_IMAGES, SINGLE_CHR_ANNOTATIONS,
    batch_size=BATCH_SIZE, img_size=800, shuffle=True
)

identifier = ChromosomeIdentifier(img_size=800, max_detections=50)

print("\nStarting identifier training...")
identifier.train(train_identifier_gen, epochs=10, strategy=STRATEGY)

identifier.model.save('identifier_model.h5')
print("\n✓ Identifier model saved to 'identifier_model.h5'")
print("="*60)

In [None]:
# ==============================================================================
# CELL 2: TRAIN CLASSIFIER MODULE WITH TPU/GPU/CPU SUPPORT
# Uses ALL data from 24_chromosomes_object folder
# ==============================================================================

class ChromosomeClassifierDataGenerator(keras.utils.Sequence):
    """Data generator for training chromosome classifier"""

    CHROMOSOME_CLASSES = {
        '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5,
        '7': 6, '8': 7, '9': 8, '10': 9, '11': 10, '12': 11,
        '13': 12, '14': 13, '15': 14, '16': 15, '17': 16, '18': 17,
        '19': 18, '20': 19, '21': 20, '22': 21, 'X': 22, 'Y': 23
    }

    def __init__(self, images_dir: str, annotations_dir: str, is_validation=False,
                 batch_size=32, img_size=224, augment=False, validation_split=0.2,
                 data=None, indexes=None):
        print(f"\nInitializing ChromosomeClassifierDataGenerator:")
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.augment = augment
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.parser = AnnotationParser()

        # Initialize data storage
        if data is None or indexes is None:
            self.chromosome_data = []
            self._build_dataset()

            # Split dataset only for the training generator
            if not self.is_validation:
                total_samples = len(self.chromosome_data)
                val_size = int(total_samples * validation_split)

                # Create train/val indexes
                np.random.seed(42)
                all_indexes = np.arange(total_samples)
                np.random.shuffle(all_indexes)

                self.indexes = all_indexes[val_size:]
                self.val_indexes = all_indexes[:val_size]

                print(f"Train samples: {len(self.indexes)}, Validation samples: {len(self.val_indexes)}")
            else:
                self.chromosome_data = data
                self.indexes = indexes
                self.val_indexes = []
                print(f"Validation samples: {len(self.indexes)}")
        else:
            # Use provided data and indexes for validation generator
            self.chromosome_data = data
            self.indexes = indexes
            self.val_indexes = []
            print(f"Validation samples: {len(self.indexes)}")

        # Setup augmentation if required
        if augment and not self.is_validation:
            self.aug = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Rotate(limit=15, p=0.5),
                A.RandomBrightnessContrast(p=0.3),
                A.GaussNoise(p=0.2),
            ])
        else:
            self.aug = None

    def _build_dataset(self):
        """Build dataset from images and annotations"""
        print("\nBuilding classifier dataset...")

        # Get list of valid image files
        image_files = []
        for filename in os.listdir(self.images_dir):
            if filename.endswith(('.jpg', '.png')):
                base_name = os.path.splitext(filename)[0]
                xml_path = os.path.join(self.annotations_dir, f"{base_name}.xml")
                if os.path.exists(xml_path):
                    image_files.append(base_name)

        print(f"Found {len(image_files)} images to process...")

        # Process each image and its annotations
        for img_name in image_files:
            img_path = os.path.join(self.images_dir, f"{img_name}.jpg")
            xml_path = os.path.join(self.annotations_dir, f"{img_name}.xml")

            # Try PNG if JPG doesn't exist
            if not os.path.exists(img_path):
                img_path = os.path.join(self.images_dir, f"{img_name}.png")

            if not os.path.exists(img_path) or not os.path.exists(xml_path):
                print(f"Warning: Missing files for {img_name}")
                continue

            # Load and process image
            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Parse annotations
            annotations = self.parser.parse_xml(xml_path)

            # Extract individual chromosomes
            for obj in annotations['objects']:
                bbox = obj['bbox']
                label = obj['name']

                if label in self.CHROMOSOME_CLASSES:
                    chromosome_img = image[bbox['ymin']:bbox['ymax'],
                                        bbox['xmin']:bbox['xmax']]

                    if chromosome_img.size > 0:
                        self.chromosome_data.append({
                            'image': chromosome_img,
                            'label': self.CHROMOSOME_CLASSES[label]
                        })

        print(f"Built dataset with {len(self.chromosome_data)} chromosome samples")

    def get_validation_generator(self):
        """Create validation generator using validation split"""
        val_gen = ChromosomeClassifierDataGenerator(
            images_dir=self.images_dir,
            annotations_dir=self.annotations_dir,
            batch_size=self.batch_size,
            img_size=self.img_size,
            augment=False,
            validation_split=0.0,
            is_validation=True,
            data=[self.chromosome_data[i] for i in self.val_indexes],
            indexes=np.arange(len(self.val_indexes))
        )
        return val_gen

    def __len__(self):
        """Return number of batches per epoch"""
        return int(np.ceil(len(self.indexes) / self.batch_size))

    def __getitem__(self, index):
        """Get batch of data"""
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        batch_images = []
        batch_labels = []

        for idx in batch_indexes:
            data = self.chromosome_data[idx]
            image = data['image'].copy()
            label = data['label']

            # Apply augmentation if enabled
            if self.aug is not None:
                augmented = self.aug(image=image)
                image = augmented['image']

            # Preprocess image
            image = cv2.resize(image, (self.img_size, self.img_size))
            image = image.astype(np.float32) / 255.0

            # Normalize using ImageNet stats
            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            image = (image - mean) / std

            batch_images.append(image)
            batch_labels.append(label)

        return np.array(batch_images), np.array(batch_labels)

    def on_epoch_end(self):
        """Called at the end of every epoch"""
        if not self.is_validation:
            np.random.shuffle(self.indexes)


class ChromosomeClassifier:
    """Chromosome classification model (24 classes)"""

    def __init__(self, num_classes=24, img_size=224):
        self.num_classes = num_classes
        self.img_size = img_size
        self.model = None

    def _build_model(self):
        base_model = ResNet50(include_top=False, weights='imagenet',
                             input_shape=(self.img_size, self.img_size, 3))

        for layer in base_model.layers[:140]:
            layer.trainable = False

        x = base_model.output
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(512, activation='relu')(x)
        x = layers.Dropout(0.5)(x)
        x = layers.Dense(256, activation='relu')(x)
        x = layers.Dropout(0.3)(x)
        outputs = layers.Dense(self.num_classes, activation='softmax')(x)

        model = models.Model(inputs=base_model.input, outputs=outputs)
        return model

    def compile_model(self, lr=0.001, strategy=None):
        """Compile model within strategy scope if provided"""
        if strategy is not None:
            with strategy.scope():
                self.model = self._build_model()
                self.model.compile(
                    optimizer=optimizers.Adam(lr),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy']
                )
        else:
            self.model = self._build_model()
            self.model.compile(
                optimizer=optimizers.Adam(lr),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

    def train(self, train_gen, val_gen, epochs=30):
        callbacks = [
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_accuracy', factor=0.5, patience=3,
                min_lr=1e-7, verbose=1
            ),
            keras.callbacks.ModelCheckpoint(
                'best_classifier.h5', monitor='val_accuracy',
                save_best_only=True, verbose=1
            ),
            keras.callbacks.EarlyStopping(
                monitor='val_accuracy', patience=7,
                restore_best_weights=True, verbose=1
            )
        ]

        history = self.model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )

        return history


# ==================== TRAIN CLASSIFIER ====================

print("\n" + "="*60)
print("CELL 2: TRAINING CHROMOSOME CLASSIFIER")
print("="*60)

DATA_ROOT = '/content/drive/MyDrive/ParadoX/inter_iit/2025_Karyogram_CV_Camp'
MULTI_CHR_IMAGES = os.path.join(DATA_ROOT, '24_chromosomes_object', 'images')
MULTI_CHR_ANNOTATIONS = os.path.join(DATA_ROOT, '24_chromosomes_object', 'annotations')

print(f"\nUsing device: {DEVICE_TYPE}")
print(f"Number of replicas: {STRATEGY.num_replicas_in_sync}")

# Adjust batch size based on device
if DEVICE_TYPE == "TPU":
    CLASSIFIER_BATCH_SIZE = 128 * STRATEGY.num_replicas_in_sync
elif DEVICE_TYPE in ["Single-GPU", "Multi-GPU"]:
    CLASSIFIER_BATCH_SIZE = 32 * STRATEGY.num_replicas_in_sync
else:  # CPU
    CLASSIFIER_BATCH_SIZE = 8

print(f"Adjusted batch size: {CLASSIFIER_BATCH_SIZE}")

train_classifier_gen = ChromosomeClassifierDataGenerator(
    MULTI_CHR_IMAGES, MULTI_CHR_ANNOTATIONS,
    batch_size=CLASSIFIER_BATCH_SIZE, img_size=224, augment=True, validation_split=0.2
)

val_classifier_gen = train_classifier_gen.get_validation_generator()

classifier = ChromosomeClassifier(num_classes=24)
classifier.compile_model(lr=0.001, strategy=STRATEGY)

print("\nStarting classifier training...")
history = classifier.train(train_classifier_gen, val_classifier_gen, epochs=30)

print("\n✓ Classifier model saved to 'best_classifier.h5'")
print("="*60)

In [None]:
# ==============================================================================
# CELL 3: INTEGRATION AND EVALUATION WITH DEVICE SUPPORT
# Uses train.txt and test.txt to select specific images from 24_chromosomes_object
# ==============================================================================

class KaryotypePipeline:
    """End-to-end pipeline integrating identifier and classifier"""

    def __init__(self, identifier: ChromosomeIdentifier,
                 classifier: ChromosomeClassifier):
        self.identifier = identifier
        self.classifier = classifier
        self.img_size = 224

    def process_image(self, image_path: str, confidence_threshold=0.3):
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        boxes, scores = self.identifier.predict_boxes(image_rgb, confidence_threshold=confidence_threshold)

        if len(boxes) == 0:
            return []

        predictions = []

        for i, box in enumerate(boxes):
            xmin, ymin, xmax, ymax = map(int, box)

            chromosome_img = image_rgb[ymin:ymax, xmin:xmax]

            if chromosome_img.size == 0:
                continue

            chromosome_img = cv2.resize(chromosome_img, (self.img_size, self.img_size))
            chromosome_img = chromosome_img.astype(np.float32) / 255.0

            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            chromosome_img = (chromosome_img - mean) / std

            chromosome_img = np.expand_dims(chromosome_img, axis=0)

            probs = self.classifier.model.predict(chromosome_img, verbose=0)[0]
            class_id = np.argmax(probs)
            confidence = probs[class_id]

            predictions.append({
                'bbox': box,
                'class_id': int(class_id),
                'class_name': self._get_class_name(class_id),
                'confidence': float(confidence),
                'score': float(scores[i]) if i < len(scores) else 1.0
            })

        return predictions

    @staticmethod
    def _get_class_name(class_id: int) -> str:
        if class_id < 22:
            return str(class_id + 1)
        elif class_id == 22:
            return 'X'
        else:
            return 'Y'


def evaluate_pipeline(pipeline: KaryotypePipeline, images_dir: str,
                     annotations_dir: str, test_list: List[str],
                     output_dir='results'):
    os.makedirs(output_dir, exist_ok=True)

    parser = AnnotationParser()

    all_true_labels = []
    all_pred_probs = []

    print("Evaluating pipeline on test set...")
    for idx, img_name in enumerate(test_list):
        img_name = img_name.strip()
        img_path = os.path.join(images_dir, f"{img_name}.jpg")
        xml_path = os.path.join(annotations_dir, f"{img_name}.xml")

        if not os.path.exists(img_path):
            img_path = os.path.join(images_dir, f"{img_name}.png")

        if not os.path.exists(img_path) or not os.path.exists(xml_path):
            continue

        if (idx + 1) % 10 == 0:
            print(f"Processed {idx+1}/{len(test_list)} images...")

        annotations = parser.parse_xml(xml_path)
        gt_chromosomes = [obj['name'] for obj in annotations['objects']]

        predictions = pipeline.process_image(img_path)

        for gt_chr in gt_chromosomes:
            if gt_chr in ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES:
                gt_class = ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES[gt_chr]

                true_labels = np.zeros(24)
                true_labels[gt_class] = 1

                pred_probs = np.zeros(24)
                for pred in predictions:
                    pred_probs[pred['class_id']] = max(
                        pred_probs[pred['class_id']],
                        pred['confidence']
                    )

                all_true_labels.append(true_labels)
                all_pred_probs.append(pred_probs)

    all_true_labels = np.array(all_true_labels)
    all_pred_probs = np.array(all_pred_probs)

    auprc_scores = []
    plt.figure(figsize=(15, 10))

    for i in range(24):
        if np.sum(all_true_labels[:, i]) > 0:
            precision, recall, _ = precision_recall_curve(
                all_true_labels[:, i],
                all_pred_probs[:, i]
            )
            auprc = auc(recall, precision)
            auprc_scores.append(auprc)

            plt.plot(recall, precision,
                    label=f'Class {pipeline._get_class_name(i)} (auPRC={auprc:.3f})')

    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curves for All Chromosome Classes', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'auprc_curves.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    mean_auprc = np.mean(auprc_scores)
    print(f"\nMean auPRC: {mean_auprc:.4f}")

    with open(os.path.join(output_dir, 'auprc_scores.txt'), 'w') as f:
        f.write(f"Mean auPRC: {mean_auprc:.4f}\n\n")
        f.write("Per-class auPRC scores:\n")
        for i, score in enumerate(auprc_scores):
            f.write(f"Class {pipeline._get_class_name(i)}: {score:.4f}\n")

    return mean_auprc, auprc_scores


# ==================== INTEGRATION AND EVALUATION ====================

print("\n" + "="*60)
print("CELL 3: INTEGRATION AND EVALUATION")
print("="*60)

# Load train/test splits from txt files
print("\nLoading train/test splits...")
with open(os.path.join(DATA_ROOT, 'train.txt'), 'r') as f:
    train_list = f.readlines()

with open(os.path.join(DATA_ROOT, 'test.txt'), 'r') as f:
    test_list = f.readlines()

print(f"Train samples from train.txt: {len(train_list)}")
print(f"Test samples from test.txt: {len(test_list)}")

# Load trained models (within strategy scope if using TPU)
print("\nLoading trained models...")
print(f"Loading on device: {DEVICE_TYPE}")

if DEVICE_TYPE == "TPU":
    # For TPU, load models within strategy scope
    with STRATEGY.scope():
        identifier_loaded = ChromosomeIdentifier(img_size=800, max_detections=50)
        identifier_loaded.model = keras.models.load_model('identifier_model.h5')
        print("✓ Identifier model loaded")

        classifier_loaded = ChromosomeClassifier(num_classes=24)
        classifier_loaded.model = keras.models.load_model('best_classifier.h5')
        print("✓ Classifier model loaded")
else:
    # For GPU/CPU, load normally
    identifier_loaded = ChromosomeIdentifier(img_size=800, max_detections=50)
    identifier_loaded.model = keras.models.load_model('identifier_model.h5')
    print("✓ Identifier model loaded")

    classifier_loaded = ChromosomeClassifier(num_classes=24)
    classifier_loaded.model = keras.models.load_model('best_classifier.h5')
    print("✓ Classifier model loaded")

# Create pipeline
pipeline = KaryotypePipeline(identifier_loaded, classifier_loaded)
print("\n✓ Pipeline created successfully")

# Evaluate on test set (using images specified in test.txt)
print("\nEvaluating pipeline on test set from test.txt...")
mean_auprc, class_auprcs = evaluate_pipeline(
    pipeline, MULTI_CHR_IMAGES, MULTI_CHR_ANNOTATIONS,
    test_list, output_dir='results'
)

print("\n" + "="*60)
print("EVALUATION COMPLETE!")
print("="*60)
print(f"Mean auPRC: {mean_auprc:.4f}")
print(f"\nResults saved to 'results/' directory:")
print(f"  - auPRC curve: results/auprc_curves.png")
print(f"  - Detailed scores: results/auprc_scores.txt")
print("="*60)

In [None]:
# ==============================================================================
# CELL 4: COMPREHENSIVE EVALUATION METRICS WITH DEVICE SUPPORT
# Includes Confusion Matrix, ROC Curves, and Detailed Metrics
# ==============================================================================

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report, 
    roc_curve, roc_auc_score, auc,
    precision_recall_fscore_support,
    accuracy_score, balanced_accuracy_score
)
import pandas as pd
from itertools import cycle
import warnings
warnings.filterwarnings('ignore')

def compute_comprehensive_metrics(pipeline, images_dir, annotations_dir, 
                                  test_list, output_dir='evaluation_results'):
    """
    Compute comprehensive evaluation metrics including:
    - Confusion Matrix
    - ROC Curves (One-vs-Rest)
    - Classification Report
    - Per-class metrics
    """
    
    os.makedirs(output_dir, exist_ok=True)
    parser = AnnotationParser()
    
    # Storage for predictions and ground truth
    all_true_labels = []
    all_pred_labels = []
    all_pred_probs = []
    
    # Class mapping
    CLASS_NAMES = [str(i+1) for i in range(22)] + ['X', 'Y']
    
    print("Computing comprehensive metrics...")
    print(f"Processing {len(test_list)} test images...")
    print(f"Using device: {DEVICE_TYPE}")
    
    # Process all test images
    for idx, img_name in enumerate(test_list):
        img_name = img_name.strip()
        img_path = os.path.join(images_dir, f"{img_name}.jpg")
        xml_path = os.path.join(annotations_dir, f"{img_name}.xml")
        
        if not os.path.exists(img_path):
            img_path = os.path.join(images_dir, f"{img_name}.png")
        
        if not os.path.exists(img_path) or not os.path.exists(xml_path):
            continue
        
        if (idx + 1) % 20 == 0:
            print(f"  Processed {idx+1}/{len(test_list)} images...")
        
        # Get ground truth
        annotations = parser.parse_xml(xml_path)
        
        # Get predictions
        predictions = pipeline.process_image(img_path, confidence_threshold=0.3)
        
        # Match predictions to ground truth
        for obj in annotations['objects']:
            gt_chr = obj['name']
            if gt_chr not in ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES:
                continue
            
            gt_class = ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES[gt_chr]
            
            # Find best matching prediction (highest confidence)
            if predictions:
                best_pred = max(predictions, key=lambda x: x['confidence'])
                pred_class = best_pred['class_id']
                
                # Store probabilities for this prediction
                pred_probs = np.zeros(24)
                pred_probs[pred_class] = best_pred['confidence']
            else:
                pred_class = 0  # Default to class 0 if no detection
                pred_probs = np.zeros(24)
                pred_probs[0] = 1.0
            
            all_true_labels.append(gt_class)
            all_pred_labels.append(pred_class)
            all_pred_probs.append(pred_probs)
    
    all_true_labels = np.array(all_true_labels)
    all_pred_labels = np.array(all_pred_labels)
    all_pred_probs = np.array(all_pred_probs)
    
    print(f"\nTotal samples evaluated: {len(all_true_labels)}")
    
    # ==================== CONFUSION MATRIX ====================
    print("\n" + "="*60)
    print("GENERATING CONFUSION MATRIX")
    print("="*60)
    
    cm = confusion_matrix(all_true_labels, all_pred_labels)
    
    # Plot confusion matrix
    plt.figure(figsize=(20, 18))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Class', fontsize=14, fontweight='bold')
    plt.ylabel('True Class', fontsize=14, fontweight='bold')
    plt.title('Confusion Matrix - Chromosome Classification', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()
    
    # Normalized confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plt.figure(figsize=(20, 18))
    sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                cbar_kws={'label': 'Proportion'})
    plt.xlabel('Predicted Class', fontsize=14, fontweight='bold')
    plt.ylabel('True Class', fontsize=14, fontweight='bold')
    plt.title('Normalized Confusion Matrix - Chromosome Classification', 
              fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'confusion_matrix_normalized.png'),
                dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Confusion matrices saved")
    
    # ==================== ROC CURVES ====================
    print("\n" + "="*60)
    print("GENERATING ROC CURVES")
    print("="*60)
    
    # Convert to binary format for ROC (One-vs-Rest)
    from sklearn.preprocessing import label_binarize
    y_true_bin = label_binarize(all_true_labels, classes=range(24))
    
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    for i in range(24):
        if np.sum(y_true_bin[:, i]) > 0:  # Only compute if class exists
            fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], all_pred_probs[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
    
    # Plot ROC curves
    plt.figure(figsize=(16, 12))
    colors = cycle(plt.cm.tab20.colors)
    
    for i, color in zip(range(24), colors):
        if i in roc_auc:
            plt.plot(fpr[i], tpr[i], color=color, lw=2,
                    label=f'Class {CLASS_NAMES[i]} (AUC = {roc_auc[i]:.3f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=14, fontweight='bold')
    plt.title('ROC Curves - One-vs-Rest (All Classes)', fontsize=16, fontweight='bold')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=9)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'roc_curves_all.png'),
                dpi=300, bbox_inches='tight')
    plt.close()
    
    # Compute micro-average and macro-average ROC
    fpr_micro, tpr_micro, _ = roc_curve(y_true_bin.ravel(), all_pred_probs.ravel())
    roc_auc_micro = auc(fpr_micro, tpr_micro)
    
    # Compute macro-average
    all_fpr = np.unique(np.concatenate([fpr[i] for i in roc_auc.keys()]))
    mean_tpr = np.zeros_like(all_fpr)
    for i in roc_auc.keys():
        mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
    mean_tpr /= len(roc_auc)
    roc_auc_macro = auc(all_fpr, mean_tpr)
    
    # Plot aggregate ROC curves
    plt.figure(figsize=(10, 8))
    plt.plot(fpr_micro, tpr_micro, 'b-', lw=3,
             label=f'Micro-average (AUC = {roc_auc_micro:.3f})')
    plt.plot(all_fpr, mean_tpr, 'r-', lw=3,
             label=f'Macro-average (AUC = {roc_auc_macro:.3f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=14, fontweight='bold')
    plt.title('Aggregate ROC Curves', fontsize=16, fontweight='bold')
    plt.legend(loc='lower right', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'roc_curves_aggregate.png'),
                dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ ROC curves saved")
    
    # ==================== CLASSIFICATION REPORT ====================
    print("\n" + "="*60)
    print("GENERATING CLASSIFICATION REPORT")
    print("="*60)
    
    # Generate classification report
    report = classification_report(all_true_labels, all_pred_labels, 
                                   target_names=CLASS_NAMES, 
                                   output_dict=True, zero_division=0)
    
    # Convert to DataFrame for better visualization
    report_df = pd.DataFrame(report).transpose()
    
    # Save detailed report
    with open(os.path.join(output_dir, 'classification_report.txt'), 'w') as f:
        f.write("="*80 + "\n")
        f.write("COMPREHENSIVE CLASSIFICATION REPORT\n")
        f.write("="*80 + "\n\n")
        f.write(classification_report(all_true_labels, all_pred_labels,
                                     target_names=CLASS_NAMES, zero_division=0))
        f.write("\n" + "="*80 + "\n")
    
    # Save as CSV
    report_df.to_csv(os.path.join(output_dir, 'classification_report.csv'))
    
    print("✓ Classification report saved")
    
    # ==================== PER-CLASS METRICS ====================
    print("\n" + "="*60)
    print("COMPUTING PER-CLASS METRICS")
    print("="*60)
    
    precision, recall, f1, support = precision_recall_fscore_support(
        all_true_labels, all_pred_labels, average=None, zero_division=0
    )
    
    # Create comprehensive metrics DataFrame
    metrics_df = pd.DataFrame({
        'Class': CLASS_NAMES,
        'Precision': precision,
        'Recall': recall,
        'F1-Score': f1,
        'Support': support,
        'ROC-AUC': [roc_auc.get(i, np.nan) for i in range(24)]
    })
    
    # Save metrics
    metrics_df.to_csv(os.path.join(output_dir, 'per_class_metrics.csv'), index=False)
    
    # Plot per-class metrics comparison
    fig, axes = plt.subplots(2, 2, figsize=(18, 14))
    
    # Precision
    axes[0, 0].bar(range(24), precision, color='steelblue', alpha=0.7)
    axes[0, 0].set_xlabel('Chromosome Class', fontweight='bold')
    axes[0, 0].set_ylabel('Precision', fontweight='bold')
    axes[0, 0].set_title('Precision by Class', fontweight='bold', fontsize=14)
    axes[0, 0].set_xticks(range(24))
    axes[0, 0].set_xticklabels(CLASS_NAMES, rotation=45)
    axes[0, 0].grid(axis='y', alpha=0.3)
    
    # Recall
    axes[0, 1].bar(range(24), recall, color='coral', alpha=0.7)
    axes[0, 1].set_xlabel('Chromosome Class', fontweight='bold')
    axes[0, 1].set_ylabel('Recall', fontweight='bold')
    axes[0, 1].set_title('Recall by Class', fontweight='bold', fontsize=14)
    axes[0, 1].set_xticks(range(24))
    axes[0, 1].set_xticklabels(CLASS_NAMES, rotation=45)
    axes[0, 1].grid(axis='y', alpha=0.3)
    
    # F1-Score
    axes[1, 0].bar(range(24), f1, color='mediumseagreen', alpha=0.7)
    axes[1, 0].set_xlabel('Chromosome Class', fontweight='bold')
    axes[1, 0].set_ylabel('F1-Score', fontweight='bold')
    axes[1, 0].set_title('F1-Score by Class', fontweight='bold', fontsize=14)
    axes[1, 0].set_xticks(range(24))
    axes[1, 0].set_xticklabels(CLASS_NAMES, rotation=45)
    axes[1, 0].grid(axis='y', alpha=0.3)
    
    # ROC-AUC
    roc_values = [roc_auc.get(i, 0) for i in range(24)]
    axes[1, 1].bar(range(24), roc_values, color='mediumpurple', alpha=0.7)
    axes[1, 1].set_xlabel('Chromosome Class', fontweight='bold')
    axes[1, 1].set_ylabel('ROC-AUC', fontweight='bold')
    axes[1, 1].set_title('ROC-AUC by Class', fontweight='bold', fontsize=14)
    axes[1, 1].set_xticks(range(24))
    axes[1, 1].set_xticklabels(CLASS_NAMES, rotation=45)
    axes[1, 1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'per_class_metrics.png'),
                dpi=300, bbox_inches='tight')
    plt.close()
    
    print("✓ Per-class metrics saved")
    
    # ==================== SUMMARY STATISTICS ====================
    print("\n" + "="*60)
    print("COMPUTING SUMMARY STATISTICS")
    print("="*60)
    
    # Overall accuracy
    overall_acc = accuracy_score(all_true_labels, all_pred_labels)
    balanced_acc = balanced_accuracy_score(all_true_labels, all_pred_labels)
    
    # Weighted averages
    weighted_precision = np.average(precision, weights=support)
    weighted_recall = np.average(recall, weights=support)
    weighted_f1 = np.average(f1, weights=support)
    
    # Macro averages
    macro_precision = np.mean(precision)
    macro_recall = np.mean(recall)
    macro_f1 = np.mean(f1)
    
    summary = {
        'Overall Accuracy': overall_acc,
        'Balanced Accuracy': balanced_acc,
        'Macro-avg Precision': macro_precision,
        'Macro-avg Recall': macro_recall,
        'Macro-avg F1-Score': macro_f1,
        'Weighted-avg Precision': weighted_precision,
        'Weighted-avg Recall': weighted_recall,
        'Weighted-avg F1-Score': weighted_f1,
        'Micro-avg ROC-AUC': roc_auc_micro,
        'Macro-avg ROC-AUC': roc_auc_macro
    }
    
    # Save summary
    with open(os.path.join(output_dir, 'summary_statistics.txt'), 'w') as f:
        f.write("="*80 + "\n")
        f.write("SUMMARY STATISTICS\n")
        f.write("="*80 + "\n\n")
        f.write(f"Device Used: {DEVICE_TYPE}\n")
        f.write(f"Number of Replicas: {STRATEGY.num_replicas_in_sync}\n\n")
        for metric, value in summary.items():
            f.write(f"{metric:.<50} {value:.4f}\n")
        f.write("\n" + "="*80 + "\n")
    
    # Print summary
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    for metric, value in summary.items():
        print(f"{metric:.<50} {value:.4f}")
    print("="*60)
    
    print(f"\n✓ All evaluation metrics saved to '{output_dir}/' directory")
    
    return summary, metrics_df, cm

# ==================== RUN EVALUATION ====================

# Run comprehensive evaluation
summary_stats, per_class_metrics, conf_matrix = compute_comprehensive_metrics(
    pipeline=pipeline,
    images_dir=MULTI_CHR_IMAGES,
    annotations_dir=MULTI_CHR_ANNOTATIONS,
    test_list=test_list,
    output_dir='evaluation_results'
)

print("\n" + "="*60)
print("EVALUATION METRICS GENERATION COMPLETE!")
print("="*60)
print(f"\nDevice used: {DEVICE_TYPE}")
print(f"Number of replicas: {STRATEGY.num_replicas_in_sync}")
print("\nGenerated files:")
print("  1. confusion_matrix.png - Raw confusion matrix")
print("  2. confusion_matrix_normalized.png - Normalized confusion matrix")
print("  3. roc_curves_all.png - Individual ROC curves for all classes")
print("  4. roc_curves_aggregate.png - Micro and macro-average ROC")
print("  5. classification_report.txt - Detailed classification report")
print("  6. classification_report.csv - Report in CSV format")
print("  7. per_class_metrics.csv - Per-class metrics table")
print("  8. per_class_metrics.png - Visual comparison of metrics")
print("  9. summary_statistics.txt - Overall performance summary")
print("="*60)