# Improved Cervical Cancer Cell Classification
## Project Phoenix - Enhanced Pipeline

**Improvements over baseline:**
1. Proper train/validation/test split
2. Class-weighted loss for imbalanced data
3. Modern architectures (EfficientNet, ConvNeXt)
4. Cosine annealing with warm restarts
5. Mixup and CutMix augmentation
6. Gradual unfreezing for fine-tuning
7. Ensemble methods
8. Test-time augmentation (TTA)

In [None]:
# Install required packages
# !pip install torch torchvision timm scikit-learn matplotlib seaborn pillow tqdm shap streamlit

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler, Subset
import torchvision.transforms as transforms
import torchvision.models as models

# timm for modern architectures
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    precision_recall_fscore_support, balanced_accuracy_score
)
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.svm import SVC

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

## Configuration

In [None]:
# Paths
DATA_PATH = "/content/drive/MyDrive/Projects/6_Project Phoenix_Cervical Cancer Cell Classification/Herlev Dataset/Preprocessing Analysis v3.0"
TRAIN_PATH = os.path.join(DATA_PATH, "train")
TEST_PATH = os.path.join(DATA_PATH, "test")

# Hyperparameters
IMG_SIZE = 224
BATCH_SIZE = 16  # Smaller batch for better generalization
NUM_EPOCHS = 30
LEARNING_RATE = 1e-4  # Lower LR for fine-tuning
WEIGHT_DECAY = 1e-4
NUM_CLASSES = 7
VAL_SPLIT = 0.15  # 15% of training data for validation

# Early stopping
PATIENCE = 7

## 1. Data Loading with Proper Splits

In [None]:
class HerlevDataset(Dataset):
    """Enhanced dataset with support for various augmentations."""
    
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


def load_data_paths(root_dir):
    """Load all image paths and labels from directory."""
    image_paths = []
    labels = []
    class_names = sorted(os.listdir(root_dir))
    class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}
    
    for class_name in class_names:
        class_path = os.path.join(root_dir, class_name)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                if img_name.lower().endswith(('.bmp', '.png', '.jpg', '.jpeg')):
                    image_paths.append(os.path.join(class_path, img_name))
                    labels.append(class_to_idx[class_name])
    
    return image_paths, labels, class_to_idx


# Load paths
train_paths, train_labels, class_to_idx = load_data_paths(TRAIN_PATH)
test_paths, test_labels, _ = load_data_paths(TEST_PATH)

idx_to_class = {v: k for k, v in class_to_idx.items()}

print(f"Total training samples: {len(train_paths)}")
print(f"Total test samples: {len(test_paths)}")
print(f"Classes: {list(class_to_idx.keys())}")

In [None]:
# Split training into train and validation (stratified)
train_paths_split, val_paths, train_labels_split, val_labels = train_test_split(
    train_paths, train_labels, 
    test_size=VAL_SPLIT, 
    stratify=train_labels,
    random_state=SEED
)

print(f"Training samples: {len(train_paths_split)}")
print(f"Validation samples: {len(val_paths)}")
print(f"Test samples: {len(test_paths)}")

# Class distribution
print("\nTraining class distribution:")
for cls_idx, count in sorted(Counter(train_labels_split).items()):
    print(f"  {idx_to_class[cls_idx]}: {count}")

In [None]:
# Calculate class weights for imbalanced data
class_counts = Counter(train_labels_split)
total_samples = len(train_labels_split)
class_weights = torch.tensor([
    total_samples / (NUM_CLASSES * class_counts[i]) 
    for i in range(NUM_CLASSES)
], dtype=torch.float32).to(device)

print("Class weights:")
for i, w in enumerate(class_weights):
    print(f"  {idx_to_class[i]}: {w:.3f}")

## 2. Advanced Augmentation

In [None]:
# Strong augmentation for training
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),  # Resize larger for random crop
    transforms.RandomCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.2))  # Cutout-like augmentation
])

