## Quick Configuration Guide

### Configuration Flags (Set at the top of script)

```python
TRAIN_IDENTIFIER = True   # Train identifier? (False = load existing)
TRAIN_CLASSIFIER = True   # Train classifier? (False = load existing)
RUN_EVALUATION = True     # Run evaluation?

RESUME_IDENTIFIER = False # Continue identifier training from saved weights?
RESUME_CLASSIFIER = False # Continue classifier training from saved weights?

IDENTIFIER_EPOCHS = 15    # How many epochs for identifier
CLASSIFIER_EPOCHS = 50    # How many epochs for classifier
```

### Important Notes

- **`RESUME = True`** only works when **`TRAIN = True`**
- Models save automatically: `identifier_model.h5` and `best_classifier.h5`
- If model files missing and `TRAIN = False`, script will error
- **Resume = Continue training** from saved weights (not start over)


### Quick Decision Tree

- **First time running?** → All `TRAIN = True`, All `RESUME = False`
- **Kaggle timed out?** → All `TRAIN = True`, All `RESUME = True`
- **Just want results?** → All `TRAIN = False`, `RUN_EVALUATION = True`
- **Only classifier bad?** → `TRAIN_CLASSIFIER = True`, `TRAIN_IDENTIFIER = False`

In [None]:
# ==============================================================================
# COMPLETE CHROMOSOME CLASSIFICATION PIPELINE - TPU v5e-8 OPTIMIZED
# Features:
# - TPU v5e-8 optimized batch sizes
# - Robust error handling
# - Resume training capability
# - Complete end-to-end pipeline
#
# TPU v5e-8 Setup (Kaggle):
# 1. Go to notebook settings (right sidebar)
# 2. Set Accelerator to "TPU v5e-8"
# 3. Run this script - TPU will be auto-detected
# ==============================================================================

import os
import sys
import numpy as np
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import ResNet50
import xml.etree.ElementTree as ET
from sklearn.metrics import (
    precision_recall_curve, auc, confusion_matrix, classification_report,
    roc_curve, precision_recall_fscore_support, accuracy_score, 
    balanced_accuracy_score
)
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from typing import List, Tuple, Dict
from itertools import cycle
import warnings
warnings.filterwarnings('ignore')

# Install albumentations if not available
try:
    import albumentations as A
    print("✓ Albumentations loaded successfully")
except ImportError:
    print("Installing albumentations...")
    os.system('pip install -q albumentations')
    import albumentations as A
    print("✓ Albumentations installed and loaded")

# ==============================================================================
# CONFIGURATION - SET THESE BEFORE RUNNING
# ==============================================================================

# Training control flags
TRAIN_IDENTIFIER = True  # Set to False to skip identifier training
TRAIN_CLASSIFIER = True  # Set to False to skip classifier training
RUN_EVALUATION = True    # Set to False to skip evaluation

# Resume training flags
RESUME_IDENTIFIER = False  # Set to True to resume from saved identifier weights
RESUME_CLASSIFIER = False  # Set to True to resume from saved classifier weights

# Model paths
IDENTIFIER_MODEL_PATH = 'identifier_model.h5'
CLASSIFIER_MODEL_PATH = 'best_classifier.h5'

# Training epochs
IDENTIFIER_EPOCHS = 15
CLASSIFIER_EPOCHS = 50

# Data paths (adjust if needed)
DATA_ROOT = '/kaggle/input/2025-karyogram-cv-camp/2025_Karyogram_CV_Camp'

# ==============================================================================
# DEVICE SETUP - TPU v5e-8 OPTIMIZED
# ==============================================================================

tf.keras.backend.clear_session()

def setup_device_strategy():
    """
    Setup device strategy with fallback: TPU -> GPU -> CPU
    
    TPU v5e-8 Setup (Kaggle):
    - 8 cores with improved performance over v3-8
    - Better memory efficiency
    - Recommended batch sizes per core:
      * Identifier: 4-8 per core
      * Classifier: 128-256 per core
    """
    print("="*60)
    print("DEVICE SETUP")
    print("="*60)
    
    # ==================== TPU DETECTION ====================
    try:
        print("Attempting TPU detection...")
        
        # Step 1: Detect and init the TPU
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        
        # Step 2: Initialize TPU system
        tf.tpu.experimental.initialize_tpu_system(tpu)
        
        # Step 3: Instantiate distribution strategy
        strategy = tf.distribute.TPUStrategy(tpu)
        device_type = "TPU"
        
        print(f"✓ TPU detected and initialized")
        print(f"✓ TPU cores: {strategy.num_replicas_in_sync}")
        print(f"✓ TPU type: v5e-8 (optimized)")
        print(f"✓ Recommended batch sizes:")
        print(f"  - Identifier: {4 * strategy.num_replicas_in_sync}-{8 * strategy.num_replicas_in_sync} total (4-8 per core)")
        print(f"  - Classifier: {128 * strategy.num_replicas_in_sync}-{256 * strategy.num_replicas_in_sync} total (128-256 per core)")
        
        return strategy, device_type
        
    except (ValueError, tf.errors.NotFoundError, Exception) as e:
        print(f"✗ No TPU detected: {type(e).__name__}")
    
    # ==================== GPU DETECTION ====================
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            
            print(f"✓ GPU(s) detected: {len(gpus)} device(s)")
            for i, gpu in enumerate(gpus):
                print(f"  GPU {i}: {gpu.name}")
            
            if len(gpus) > 1:
                strategy = tf.distribute.MirroredStrategy()
                device_type = "Multi-GPU"
                print(f"✓ Using MirroredStrategy with {strategy.num_replicas_in_sync} GPUs")
            else:
                strategy = tf.distribute.get_strategy()
                device_type = "Single-GPU"
                print("✓ Using single GPU")
            
            return strategy, device_type
            
        except RuntimeError as e:
            print(f"✗ GPU setup failed: {e}")
    else:
        print("✗ No GPU detected")
    
    # ==================== CPU FALLBACK ====================
    print("✓ Falling back to CPU")
    strategy = tf.distribute.get_strategy()
    device_type = "CPU"
    print("⚠ Warning: Training on CPU will be significantly slower")
    print("  Consider enabling GPU or TPU in Kaggle notebook settings")
    
    return strategy, device_type

