### **Import libraries**

In [1]:
import os
%pwd

'c:\\09_AHFID\\via-cervix-ai\\notebook'

In [2]:
os.chdir("../")
%pwd

'c:\\09_AHFID\\via-cervix-ai'

### **Import and Config**

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import Counter
import random
from torchvision import transforms
from PIL import Image, ImageFilter, ImageDraw, ImageEnhance
import timm
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path
import json
import os

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATA_DIR = Path("artifacts") / "via-cervix"
RESULTS_DIR = Path("artifacts") / "training_runs"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

CLASS_NAMES = ["Negative", "Positive", "Suspicious cancer"]
CLASS_TO_IDX = {c:i for i,c in enumerate(CLASS_NAMES)}

print(f"Using device: {DEVICE}")
print(f"Data directory: {DATA_DIR}")
print(f"Results directory: {RESULTS_DIR}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu
Data directory: artifacts\via-cervix
Results directory: artifacts\training_runs


### **Utilities for listing the images**

In [4]:
def list_images_by_class(root_dir, class_names):
    """Load images from class folders"""
    items = []
    for cls in class_names:
        folder = root_dir / cls
        if not folder.exists():
            print(f"Warning: Missing folder: {folder}")
            continue
        
        count = 0
        for p in folder.rglob("*"):
            if p.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                items.append((str(p), CLASS_TO_IDX[cls]))
                count += 1
        print(f"Found {count} images in {cls} folder")
    
    return items

### **Class imbalance handling strategies**

In [5]:
class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss

class CostSensitiveLoss(nn.Module):
    """Custom loss with heavy penalties for missing cancer cases"""
    def __init__(self, cost_matrix):
        super(CostSensitiveLoss, self).__init__()
        self.cost_matrix = cost_matrix
        
    def forward(self, inputs, targets):
        probs = F.softmax(inputs, dim=1)
        costs = self.cost_matrix[targets]
        expected_costs = torch.sum(probs * costs, dim=1)
        return expected_costs.mean()

def create_cost_matrix(n_classes, cancer_penalty=15.0):
    """Create cost matrix where missing cancer is heavily penalized"""
    cost_matrix = torch.ones(n_classes, n_classes)
    
    for i in range(n_classes):
        for j in range(n_classes):
            if i != j:
                cost_matrix[i, j] = 1.0
            else:
                cost_matrix[i, j] = 0.0
    
    # Heavy penalty for missing suspicious cancer (class 2)
    cost_matrix[2, 0] = cancer_penalty  # suspicious -> negative
    cost_matrix[2, 1] = cancer_penalty  # suspicious -> positive
    
    return cost_matrix

class AggressiveAugmentation:
    """Aggressive augmentation for minority classes"""
    
    @staticmethod
    def add_realistic_artifacts(pil_img):
        """Add medical imaging artifacts"""
        img = pil_img.copy()
        
        # Random brightness/contrast changes
        enhancer = ImageEnhance.Brightness(img)
        img = enhancer.enhance(random.uniform(0.7, 1.3))
        
        enhancer = ImageEnhance.Contrast(img)
        img = enhancer.enhance(random.uniform(0.8, 1.4))
        
        # Add specular highlights
        if random.random() < 0.5:
            img = AggressiveAugmentation.add_specular_highlights(img)
            
        return img
    
    @staticmethod
    def add_specular_highlights(pil_img, n_spots=None):
        if n_spots is None:
            n_spots = random.randint(1, 2)
            
        img = pil_img.convert("RGBA")
        w, h = img.size
        
        for _ in range(n_spots):
            radius = int(min(w, h) * random.uniform(0.02, 0.12))
            if radius < 2:
                continue
                
            cx = random.randint(radius, max(radius+1, w - radius))
            cy = random.randint(radius, max(radius+1, h - radius))
            
            highlight = Image.new('RGBA', (radius*2, radius*2), (0, 0, 0, 0))
            draw = ImageDraw.Draw(highlight)
            
            for i in range(2):
                r = max(1, radius - i * radius // 3)
                alpha = int(255 * random.uniform(0.2, 0.5) * (1 - i/2))
                draw.ellipse([radius-r, radius-r, radius+r, radius+r], 
                           fill=(255, 255, 255, alpha))
            
            img.paste(highlight, (cx-radius, cy-radius), highlight)
        
        return img.convert("RGB")


### **Transforms for minority class**

In [6]:
# Enhanced transforms
minority_transforms = transforms.Compose([
    transforms.Lambda(AggressiveAugmentation.add_realistic_artifacts),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomRotation(degrees=25),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

class ImbalancedDataset(Dataset):
    """Dataset with intelligent augmentation based on class frequency"""
    def __init__(self, samples, val=False):
        self.samples = samples
        self.val = val
        
        if not val:
            # Calculate augmentation factors
            class_counts = Counter([label for _, label in samples])
            max_count = max(class_counts.values()) if class_counts else 1
            
            self.augmentation_factor = {}
            for class_id, count in class_counts.items():
                self.augmentation_factor[class_id] = max(1, max_count // max(count, 1))
            
            print(f"Augmentation factors: {self.augmentation_factor}")
            
            # Expand dataset
            self.expanded_samples = []
            for path, label in samples:
                self.expanded_samples.append((path, label))
                # Add augmented versions for minority classes
                for _ in range(self.augmentation_factor.get(label, 1) - 1):
                    self.expanded_samples.append((path, label))
        else:
            self.expanded_samples = samples
    
    def __len__(self):
        return len(self.expanded_samples)
    
    def __getitem__(self, idx):
        path, label = self.expanded_samples[idx]
        
        try:
            img = Image.open(path).convert("RGB")
            
            if self.val:
                # Standard validation transforms
                transform = transforms.Compose([
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
                ])
                img = transform(img)
            else:
                # Choose transforms based on class
                if label == 2:  # Suspicious cancer - most aggressive
                    img = minority_transforms(img)
                elif label == 1:  # Positive - moderate augmentation
                    transform = transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.RandomRotation(degrees=15),
                        transforms.ColorJitter(brightness=0.2, contrast=0.2),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
                    ])
                    img = transform(img)
                else:  # Negative - light augmentation
                    transform = transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.RandomHorizontalFlip(p=0.3),
                        transforms.RandomRotation(degrees=8),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
                    ])
                    img = transform(img)
            
            return img, label, path
            
        except Exception as e:
            print(f"Error loading {path}: {e}")
            return torch.zeros(3, 224, 224), label, path

### **Multi-stage training class**

In [None]:
class MultiStageTrainer:
    """Multi-stage training to handle extreme imbalance"""
    
    def __init__(self, model, device, class_names):
        self.model = model
        self.device = device
        self.class_names = class_names
        
    def _modify_classifier_output(self, num_classes):
        """Helper method to modify classifier output size"""
        if hasattr(self.model, 'classifier'):
            if isinstance(self.model.classifier, nn.Sequential):
                # Find the last Linear layer in Sequential
                layers = list(self.model.classifier.children())
                for i in range(len(layers)-1, -1, -1):
                    if isinstance(layers[i], nn.Linear):
                        in_features = layers[i].in_features
                        layers[i] = nn.Linear(in_features, num_classes)
                        break
                self.model.classifier = nn.Sequential(*layers)
            else:
                # Single Linear layer
                in_features = self.model.classifier.in_features
                self.model.classifier = nn.Linear(in_features, num_classes)
        elif hasattr(self.model, 'head'):
            in_features = self.model.head.in_features
            self.model.head = nn.Linear(in_features, num_classes)
        
        return self.model.to(self.device)
        
    def stage1_binary_training(self, train_loader, epochs=3):
        """Stage 1: Train binary classifier (cancer vs non-cancer)"""
        print("Stage 1: Binary cancer detection training...")
        
        # Modify final layer for binary classification
        self.model = self._modify_classifier_output(2)
        
        # Binary focal loss with higher weight for cancer
        alpha = torch.tensor([0.3, 0.7]).to(self.device)
        criterion = FocalLoss(alpha=alpha, gamma=2.0)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5, weight_decay=0.01)
        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            batch_count = 0
            
            for data, target, _ in train_loader:
                data = data.to(self.device)
                # Convert to binary labels (0: non-cancer, 1: cancer)
                binary_target = (target == 2).long().to(self.device)
                
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, binary_target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                batch_count += 1
            
            avg_loss = total_loss / max(batch_count, 1)
            print(f"  Binary epoch {epoch+1}/{epochs}, loss: {avg_loss:.4f}")
    
    def stage2_multiclass_training(self, train_loader, val_loader, epochs=8):
        """Stage 2: Fine-tune for 3-class classification"""
        print("Stage 2: Multi-class fine-tuning...")
        
        # Restore 3-class head
        self.model = self._modify_classifier_output(3)
        
        # Cost-sensitive loss with heavy penalty for missing cancer
        cost_matrix = create_cost_matrix(3, cancer_penalty=15.0).to(self.device)
        criterion = CostSensitiveLoss(cost_matrix)
        
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5, weight_decay=0.01)
        
        best_cancer_recall = 0
        
        for epoch in range(epochs):
            self.model.train()
            total_loss = 0
            batch_count = 0
            
            for data, target, _ in train_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                optimizer.zero_grad()
                output = self.model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                batch_count += 1
            
            # Evaluate
            report, cm = self.evaluate(val_loader)
            cancer_recall = report.get('Suspicious cancer', {}).get('recall', 0)
            
            avg_loss = total_loss / max(batch_count, 1)
            print(f"  Epoch {epoch+1}/{epochs}, loss: {avg_loss:.4f}, cancer recall: {cancer_recall:.3f}")
            
            if cancer_recall > best_cancer_recall:
                best_cancer_recall = cancer_recall
        
        return best_cancer_recall
    
    def evaluate(self, val_loader):
        """Evaluate model performance"""
        self.model.eval()
        all_preds, all_targets = [], []
        
        with torch.no_grad():
            for data, target, _ in val_loader:
                data = data.to(self.device)
                output = self.model(data)
                pred = output.argmax(dim=1).cpu().numpy()
                all_preds.extend(pred)
                all_targets.extend(target.numpy())
        
        try:
            report = classification_report(all_targets, all_preds, 
                                         target_names=self.class_names, 
                                         output_dict=True, zero_division=0)
            cm = confusion_matrix(all_targets, all_preds, labels=[0,1,2])
        except Exception as e:
            print(f"Error in evaluation: {e}")
            report = {'Suspicious cancer': {'recall': 0}}
            cm = np.zeros((3,3))
        
        return report, cm 