# Validation/Test transform (no augmentation)
eval_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# TTA transforms
tta_transforms = [
    eval_transform,
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomRotation((90, 90)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
]

In [None]:
# Create datasets
train_dataset = HerlevDataset(train_paths_split, train_labels_split, transform=train_transform)
val_dataset = HerlevDataset(val_paths, val_labels, transform=eval_transform)
test_dataset = HerlevDataset(test_paths, test_labels, transform=eval_transform)

# Weighted sampler for balanced batches
sample_weights = [1.0 / class_counts[label] for label in train_labels_split]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 3. Mixup Augmentation

In [None]:
def mixup_data(x, y, alpha=0.4):
    """Apply mixup augmentation."""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    
    batch_size = x.size(0)
    index = torch.randperm(batch_size).to(device)
    
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Compute mixup loss."""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

## 4. Model Definitions

In [None]:
def create_model(model_name, num_classes, pretrained=True):
    """Create model using timm library for modern architectures."""
    
    if model_name == 'efficientnet_b0':
        model = timm.create_model('efficientnet_b0', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'efficientnet_b2':
        model = timm.create_model('efficientnet_b2', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'convnext_tiny':
        model = timm.create_model('convnext_tiny', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'convnext_small':
        model = timm.create_model('convnext_small', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'resnet50':
        model = timm.create_model('resnet50', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'densenet121':
        model = timm.create_model('densenet121', pretrained=pretrained, num_classes=num_classes)
    elif model_name == 'swin_tiny':
        model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=pretrained, num_classes=num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    return model.to(device)

# List available models
print("Models to evaluate:")
model_names = ['efficientnet_b0', 'efficientnet_b2', 'convnext_tiny', 'swin_tiny']
for m in model_names:
    print(f"  - {m}")

## 5. Training with Best Practices

In [None]:
class EarlyStopping:
    """Early stopping to prevent overfitting."""
    
    def __init__(self, patience=7, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_model = None
    
    def __call__(self, val_score, model):
        if self.best_score is None:
            self.best_score = val_score
            self.best_model = model.state_dict().copy()
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.best_model = model.state_dict().copy()
            self.counter = 0
        
        return self.early_stop

In [None]:
def train_epoch(model, loader, criterion, optimizer, use_mixup=True):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        if use_mixup and np.random.random() > 0.5:
            images, labels_a, labels_b, lam = mixup_data(images, labels)
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
        else:
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        
        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / total, correct / total


def evaluate(model, loader, criterion):
    """Evaluate model."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    
    return total_loss / len(all_labels), acc, bal_acc, np.array(all_preds), np.array(all_labels)

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, model_name):
    """Full training loop with cosine annealing and early stopping."""
    
    # Loss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
    
    # Optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # Cosine annealing scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    
    # Early stopping
    early_stopping = EarlyStopping(patience=PATIENCE)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_bal_acc': []}
    
    print(f"\nTraining {model_name}...")
    
    for epoch in range(num_epochs):
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, use_mixup=True)
        
        # Validate
        val_loss, val_acc, val_bal_acc, _, _ = evaluate(model, val_loader, criterion)
        
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_bal_acc'].append(val_bal_acc)
        
        print(f"Epoch {epoch+1:02d}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.4f}, "
              f"Val Loss={val_loss:.4f}, Val Acc={val_acc:.4f}, Val Bal Acc={val_bal_acc:.4f}")
        
        # Early stopping check
        if early_stopping(val_bal_acc, model):
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    model.load_state_dict(early_stopping.best_model)
    
    return model, history, early_stopping.best_score

## 6. Train Multiple Models

In [None]:
# Train and compare models
models_to_train = ['efficientnet_b0', 'efficientnet_b2', 'convnext_tiny', 'swin_tiny']
results = {}

for model_name in models_to_train:
    print(f"\n{'='*60}")
    print(f"Model: {model_name.upper()}")
    print(f"{'='*60}")
    
    model = create_model(model_name, NUM_CLASSES)
    model, history, best_val_score = train_model(model, train_loader, val_loader, NUM_EPOCHS, model_name)
    
    # Test evaluation
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    test_loss, test_acc, test_bal_acc, preds, labels = evaluate(model, test_loader, criterion)
    
    results[model_name] = {
        'model': model,
        'history': history,
        'best_val_score': best_val_score,
        'test_acc': test_acc,
        'test_bal_acc': test_bal_acc,
        'predictions': preds,
        'labels': labels
    }
    
    print(f"\nTest Accuracy: {test_acc:.4f}")
    print(f"Test Balanced Accuracy: {test_bal_acc:.4f}")
    print(classification_report(labels, preds, target_names=list(class_to_idx.keys())))