STRATEGY, DEVICE_TYPE = setup_device_strategy()
print(f"\n{'='*60}")
print(f"CONFIGURATION: {DEVICE_TYPE}")
print(f"{'='*60}")
print(f"Devices: {STRATEGY.num_replicas_in_sync}")
print(f"TensorFlow: {tf.__version__}")
print(f"{'='*60}\n")

# ==============================================================================
# SHARED UTILITIES
# ==============================================================================

class AnnotationParser:
    """Parse XML annotations"""

    @staticmethod
    def parse_xml(xml_path: str) -> Dict:
        tree = ET.parse(xml_path)
        root = tree.getroot()

        annotations = {
            'filename': root.find('filename').text if root.find('filename') is not None else '',
            'size': {},
            'objects': []
        }

        size = root.find('size')
        if size is not None:
            annotations['size'] = {
                'width': int(size.find('width').text),
                'height': int(size.find('height').text),
                'depth': int(size.find('depth').text) if size.find('depth') is not None else 3
            }

        for obj in root.findall('object'):
            name = obj.find('name').text
            bndbox = obj.find('bndbox')

            if bndbox is not None:
                bbox = {
                    'xmin': int(float(bndbox.find('xmin').text)),
                    'ymin': int(float(bndbox.find('ymin').text)),
                    'xmax': int(float(bndbox.find('xmax').text)),
                    'ymax': int(float(bndbox.find('ymax').text))
                }

                annotations['objects'].append({
                    'name': name,
                    'bbox': bbox
                })

        return annotations


# ==============================================================================
# IDENTIFIER MODULE
# ==============================================================================