### **Model optimization for imbalanced data**

In [8]:
def create_improved_model():
    """Create model optimized for imbalanced data"""
    model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=3)
    
    # Add dropout for regularization
    if hasattr(model, 'classifier'):
        in_features = model.classifier.in_features
        model.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 3)
        )
    
    return model

def save_evaluation_files(model, val_loader, results_dir):
    """Save evaluation files for comprehensive evaluation notebook"""
    print("Saving evaluation files...")
    
    model.eval()
    all_labels, all_probs = [], []
    
    with torch.no_grad():
        for data, target, _ in val_loader:
            data = data.to(DEVICE)
            logits = model(data)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            
            all_labels.extend(target.numpy())
            all_probs.extend(probs)
    
    # Convert to numpy arrays
    eval_labels = np.array(all_labels)
    eval_probs = np.array(all_probs)
    
    # Save files
    np.save(results_dir / "eval_labels.npy", eval_labels)
    np.save(results_dir / "eval_probs.npy", eval_probs)
    
    print(f"Saved evaluation files: {len(eval_labels)} samples")
    print(f"  Labels shape: {eval_labels.shape}")
    print(f"  Probabilities shape: {eval_probs.shape}")
    
    # Verify files
    try:
        test_labels = np.load(results_dir / "eval_labels.npy")
        test_probs = np.load(results_dir / "eval_probs.npy")
        print("  Files verified successfully")
        return True
    except Exception as e:
        print(f"  Error verifying files: {e}")
        return False