## 7. Test-Time Augmentation (TTA)

In [None]:
def predict_with_tta(model, image_paths, labels, tta_transforms):
    """Make predictions using test-time augmentation."""
    model.eval()
    all_preds = []
    all_probs = []
    
    with torch.no_grad():
        for img_path in tqdm(image_paths, desc="TTA Prediction"):
            image = Image.open(img_path).convert('RGB')
            
            # Aggregate predictions from all TTA transforms
            probs_list = []
            for transform in tta_transforms:
                img_tensor = transform(image).unsqueeze(0).to(device)
                output = model(img_tensor)
                probs = F.softmax(output, dim=1)
                probs_list.append(probs.cpu().numpy())
            
            # Average probabilities
            avg_probs = np.mean(probs_list, axis=0)
            pred = np.argmax(avg_probs)
            
            all_preds.append(pred)
            all_probs.append(avg_probs[0])
    
    return np.array(all_preds), np.array(all_probs)

# Apply TTA to best model
best_model_name = max(results, key=lambda x: results[x]['test_bal_acc'])
print(f"\nApplying TTA to best model: {best_model_name}")

best_model = results[best_model_name]['model']
tta_preds, tta_probs = predict_with_tta(best_model, test_paths, test_labels, tta_transforms)

tta_acc = accuracy_score(test_labels, tta_preds)
tta_bal_acc = balanced_accuracy_score(test_labels, tta_preds)

print(f"\nTTA Results for {best_model_name}:")
print(f"  Accuracy: {tta_acc:.4f}")
print(f"  Balanced Accuracy: {tta_bal_acc:.4f}")
print(classification_report(test_labels, tta_preds, target_names=list(class_to_idx.keys())))

## 8. Ensemble Model

In [None]:
def get_model_probabilities(model, dataloader):
    """Get probability predictions from a model."""
    model.eval()
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.vstack(all_probs), np.array(all_labels)


# Ensemble predictions from all models
print("\nCreating ensemble from all trained models...")

ensemble_probs = []
for model_name, result in results.items():
    probs, labels = get_model_probabilities(result['model'], test_loader)
    ensemble_probs.append(probs)
    print(f"  Added {model_name}")

# Average ensemble
avg_ensemble_probs = np.mean(ensemble_probs, axis=0)
ensemble_preds = np.argmax(avg_ensemble_probs, axis=1)

ensemble_acc = accuracy_score(labels, ensemble_preds)
ensemble_bal_acc = balanced_accuracy_score(labels, ensemble_preds)

print(f"\nEnsemble Results:")
print(f"  Accuracy: {ensemble_acc:.4f}")
print(f"  Balanced Accuracy: {ensemble_bal_acc:.4f}")
print(classification_report(labels, ensemble_preds, target_names=list(class_to_idx.keys())))

## 9. Enhanced Hybrid Model

In [None]:
def extract_features_multi_model(models_dict, dataloader):
    """Extract and concatenate features from multiple models."""
    all_features = []
    all_labels = []
    
    for images, labels in tqdm(dataloader, desc="Extracting features"):
        images = images.to(device)
        batch_features = []
        
        with torch.no_grad():
            for model_name, result in models_dict.items():
                model = result['model']
                model.eval()
                
                # Get features before final classifier
                if 'efficientnet' in model_name:
                    features = model.forward_features(images)
                    features = model.global_pool(features)
                elif 'convnext' in model_name:
                    features = model.forward_features(images)
                    features = model.head.global_pool(features)
                elif 'swin' in model_name:
                    features = model.forward_features(images)
                else:
                    features = model.forward_features(images)
                    if len(features.shape) > 2:
                        features = F.adaptive_avg_pool2d(features, 1).flatten(1)
                
                if len(features.shape) > 2:
                    features = features.mean(dim=(1, 2)) if len(features.shape) == 4 else features.mean(dim=1)
                
                batch_features.append(features.cpu().numpy())
        
        # Concatenate features from all models
        combined = np.concatenate(batch_features, axis=1)
        all_features.append(combined)
        all_labels.extend(labels.numpy())
    
    return np.vstack(all_features), np.array(all_labels)


# Extract multi-model features
print("Extracting features from all models...")
X_train_feat, y_train_feat = extract_features_multi_model(results, train_loader)
X_val_feat, y_val_feat = extract_features_multi_model(results, val_loader)
X_test_feat, y_test_feat = extract_features_multi_model(results, test_loader)

print(f"Combined feature dimensions: {X_train_feat.shape[1]}")

In [None]:
# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_feat)
X_val_scaled = scaler.transform(X_val_feat)
X_test_scaled = scaler.transform(X_test_feat)

# Try multiple classifiers
classifiers = {
    'Logistic Regression': LogisticRegression(max_iter=2000, C=0.1, random_state=SEED),
    'SVM (RBF)': SVC(kernel='rbf', C=10, gamma='scale', probability=True, random_state=SEED),
    'Random Forest': RandomForestClassifier(n_estimators=200, max_depth=20, random_state=SEED),
    'Gradient Boosting': GradientBoostingClassifier(n_estimators=100, max_depth=5, random_state=SEED)
}

hybrid_results = {}

for clf_name, clf in classifiers.items():
    print(f"\nTraining {clf_name}...")
    clf.fit(X_train_scaled, y_train_feat)
    
    # Validation
    val_preds = clf.predict(X_val_scaled)
    val_acc = accuracy_score(y_val_feat, val_preds)
    val_bal_acc = balanced_accuracy_score(y_val_feat, val_preds)
    
    # Test
    test_preds = clf.predict(X_test_scaled)
    test_acc = accuracy_score(y_test_feat, test_preds)
    test_bal_acc = balanced_accuracy_score(y_test_feat, test_preds)
    
    hybrid_results[clf_name] = {
        'classifier': clf,
        'val_acc': val_acc,
        'val_bal_acc': val_bal_acc,
        'test_acc': test_acc,
        'test_bal_acc': test_bal_acc,
        'predictions': test_preds
    }
    
    print(f"  Val Acc: {val_acc:.4f}, Val Bal Acc: {val_bal_acc:.4f}")
    print(f"  Test Acc: {test_acc:.4f}, Test Bal Acc: {test_bal_acc:.4f}")

## 10. Final Comparison

In [None]:
# Compile all results
final_results = []

# Individual models
for model_name, result in results.items():
    final_results.append({
        'Method': f"{model_name} (Fine-tuned)",
        'Test Accuracy': result['test_acc'],
        'Balanced Accuracy': result['test_bal_acc']
    })

# TTA
final_results.append({
    'Method': f"{best_model_name} + TTA",
    'Test Accuracy': tta_acc,
    'Balanced Accuracy': tta_bal_acc
})

# Ensemble
final_results.append({
    'Method': 'Model Ensemble (Avg)',
    'Test Accuracy': ensemble_acc,
    'Balanced Accuracy': ensemble_bal_acc
})

# Hybrid models
for clf_name, result in hybrid_results.items():
    final_results.append({
        'Method': f"Hybrid + {clf_name}",
        'Test Accuracy': result['test_acc'],
        'Balanced Accuracy': result['test_bal_acc']
    })

# Create dataframe and sort
final_df = pd.DataFrame(final_results)
final_df = final_df.sort_values('Balanced Accuracy', ascending=False)

print("\n" + "="*70)
print("FINAL MODEL COMPARISON")
print("="*70)
print(final_df.to_string(index=False))

In [None]:
# Visualization
fig, ax = plt.subplots(figsize=(14, 8))

x = np.arange(len(final_df))
width = 0.35

bars1 = ax.barh(x - width/2, final_df['Test Accuracy'], width, label='Accuracy', color='#3498db')
bars2 = ax.barh(x + width/2, final_df['Balanced Accuracy'], width, label='Balanced Accuracy', color='#2ecc71')

ax.set_xlabel('Score')
ax.set_title('Model Performance Comparison')
ax.set_yticks(x)
ax.set_yticklabels(final_df['Method'])
ax.legend()
ax.set_xlim(0, 1)
ax.grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Best model confusion matrix
best_method = final_df.iloc[0]['Method']
print(f"\nBest Method: {best_method}")

# Get predictions for best method
if 'Hybrid' in best_method:
    clf_name = best_method.replace('Hybrid + ', '')
    best_preds = hybrid_results[clf_name]['predictions']
    best_labels = y_test_feat