class ChromosomeIdentifierDataGenerator(keras.utils.Sequence):
    """Data generator for training chromosome identifier"""

    def __init__(self, images_dir: str, annotations_dir: str,
                 batch_size=2, img_size=800, shuffle=True):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.shuffle = shuffle
        self.parser = AnnotationParser()

        self.image_list = []
        for filename in os.listdir(images_dir):
            if filename.endswith('.jpg') or filename.endswith('.png'):
                base_name = os.path.splitext(filename)[0]
                xml_path = os.path.join(annotations_dir, f"{base_name}.xml")
                if os.path.exists(xml_path):
                    self.image_list.append(base_name)

        if len(self.image_list) == 0:
            raise ValueError(f"No valid image-annotation pairs found in {images_dir}")

        print(f"Found {len(self.image_list)} images for identifier")

        self.indexes = np.arange(len(self.image_list))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        return int(np.ceil(len(self.image_list) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_images = [self.image_list[k] for k in batch_indexes]
        X, y = self._generate_data(batch_images)
        
        # Handle empty batches by trying next batch
        if len(X) == 0 and index < len(self) - 1:
            return self.__getitem__((index + 1) % len(self))
        
        return X, y

    def _generate_data(self, batch_images):
        images = []
        targets = []

        for img_name in batch_images:
            img_path = os.path.join(self.images_dir, f"{img_name}.jpg")
            xml_path = os.path.join(self.annotations_dir, f"{img_name}.xml")

            if not os.path.exists(img_path):
                img_path = os.path.join(self.images_dir, f"{img_name}.png")

            if not os.path.exists(img_path) or not os.path.exists(xml_path):
                continue

            try:
                image = cv2.imread(img_path)
                if image is None:
                    continue
                    
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                orig_h, orig_w = image.shape[:2]

                image = cv2.resize(image, (self.img_size, self.img_size))
                image = image.astype(np.float32) / 255.0

                annotations = self.parser.parse_xml(xml_path)

                boxes = []
                for obj in annotations['objects']:
                    bbox = obj['bbox']
                    xmin = bbox['xmin'] / orig_w
                    ymin = bbox['ymin'] / orig_h
                    xmax = bbox['xmax'] / orig_w
                    ymax = bbox['ymax'] / orig_h
                    boxes.append([ymin, xmin, ymax, xmax])

                images.append(image)
                targets.append(np.array(boxes) if len(boxes) > 0 else np.zeros((0, 4)))
            except Exception as e:
                print(f"Warning: Error processing {img_name}: {e}")
                continue

        return images, targets

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)


class ChromosomeIdentifier:
    """Chromosome detection model"""

    def __init__(self, img_size=800, max_detections=50):
        self.img_size = img_size
        self.max_detections = max_detections
        self.model = None

    def _build_model(self):
        """Build model"""
        inputs = layers.Input(shape=(self.img_size, self.img_size, 3))

        base_model = ResNet50(
            include_top=False,
            weights='imagenet',
            input_tensor=inputs,
            pooling=None
        )

        for layer in base_model.layers[:100]:
            layer.trainable = False

        x = base_model.output
        x = layers.Conv2D(512, 3, padding='same')(x)
        x = layers.ReLU()(x)
        x = layers.BatchNormalization()(x)

        x = layers.Conv2D(512, 3, padding='same')(x)
        x = layers.ReLU()(x)
        x = layers.BatchNormalization()(x)

        x = layers.Conv2D(256, 3, padding='same')(x)
        x = layers.ReLU()(x)
        x = layers.BatchNormalization()(x)

        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(1024)(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.5)(x)

        bbox_output = layers.Dense(self.max_detections * 4, name='bbox')(x)
        objectness_output = layers.Dense(self.max_detections)(x)
        objectness_output = layers.Activation('sigmoid', name='objectness')(objectness_output)

        model = models.Model(inputs=inputs, outputs=[bbox_output, objectness_output])
        return model

    def load_weights(self, path, strategy=None):
        """Load pretrained weights"""
        if os.path.exists(path):
            print(f"✓ Loading identifier weights from {path}")
            if strategy is not None:
                with strategy.scope():
                    self.model = keras.models.load_model(path, compile=False)
                    self.model.compile(
                        optimizer=optimizers.Adam(learning_rate=0.0005),
                        loss={'bbox': 'mse', 'objectness': 'binary_crossentropy'},
                        loss_weights={'bbox': 1.0, 'objectness': 0.5}
                    )
            else:
                self.model = keras.models.load_model(path, compile=False)
                self.model.compile(
                    optimizer=optimizers.Adam(learning_rate=0.0005),
                    loss={'bbox': 'mse', 'objectness': 'binary_crossentropy'},
                    loss_weights={'bbox': 1.0, 'objectness': 0.5}
                )
            print("✓ Identifier weights loaded successfully")
            return True
        else:
            print(f"✗ Weights file not found: {path}")
            return False

    def train(self, train_gen, epochs=15, strategy=None, resume=False, weights_path=None):
        """Train identifier with resume capability"""

        print(f"\n{'='*60}")
        print(f"IDENTIFIER TRAINING")
        print(f"{'='*60}")

        # Try to resume from weights if requested
        if resume and weights_path and os.path.exists(weights_path):
            loaded = self.load_weights(weights_path, strategy)
            if loaded:
                print(f"✓ Resuming training from epoch 0 (with pretrained weights)")
        else:
            # Build new model
            if strategy is not None:
                with strategy.scope():
                    self.model = self._build_model()
                    self.model.compile(
                        optimizer=optimizers.Adam(learning_rate=0.0005),
                        loss={'bbox': 'mse', 'objectness': 'binary_crossentropy'},
                        loss_weights={'bbox': 1.0, 'objectness': 0.5}
                    )
            else:
                self.model = self._build_model()
                self.model.compile(
                    optimizer=optimizers.Adam(learning_rate=0.0005),
                    loss={'bbox': 'mse', 'objectness': 'binary_crossentropy'},
                    loss_weights={'bbox': 1.0, 'objectness': 0.5}
                )

        # Training loop
        for epoch in range(epochs):
            print(f"\nEpoch {epoch+1}/{epochs}")
            epoch_losses = []

            for batch_idx in range(len(train_gen)):
                try:
                    images, targets = train_gen[batch_idx]

                    if len(images) == 0:
                        continue

                    images = np.array(images, dtype=np.float32)
                    batch_size = len(images)

                    target_bbox_batch = np.zeros((batch_size, self.max_detections * 4), dtype=np.float32)
                    obj_target_batch = np.zeros((batch_size, self.max_detections), dtype=np.float32)

                    for i, target in enumerate(targets):
                        if len(target) > 0:
                            num_objs = min(len(target), self.max_detections)
                            target_flat = target[:num_objs].flatten()
                            target_bbox_batch[i, :len(target_flat)] = target_flat
                            obj_target_batch[i, :num_objs] = 1.0

                    loss = self.model.train_on_batch(
                        images,
                        {'bbox': target_bbox_batch, 'objectness': obj_target_batch}
                    )

                    epoch_losses.append(loss[0] if isinstance(loss, list) else loss)

                    if (batch_idx + 1) % 50 == 0:
                        print(f"  Batch {batch_idx+1}/{len(train_gen)}, Loss: {np.mean(epoch_losses[-50:]):.4f}")
                except Exception as e:
                    print(f"Warning: Error in batch {batch_idx}: {e}")
                    continue

            if len(epoch_losses) > 0:
                print(f"Epoch {epoch+1} Loss: {np.mean(epoch_losses):.4f}")
            else:
                print(f"Warning: No valid batches in epoch {epoch+1}")

        return self.model

    def predict_boxes(self, image, confidence_threshold=0.5):
        """Predict bounding boxes"""
        orig_h, orig_w = image.shape[:2]
        image_resized = cv2.resize(image, (self.img_size, self.img_size))
        image_resized = image_resized.astype(np.float32) / 255.0
        image_resized = np.expand_dims(image_resized, axis=0)

        bbox_pred, obj_pred = self.model.predict(image_resized, verbose=0)

        bbox_pred = bbox_pred.reshape(self.max_detections, 4)
        obj_pred = obj_pred.reshape(self.max_detections)

        keep = obj_pred > confidence_threshold
        boxes = bbox_pred[keep]
        scores = obj_pred[keep]

        final_boxes = []
        for box in boxes:
            ymin, xmin, ymax, xmax = box
            xmin = int(np.clip(xmin * orig_w, 0, orig_w))
            ymin = int(np.clip(ymin * orig_h, 0, orig_h))
            xmax = int(np.clip(xmax * orig_w, 0, orig_w))
            ymax = int(np.clip(ymax * orig_h, 0, orig_h))

            if xmax > xmin and ymax > ymin:
                final_boxes.append([xmin, ymin, xmax, ymax])

        return np.array(final_boxes), scores[:len(final_boxes)]


# ==============================================================================
# CLASSIFIER MODULE
# ==============================================================================

class ChromosomeClassifierDataGenerator(keras.utils.Sequence):
    """Data generator for classifier"""

    CHROMOSOME_CLASSES = {
        '1': 0, '2': 1, '3': 2, '4': 3, '5': 4, '6': 5,
        '7': 6, '8': 7, '9': 8, '10': 9, '11': 10, '12': 11,
        '13': 12, '14': 13, '15': 14, '16': 15, '17': 16, '18': 17,
        '19': 18, '20': 19, '21': 20, '22': 21, 'X': 22, 'Y': 23
    }

    def __init__(self, images_dir: str, annotations_dir: str, is_validation=False,
                 batch_size=32, img_size=224, augment=False, validation_split=0.2,
                 data=None, indexes=None):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.batch_size = batch_size
        self.img_size = img_size
        self.augment = augment
        self.validation_split = validation_split
        self.is_validation = is_validation
        self.parser = AnnotationParser()

        if data is None or indexes is None:
            self.chromosome_data = []
            self._build_dataset()

            if len(self.chromosome_data) == 0:
                raise ValueError("No chromosome data found!")

            if not self.is_validation:
                total_samples = len(self.chromosome_data)
                val_size = int(total_samples * validation_split)

                np.random.seed(42)
                all_indexes = np.arange(total_samples)
                np.random.shuffle(all_indexes)

                self.indexes = all_indexes[val_size:]
                self.val_indexes = all_indexes[:val_size]

                print(f"Train: {len(self.indexes)}, Val: {len(self.val_indexes)}")
            else:
                self.chromosome_data = data
                self.indexes = indexes
                self.val_indexes = []
        else:
            self.chromosome_data = data
            self.indexes = indexes
            self.val_indexes = []

        self._print_class_distribution()

        if augment and not self.is_validation:
            self.aug = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.Rotate(limit=180, p=0.7),
                A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                A.GaussNoise(var_limit=(10, 50), p=0.3),
                A.GaussianBlur(blur_limit=(3, 5), p=0.2),
                A.CLAHE(clip_limit=2.0, p=0.3),
            ])
        else:
            self.aug = None

    def _build_dataset(self):
        """Build dataset"""
        print("\nBuilding classifier dataset...")

        image_files = []
        for filename in os.listdir(self.images_dir):
            if filename.endswith(('.jpg', '.png')):
                base_name = os.path.splitext(filename)[0]
                xml_path = os.path.join(self.annotations_dir, f"{base_name}.xml")
                if os.path.exists(xml_path):
                    image_files.append(base_name)

        print(f"Processing {len(image_files)} images...")

        for idx, img_name in enumerate(image_files):
            if (idx + 1) % 100 == 0:
                print(f"  Processed {idx+1}/{len(image_files)}...")

            img_path = os.path.join(self.images_dir, f"{img_name}.jpg")
            xml_path = os.path.join(self.annotations_dir, f"{img_name}.xml")

            if not os.path.exists(img_path):
                img_path = os.path.join(self.images_dir, f"{img_name}.png")

            if not os.path.exists(img_path) or not os.path.exists(xml_path):
                continue

            try:
                image = cv2.imread(img_path)
                if image is None:
                    continue

                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                annotations = self.parser.parse_xml(xml_path)

                for obj in annotations['objects']:
                    bbox = obj['bbox']
                    label = obj['name']

                    if label in self.CHROMOSOME_CLASSES:
                        chromosome_img = image[bbox['ymin']:bbox['ymax'],
                                            bbox['xmin']:bbox['xmax']]

                        if chromosome_img.size > 0:
                            self.chromosome_data.append({
                                'image': chromosome_img,
                                'label': self.CHROMOSOME_CLASSES[label],
                                'label_name': label
                            })
            except Exception as e:
                print(f"Warning: Error processing {img_name}: {e}")
                continue

        print(f"Dataset: {len(self.chromosome_data)} samples")

    def _print_class_distribution(self):
        """Print class distribution"""
        if len(self.chromosome_data) == 0:
            return

        class_counts = {}
        for data in self.chromosome_data:
            label = data['label']
            class_counts[label] = class_counts.get(label, 0) + 1

        print("\nClass Distribution:")
        for class_id in sorted(class_counts.keys()):
            class_name = [k for k, v in self.CHROMOSOME_CLASSES.items() if v == class_id][0]
            print(f"  Class {class_name:>2}: {class_counts[class_id]:5d} samples")

    def get_validation_generator(self):
        """Create validation generator"""
        val_gen = ChromosomeClassifierDataGenerator(
            images_dir=self.images_dir,
            annotations_dir=self.annotations_dir,
            batch_size=self.batch_size,
            img_size=self.img_size,
            augment=False,
            validation_split=0.0,
            is_validation=True,
            data=[self.chromosome_data[i] for i in self.val_indexes],
            indexes=np.arange(len(self.val_indexes))
        )
        return val_gen

    def __len__(self):
        return int(np.ceil(len(self.indexes) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        batch_images = []
        batch_labels = []

        for idx in batch_indexes:
            try:
                data = self.chromosome_data[idx]
                image = data['image'].copy()
                label = data['label']

                if self.aug is not None:
                    augmented = self.aug(image=image)
                    image = augmented['image']

                image = cv2.resize(image, (self.img_size, self.img_size))
                image = image.astype(np.float32) / 255.0

                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                image = (image - mean) / std

                batch_images.append(image)
                batch_labels.append(label)
            except Exception as e:
                continue

        # Handle empty batches by trying next batch
        if len(batch_images) == 0 and index < len(self) - 1:
            return self.__getitem__((index + 1) % len(self))
        
        # Last resort: return minimal valid batch
        if len(batch_images) == 0:
            batch_images = [np.zeros((self.img_size, self.img_size, 3), dtype=np.float32)]
            batch_labels = [0]

        return np.array(batch_images, dtype=np.float32), np.array(batch_labels, dtype=np.int32)

    def on_epoch_end(self):
        if not self.is_validation:
            np.random.shuffle(self.indexes)


class ChromosomeClassifier:
    """Chromosome classifier"""

    def __init__(self, num_classes=24, img_size=224):
        self.num_classes = num_classes
        self.img_size = img_size
        self.model = None

    def _build_model(self):
        """Build model"""
        base_model = ResNet50(
            include_top=False,
            weights='imagenet',
            input_shape=(self.img_size, self.img_size, 3)
        )

        for layer in base_model.layers[:120]:
            layer.trainable = False
        for layer in base_model.layers[120:]:
            layer.trainable = True

        x = base_model.output
        x = layers.GlobalAveragePooling2D()(x)

        x = layers.Dense(1024)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.5)(x)

        x = layers.Dense(512)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.4)(x)

        x = layers.Dense(256)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        x = layers.Dropout(0.3)(x)

        outputs = layers.Dense(self.num_classes)(x)
        outputs = layers.Activation('softmax')(outputs)

        model = models.Model(inputs=base_model.input, outputs=outputs)
        return model

    def load_weights(self, path, strategy=None):
        """Load pretrained weights"""
        if os.path.exists(path):
            print(f"✓ Loading classifier weights from {path}")
            if strategy is not None:
                with strategy.scope():
                    self.model = keras.models.load_model(path, compile=False)
                    self.model.compile(
                        optimizer=optimizers.Adam(0.001),
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy']
                    )
            else:
                self.model = keras.models.load_model(path, compile=False)
                self.model.compile(
                    optimizer=optimizers.Adam(0.001),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy']
                )
            print("✓ Classifier weights loaded successfully")
            return True
        else:
            print(f"✗ Weights file not found: {path}")
            return False

    def compile_model(self, lr=0.001, strategy=None, resume=False, weights_path=None):
        """Compile model with resume capability"""
        
        if resume and weights_path and os.path.exists(weights_path):
            loaded = self.load_weights(weights_path, strategy)
            if loaded:
                print("✓ Resuming from checkpoint")
                return
        
        if strategy is not None:
            with strategy.scope():
                self.model = self._build_model()
                self.model.compile(
                    optimizer=optimizers.Adam(lr),
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy']
                )
        else:
            self.model = self._build_model()
            self.model.compile(
                optimizer=optimizers.Adam(lr),
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy']
            )

    def train(self, train_gen, val_gen, epochs=50):
        """Train classifier"""
        callbacks = [
            keras.callbacks.ReduceLROnPlateau(
                monitor='val_accuracy',
                factor=0.5,
                patience=5,
                min_lr=1e-7,
                verbose=1
            ),
            keras.callbacks.ModelCheckpoint(
                'best_classifier.h5',
                monitor='val_accuracy',
                save_best_only=True,
                verbose=1
            ),
            keras.callbacks.EarlyStopping(
                monitor='val_accuracy',
                patience=10,
                restore_best_weights=True,
                verbose=1
            )
        ]

        history = self.model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=epochs,
            callbacks=callbacks,
            verbose=1
        )
        return history


# ==============================================================================
# PIPELINE AND EVALUATION
# ==============================================================================

class KaryotypePipeline:
    """End-to-end pipeline"""

    def __init__(self, identifier: ChromosomeIdentifier,
                 classifier: ChromosomeClassifier):
        self.identifier = identifier
        self.classifier = classifier
        self.img_size = 224

    def process_image(self, image_path: str, confidence_threshold=0.5):
        image = cv2.imread(image_path)
        if image is None:
            print(f"Warning: Could not read image {image_path}")
            return []
            
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        boxes, scores = self.identifier.predict_boxes(image_rgb, confidence_threshold=confidence_threshold)

        if len(boxes) == 0:
            return []

        predictions = []

        for i, box in enumerate(boxes):
            xmin, ymin, xmax, ymax = map(int, box)

            chromosome_img = image_rgb[ymin:ymax, xmin:xmax]

            if chromosome_img.size == 0:
                continue

            chromosome_img = cv2.resize(chromosome_img, (self.img_size, self.img_size))
            chromosome_img = chromosome_img.astype(np.float32) / 255.0

            mean = np.array([0.485, 0.456, 0.406])
            std = np.array([0.229, 0.224, 0.225])
            chromosome_img = (chromosome_img - mean) / std

            chromosome_img = np.expand_dims(chromosome_img, axis=0)

            probs = self.classifier.model.predict(chromosome_img, verbose=0)[0]
            class_id = np.argmax(probs)
            confidence = probs[class_id]

            predictions.append({
                'bbox': box,
                'class_id': int(class_id),
                'class_name': self._get_class_name(class_id),
                'confidence': float(confidence),
                'score': float(scores[i]) if i < len(scores) else 1.0
            })

        return predictions

    @staticmethod
    def _get_class_name(class_id: int) -> str:
        if class_id < 22:
            return str(class_id + 1)
        elif class_id == 22:
            return 'X'
        else:
            return 'Y'