### **Main training function**

In [9]:
def run_training():
    """Main training function"""
    print("=" * 60)
    print("STARTING IMPROVED CLASS IMBALANCE TRAINING")
    print("=" * 60)
    
    # Load data
    try:
        all_items = list_images_by_class(DATA_DIR, CLASS_NAMES)
        if len(all_items) == 0:
            print(f"ERROR: No images found in {DATA_DIR}")
            print("Expected structure:")
            print(f"  {DATA_DIR}/")
            print("    Negative/")
            print("    Positive/")
            print("    Suspicious cancer/")
            return None
        
        class_counts = Counter([y for _, y in all_items])
        print(f"Total images: {len(all_items)}")
        print(f"Class distribution: {dict(class_counts)}")
        
    except Exception as e:
        print(f"Error loading data: {e}")
        return None
    
    # Prepare cross-validation
    paths = [item[0] for item in all_items]
    labels = [item[1] for item in all_items]
    
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    
    fold_results = []
    best_model = None
    best_val_loader = None
    best_recall = 0
    
    # Run cross-validation
    for fold, (train_idx, val_idx) in enumerate(skf.split(paths, labels), 1):
        print(f"\nFOLD {fold}/5")
        print("-" * 30)
        
        train_samples = [(paths[i], labels[i]) for i in train_idx]
        val_samples = [(paths[i], labels[i]) for i in val_idx]
        
        print(f"Train samples: {len(train_samples)}")
        print(f"Validation samples: {len(val_samples)}")
        
        # Create datasets
        train_ds = ImbalancedDataset(train_samples, val=False)
        val_ds = ImbalancedDataset(val_samples, val=True)
        
        train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0)
        
        # Create and train model
        model = create_improved_model()
        trainer = MultiStageTrainer(model, DEVICE, CLASS_NAMES)
        
        # Two-stage training
        trainer.stage1_binary_training(train_loader, epochs=3)
        cancer_recall = trainer.stage2_multiclass_training(train_loader, val_loader, epochs=6)
        
        fold_results.append(cancer_recall)
        print(f"Fold {fold} final cancer recall: {cancer_recall:.3f}")
        
        # Track best model
        if cancer_recall > best_recall:
            best_recall = cancer_recall
            best_model = model
            best_val_loader = val_loader
        
        # Save model
        torch.save(model.state_dict(), RESULTS_DIR / f"fold_{fold}_model.pth")
    
    # Save evaluation files
    print(f"\nSaving evaluation files using best model (recall: {best_recall:.3f})")
    if best_model is not None and save_evaluation_files(best_model, best_val_loader, RESULTS_DIR):
        print("Evaluation files saved successfully")
    else:
        print("Failed to save evaluation files")
    
    # Save summary
    cv_summary = {
        "n_folds": 5,
        "cancer_recall_mean": float(np.mean(fold_results)),
        "cancer_recall_std": float(np.std(fold_results)),
        "best_recall": float(best_recall),
        "fold_results": [float(x) for x in fold_results]
    }
    
    with open(RESULTS_DIR / "cv_summary.json", 'w') as f:
        json.dump(cv_summary, f, indent=2)
    
    print("\n" + "=" * 60)
    print("TRAINING COMPLETED")
    print("=" * 60)
    print(f"Cancer recall: {np.mean(fold_results):.3f} ± {np.std(fold_results):.3f}")
    print(f"Best fold recall: {best_recall:.3f}")
    print(f"Files saved to: {RESULTS_DIR}")
    print("\nFiles created:")
    print("- fold_*_model.pth (model checkpoints)")
    print("- eval_labels.npy (for evaluation notebook)")
    print("- eval_probs.npy (for evaluation notebook)")
    print("- cv_summary.json (training summary)")
    
    return fold_results