elif 'TTA' in best_method:
    best_preds = tta_preds
    best_labels = test_labels
elif 'Ensemble' in best_method:
    best_preds = ensemble_preds
    best_labels = labels
else:
    model_name = best_method.replace(' (Fine-tuned)', '')
    best_preds = results[model_name]['predictions']
    best_labels = results[model_name]['labels']

# Plot confusion matrix
cm = confusion_matrix(best_labels, best_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=list(class_to_idx.keys()),
            yticklabels=list(class_to_idx.keys()))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Confusion Matrix - {best_method}')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

print(f"\nClassification Report:")
print(classification_report(best_labels, best_preds, target_names=list(class_to_idx.keys())))

## 11. SHAP Explainability

In [None]:
import shap

# Use the best hybrid classifier for SHAP
best_clf_name = max(hybrid_results, key=lambda x: hybrid_results[x]['test_bal_acc'])
best_clf = hybrid_results[best_clf_name]['classifier']

print(f"Using {best_clf_name} for SHAP analysis...")

# Sample background data
background_size = min(100, len(X_train_scaled))
background_idx = np.random.choice(len(X_train_scaled), background_size, replace=False)
background = X_train_scaled[background_idx]

# Create explainer
if hasattr(best_clf, 'predict_proba'):
    explainer = shap.KernelExplainer(best_clf.predict_proba, background)
else:
    explainer = shap.KernelExplainer(best_clf.predict, background)

# Calculate SHAP values for test subset
test_subset_size = min(30, len(X_test_scaled))
test_idx = np.random.choice(len(X_test_scaled), test_subset_size, replace=False)
X_test_subset = X_test_scaled[test_idx]

print(f"Calculating SHAP values for {test_subset_size} samples...")
shap_values = explainer.shap_values(X_test_subset)

In [None]:
# Feature importance summary
plt.figure(figsize=(12, 8))

if isinstance(shap_values, list):
    # Multi-class: average importance across classes
    mean_abs_shap = np.mean([np.abs(sv).mean(axis=0) for sv in shap_values], axis=0)
else:
    mean_abs_shap = np.abs(shap_values).mean(axis=0)

# Top 30 features
top_k = 30
top_idx = np.argsort(mean_abs_shap)[-top_k:]

plt.barh(range(top_k), mean_abs_shap[top_idx], color='steelblue')
plt.yticks(range(top_k), [f"Feature {i}" for i in top_idx])
plt.xlabel('Mean |SHAP value|')
plt.title('Top 30 Most Important Combined Features')
plt.tight_layout()
plt.show()

## 12. Save Best Models

In [None]:
import pickle

SAVE_DIR = "./saved_models_improved"
os.makedirs(SAVE_DIR, exist_ok=True)

# Save all trained models
for model_name, result in results.items():
    torch.save(result['model'].state_dict(), os.path.join(SAVE_DIR, f"{model_name}.pth"))

# Save best hybrid classifier
with open(os.path.join(SAVE_DIR, "best_hybrid_classifier.pkl"), 'wb') as f:
    pickle.dump(best_clf, f)

with open(os.path.join(SAVE_DIR, "feature_scaler.pkl"), 'wb') as f:
    pickle.dump(scaler, f)

with open(os.path.join(SAVE_DIR, "class_mapping.pkl"), 'wb') as f:
    pickle.dump(class_to_idx, f)

# Save results
final_df.to_csv(os.path.join(SAVE_DIR, "model_comparison.csv"), index=False)

print(f"Models saved to {SAVE_DIR}")

## 13. Updated Streamlit App

In [None]:
streamlit_code = '''
"""
Project Phoenix - Improved Classification Dashboard
Run: streamlit run streamlit_app_improved.py
"""

import streamlit as st
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import pickle
import timm
import os

st.set_page_config(page_title="Project Phoenix", page_icon="üî¨", layout="wide")

MODELS_DIR = "./saved_models_improved"
IMG_SIZE = 224

@st.cache_resource
def load_models():
    device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
    
    # Load all CNN models
    model_configs = [
        (\'efficientnet_b0\', \'efficientnet_b0\'),
        (\'efficientnet_b2\', \'efficientnet_b2\'),
        (\'convnext_tiny\', \'convnext_tiny\'),
        (\'swin_tiny\', \'swin_tiny_patch4_window7_224\')
    ]
    
    models = {}
    for name, timm_name in model_configs:
        model = timm.create_model(timm_name, pretrained=False, num_classes=7)
        model.load_state_dict(torch.load(os.path.join(MODELS_DIR, f"{name}.pth"), map_location=device))
        model.eval()
        model.to(device)
        models[name] = model
    
    with open(os.path.join(MODELS_DIR, "class_mapping.pkl"), \'rb\') as f:
        class_mapping = pickle.load(f)
    
    return models, class_mapping, device

def preprocess(image):
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

def ensemble_predict(models, image, device):
    img_tensor = preprocess(image).to(device)
    all_probs = []
    
    with torch.no_grad():
        for model in models.values():
            output = model(img_tensor)
            probs = F.softmax(output, dim=1)
            all_probs.append(probs.cpu().numpy())
    
    avg_probs = np.mean(all_probs, axis=0)[0]
    pred_class = np.argmax(avg_probs)
    return pred_class, avg_probs

def main():
    st.title("üî¨ Project Phoenix")
    st.subheader("Cervical Cancer Cell Classification (Improved)")
    st.markdown("---")
    
    st.sidebar.title("About")
    st.sidebar.info(
        "Enhanced model using ensemble of EfficientNet, ConvNeXt, and Swin Transformer."
    )
    
    try:
        models, class_mapping, device = load_models()
        st.success(f"Loaded {len(models)} models for ensemble prediction")
    except Exception as e:
        st.error(f"Error loading models: {e}")
        return
    
    idx_to_class = {v: k for k, v in class_mapping.items()}
    
    col1, col2 = st.columns(2)
    
    with col1:
        st.subheader("Upload Image")
        uploaded = st.file_uploader("Choose cell image...", type=[\'bmp\', \'png\', \'jpg\', \'jpeg\'])
        
        if uploaded:
            image = Image.open(uploaded).convert(\'RGB\')
            st.image(image, caption="Uploaded Image", use_column_width=True)
    
    with col2:
        st.subheader("Results")
        
        if uploaded and st.button("Classify", type="primary"):
            with st.spinner("Analyzing with ensemble..."):
                pred_idx, probs = ensemble_predict(models, image, device)
                pred_class = idx_to_class[pred_idx]
            
            st.success(f"**Predicted:** {pred_class.replace(\'_\', \' \').title()}")
            st.write(f"**Confidence:** {probs[pred_idx]*100:.1f}%")
            
            st.subheader("Class Probabilities")
            prob_dict = {idx_to_class[i].replace(\'_\', \' \').title(): float(probs[i]) for i in range(len(probs))}
            st.bar_chart(prob_dict)
            
            abnormal = [\'carcinoma_in_situ\', \'light_dysplastic\', \'moderate_dysplastic\', \'severe_dysplastic\']
            if pred_class in abnormal:
                st.warning("‚ö†Ô∏è Abnormal cell detected. Consult a medical professional.")
            else:
                st.info("‚úÖ Cell appears normal.")
    
    st.markdown("---")
    st.caption("Disclaimer: For research purposes only.")

if __name__ == "__main__":
    main()
'''

with open("streamlit_app_improved.py", 'w') as f:
    f.write(streamlit_code)

print("Streamlit app saved as 'streamlit_app_improved.py'")

## Summary

**Key Improvements:**

1. **Proper validation split** - 15% of training data for validation
2. **Class-weighted loss** - Handles imbalanced classes
3. **Modern architectures** - EfficientNet, ConvNeXt, Swin Transformer
4. **Strong augmentation** - RandomErasing, ColorJitter, GaussianBlur
5. **Mixup augmentation** - Regularization during training
6. **Weighted sampling** - Balanced batches
7. **Label smoothing** - Prevents overconfidence
8. **AdamW + Cosine annealing** - Better optimization
9. **Early stopping** - Prevents overfitting
10. **Test-time augmentation** - Improved test accuracy
11. **Model ensemble** - Combines multiple models
12. **Multi-model hybrid** - Features from all models + ML classifiers

In [None]:
print("\n" + "="*60)
print("IMPROVED PIPELINE COMPLETE")
print("="*60)