def evaluate_pipeline_simple(pipeline, images_dir, annotations_dir, test_list, output_dir='results'):
    """Simple evaluation for auPRC"""
    os.makedirs(output_dir, exist_ok=True)
    parser = AnnotationParser()

    all_true_labels = []
    all_pred_probs = []

    print("Evaluating pipeline...")
    for idx, img_name in enumerate(test_list):
        img_name = img_name.strip()
        img_path = os.path.join(images_dir, f"{img_name}.jpg")
        xml_path = os.path.join(annotations_dir, f"{img_name}.xml")

        if not os.path.exists(img_path):
            img_path = os.path.join(images_dir, f"{img_name}.png")

        if not os.path.exists(img_path) or not os.path.exists(xml_path):
            continue

        if (idx + 1) % 10 == 0:
            print(f"  Processed {idx+1}/{len(test_list)}...")

        try:
            annotations = parser.parse_xml(xml_path)
            gt_chromosomes = [obj['name'] for obj in annotations['objects']]

            predictions = pipeline.process_image(img_path)

            for gt_chr in gt_chromosomes:
                if gt_chr in ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES:
                    gt_class = ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES[gt_chr]

                    true_labels = np.zeros(24)
                    true_labels[gt_class] = 1

                    pred_probs = np.zeros(24)
                    for pred in predictions:
                        pred_probs[pred['class_id']] = max(
                            pred_probs[pred['class_id']],
                            pred['confidence']
                        )

                    all_true_labels.append(true_labels)
                    all_pred_probs.append(pred_probs)
        except Exception as e:
            print(f"Warning: Error processing {img_name}: {e}")
            continue

    if len(all_true_labels) == 0:
        print("Warning: No valid samples found for evaluation!")
        return 0.0, []

    all_true_labels = np.array(all_true_labels)
    all_pred_probs = np.array(all_pred_probs)

    auprc_scores = []
    plt.figure(figsize=(15, 10))

    CLASS_NAMES = [str(i+1) for i in range(22)] + ['X', 'Y']

    for i in range(24):
        if np.sum(all_true_labels[:, i]) > 0:
            precision, recall, _ = precision_recall_curve(
                all_true_labels[:, i],
                all_pred_probs[:, i]
            )
            auprc = auc(recall, precision)
            auprc_scores.append(auprc)

            plt.plot(recall, precision,
                    label=f'Class {CLASS_NAMES[i]} (auPRC={auprc:.3f})')

    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curves', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'auprc_curves.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    mean_auprc = np.mean(auprc_scores) if len(auprc_scores) > 0 else 0.0
    print(f"\nMean auPRC: {mean_auprc:.4f}")

    with open(os.path.join(output_dir, 'auprc_scores.txt'), 'w') as f:
        f.write(f"Mean auPRC: {mean_auprc:.4f}\n\n")
        f.write("Per-class auPRC scores:\n")
        class_idx = 0
        for i in range(24):
            if np.sum(all_true_labels[:, i]) > 0:
                f.write(f"Class {CLASS_NAMES[i]}: {auprc_scores[class_idx]:.4f}\n")
                class_idx += 1

    return mean_auprc, auprc_scores


def evaluate_comprehensive(pipeline, images_dir, annotations_dir, test_list, output_dir='evaluation_results'):
    """Comprehensive evaluation with confusion matrix"""
    os.makedirs(output_dir, exist_ok=True)
    parser = AnnotationParser()

    all_true_labels = []
    all_pred_labels = []
    all_pred_probs = []

    CLASS_NAMES = [str(i+1) for i in range(22)] + ['X', 'Y']

    print("Computing comprehensive metrics...")
    for idx, img_name in enumerate(test_list):
        img_name = img_name.strip()
        img_path = os.path.join(images_dir, f"{img_name}.jpg")
        xml_path = os.path.join(annotations_dir, f"{img_name}.xml")

        if not os.path.exists(img_path):
            img_path = os.path.join(images_dir, f"{img_name}.png")

        if not os.path.exists(img_path) or not os.path.exists(xml_path):
            continue

        if (idx + 1) % 20 == 0:
            print(f"  Processed {idx+1}/{len(test_list)}...")

        try:
            annotations = parser.parse_xml(xml_path)
            predictions = pipeline.process_image(img_path, confidence_threshold=0.5)

            for obj in annotations['objects']:
                gt_chr = obj['name']
                if gt_chr not in ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES:
                    continue

                gt_class = ChromosomeClassifierDataGenerator.CHROMOSOME_CLASSES[gt_chr]

                if predictions:
                    best_pred = max(predictions, key=lambda x: x['confidence'])
                    pred_class = best_pred['class_id']

                    pred_probs = np.zeros(24)
                    pred_probs[pred_class] = best_pred['confidence']
                else:
                    pred_class = 0
                    pred_probs = np.zeros(24)
                    pred_probs[0] = 1.0

                all_true_labels.append(gt_class)
                all_pred_labels.append(pred_class)
                all_pred_probs.append(pred_probs)
        except Exception as e:
            print(f"Warning: Error processing {img_name}: {e}")
            continue

    if len(all_true_labels) == 0:
        print("Warning: No valid samples found for comprehensive evaluation!")
        return {'Overall Accuracy': 0.0, 'Classes Present': 0}

    all_true_labels = np.array(all_true_labels)
    all_pred_labels = np.array(all_pred_labels)
    all_pred_probs = np.array(all_pred_probs)

    print(f"\nTotal samples: {len(all_true_labels)}")

    # Get unique classes
    unique_classes = np.unique(np.concatenate([all_true_labels, all_pred_labels]))
    print(f"Unique classes: {len(unique_classes)}")

    present_class_names = [CLASS_NAMES[i] for i in unique_classes]

    # Confusion Matrix
    cm = confusion_matrix(all_true_labels, all_pred_labels, labels=unique_classes)

    plt.figure(figsize=(max(12, len(unique_classes)), max(10, len(unique_classes)-2)))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=present_class_names, yticklabels=present_class_names,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted Class', fontsize=14, fontweight='bold')
    plt.ylabel('True Class', fontsize=14, fontweight='bold')
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'confusion_matrix.png'),
                dpi=300, bbox_inches='tight')
    plt.close()

    # Classification Report
    report = classification_report(all_true_labels, all_pred_labels,
                                   labels=unique_classes,
                                   target_names=present_class_names,
                                   output_dict=True, zero_division=0)

    with open(os.path.join(output_dir, 'classification_report.txt'), 'w') as f:
        f.write("="*80 + "\n")
        f.write("CLASSIFICATION REPORT\n")
        f.write("="*80 + "\n\n")
        f.write(classification_report(all_true_labels, all_pred_labels,
                                     labels=unique_classes,
                                     target_names=present_class_names,
                                     zero_division=0))

    # Metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        all_true_labels, all_pred_labels, labels=unique_classes, average=None, zero_division=0
    )

    overall_acc = accuracy_score(all_true_labels, all_pred_labels)

    summary = {
        'Overall Accuracy': overall_acc,
        'Classes Present': len(unique_classes)
    }

    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("="*80 + "\n")
        f.write("SUMMARY\n")
        f.write("="*80 + "\n\n")
        f.write(f"Overall Accuracy: {overall_acc:.4f}\n")
        f.write(f"Classes Present: {len(unique_classes)}\n")

    print(f"✓ Results saved to '{output_dir}/'")
    return summary


# ==============================================================================
# MAIN EXECUTION
# ==============================================================================

if __name__ == "__main__":
    
    print("\n" + "="*60)
    print("CHROMOSOME CLASSIFICATION PIPELINE - TPU v5e-8")
    print("="*60)
    print(f"Train Identifier: {TRAIN_IDENTIFIER}")
    print(f"Train Classifier: {TRAIN_CLASSIFIER}")
    print(f"Run Evaluation: {RUN_EVALUATION}")
    print(f"Resume Identifier: {RESUME_IDENTIFIER}")
    print(f"Resume Classifier: {RESUME_CLASSIFIER}")
    print("="*60 + "\n")

    # Setup paths
    SINGLE_CHR_IMAGES = os.path.join(DATA_ROOT, 'single_chromosomes_object', 'images')
    SINGLE_CHR_ANNOTATIONS = os.path.join(DATA_ROOT, 'single_chromosomes_object', 'annotations')
    MULTI_CHR_IMAGES = os.path.join(DATA_ROOT, '24_chromosomes_object', 'images')
    MULTI_CHR_ANNOTATIONS = os.path.join(DATA_ROOT, '24_chromosomes_object', 'annotations')

    # TPU v5e-8 Optimized Batch Sizes
    if DEVICE_TYPE == "TPU":
        # TPU v5e-8: 8 cores
        # Identifier: 4 per core = 32 total (conservative for 800x800 images)
        # Classifier: 128 per core = 1024 total (optimal for 224x224 images)
        IDENTIFIER_BATCH = 4 * STRATEGY.num_replicas_in_sync
        CLASSIFIER_BATCH = 128 * STRATEGY.num_replicas_in_sync
        print(f"\n✓ TPU v5e-8 Batch Sizes:")
        print(f"  Identifier: {IDENTIFIER_BATCH} (4 per core)")
        print(f"  Classifier: {CLASSIFIER_BATCH} (128 per core)")
    elif DEVICE_TYPE in ["Single-GPU", "Multi-GPU"]:
        IDENTIFIER_BATCH = 2 * STRATEGY.num_replicas_in_sync
        CLASSIFIER_BATCH = 32 * STRATEGY.num_replicas_in_sync
        print(f"\n✓ GPU Batch Sizes:")
        print(f"  Identifier: {IDENTIFIER_BATCH}")
        print(f"  Classifier: {CLASSIFIER_BATCH}")
    else:
        IDENTIFIER_BATCH = 1
        CLASSIFIER_BATCH = 8
        print(f"\n✓ CPU Batch Sizes:")
        print(f"  Identifier: {IDENTIFIER_BATCH}")
        print(f"  Classifier: {CLASSIFIER_BATCH}")

    # ==============================================================================
    # STEP 1: TRAIN OR LOAD IDENTIFIER
    # ==============================================================================

    identifier = ChromosomeIdentifier(img_size=800, max_detections=50)

    if TRAIN_IDENTIFIER:
        print("\n" + "="*60)
        print("STEP 1: IDENTIFIER TRAINING")
        print("="*60)

        try:
            train_identifier_gen = ChromosomeIdentifierDataGenerator(
                SINGLE_CHR_IMAGES, SINGLE_CHR_ANNOTATIONS,
                batch_size=IDENTIFIER_BATCH, img_size=800, shuffle=True
            )

            identifier.train(
                train_identifier_gen,
                epochs=IDENTIFIER_EPOCHS,
                strategy=STRATEGY,
                resume=RESUME_IDENTIFIER,
                weights_path=IDENTIFIER_MODEL_PATH
            )

            identifier.model.save(IDENTIFIER_MODEL_PATH)
            print(f"\n✓ Identifier saved to '{IDENTIFIER_MODEL_PATH}'")
        except Exception as e:
            print(f"✗ Error during identifier training: {e}")
            raise
    else:
        print("\n" + "="*60)
        print("STEP 1: LOADING IDENTIFIER")
        print("="*60)
        
        if os.path.exists(IDENTIFIER_MODEL_PATH):
            identifier.load_weights(IDENTIFIER_MODEL_PATH, STRATEGY)
        else:
            raise FileNotFoundError(f"Model file not found: {IDENTIFIER_MODEL_PATH}. "
                                   "Please set TRAIN_IDENTIFIER=True or provide valid model path")

    # ==============================================================================
    # STEP 2: TRAIN OR LOAD CLASSIFIER
    # ==============================================================================

    classifier = ChromosomeClassifier(num_classes=24)

    if TRAIN_CLASSIFIER:
        print("\n" + "="*60)
        print("STEP 2: CLASSIFIER TRAINING")
        print("="*60)

        try:
            train_classifier_gen = ChromosomeClassifierDataGenerator(
                MULTI_CHR_IMAGES, MULTI_CHR_ANNOTATIONS,
                batch_size=CLASSIFIER_BATCH, img_size=224, augment=True, validation_split=0.2
            )

            val_classifier_gen = train_classifier_gen.get_validation_generator()

            classifier.compile_model(
                lr=0.001,
                strategy=STRATEGY,
                resume=RESUME_CLASSIFIER,
                weights_path=CLASSIFIER_MODEL_PATH
            )

            print("\nStarting classifier training...")
            history = classifier.train(train_classifier_gen, val_classifier_gen, epochs=CLASSIFIER_EPOCHS)

            print(f"\n✓ Classifier saved to '{CLASSIFIER_MODEL_PATH}'")
        except Exception as e:
            print(f"✗ Error during classifier training: {e}")
            raise
    else:
        print("\n" + "="*60)
        print("STEP 2: LOADING CLASSIFIER")
        print("="*60)
        
        if os.path.exists(CLASSIFIER_MODEL_PATH):
            classifier.load_weights(CLASSIFIER_MODEL_PATH, STRATEGY)
        else:
            raise FileNotFoundError(f"Model file not found: {CLASSIFIER_MODEL_PATH}. "
                                   "Please set TRAIN_CLASSIFIER=True or provide valid model path")

    # ==============================================================================
    # STEP 3: EVALUATION
    # ==============================================================================

    if RUN_EVALUATION:
        print("\n" + "="*60)
        print("STEP 3: EVALUATION")
        print("="*60)

        try:
            # Load test list
            test_file_path = os.path.join(DATA_ROOT, 'test.txt')
            if not os.path.exists(test_file_path):
                print(f"Warning: test.txt not found at {test_file_path}")
                print("Skipping evaluation...")
            else:
                with open(test_file_path, 'r') as f:
                    test_list = f.readlines()

                print(f"Test samples: {len(test_list)}")

                # Create pipeline
                pipeline = KaryotypePipeline(identifier, classifier)

                # Simple evaluation (auPRC)
                print("\nRunning auPRC evaluation...")
                mean_auprc, class_auprcs = evaluate_pipeline_simple(
                    pipeline, MULTI_CHR_IMAGES, MULTI_CHR_ANNOTATIONS,
                    test_list, output_dir='results'
                )

                # Comprehensive evaluation
                print("\nRunning comprehensive evaluation...")
                summary = evaluate_comprehensive(
                    pipeline, MULTI_CHR_IMAGES, MULTI_CHR_ANNOTATIONS,
                    test_list, output_dir='evaluation_results'
                )

                print("\n" + "="*60)
                print("EVALUATION COMPLETE!")
                print("="*60)
                print(f"Mean auPRC: {mean_auprc:.4f}")
                print(f"Overall Accuracy: {summary['Overall Accuracy']:.4f}")
                print(f"Classes Detected: {summary['Classes Present']}")
                print("\nResults saved to:")
                print("  - results/auprc_curves.png")
                print("  - results/auprc_scores.txt")
                print("  - evaluation_results/confusion_matrix.png")
                print("  - evaluation_results/classification_report.txt")
                print("="*60)
        except Exception as e:
            print(f"✗ Error during evaluation: {e}")
            raise

    print("\n" + "="*60)
    print("PIPELINE EXECUTION COMPLETE!")
    print("="*60)