### **Execute training**

In [10]:
# Execute training
if DATA_DIR.exists():
    print(f"Data directory found: {DATA_DIR}")
    results = run_training()
else:
    print(f"Data directory not found: {DATA_DIR}")
    print("Please ensure your data is in the correct location.")
    results = None

Data directory found: artifacts\via-cervix
STARTING IMPROVED CLASS IMBALANCE TRAINING
Found 92 images in Negative folder
Found 78 images in Positive folder
Found 20 images in Suspicious cancer folder
Total images: 190
Class distribution: {0: 92, 1: 78, 2: 20}

FOLD 1/5
------------------------------
Train samples: 152
Validation samples: 38
Augmentation factors: {0: 1, 1: 1, 2: 4}
Stage 1: Binary cancer detection training...
  Binary epoch 1/3, loss: 0.0273
  Binary epoch 2/3, loss: 0.0222
  Binary epoch 3/3, loss: 0.0193
Stage 2: Multi-class fine-tuning...
  Epoch 1/6, loss: 3.5570, cancer recall: 0.250
  Epoch 2/6, loss: 3.4318, cancer recall: 0.500
  Epoch 3/6, loss: 3.2414, cancer recall: 0.750
  Epoch 4/6, loss: 3.1092, cancer recall: 0.750
  Epoch 5/6, loss: 2.8208, cancer recall: 0.750
  Epoch 6/6, loss: 2.4976, cancer recall: 0.750
Fold 1 final cancer recall: 0.750

FOLD 2/5
------------------------------
Train samples: 152
Validation samples: 38
Augmentation factors: {0: 1, 1: