In [1]:
# ============================================================
# ENHANCED SSL TRAINING WITH AGGRESSIVE OPTIMIZATION
# Implements: Focal Loss tuning, threshold optimization,
# architecture improvements, training strategies, and ensemble capability
# ============================================================

In [2]:
# ============================================================
# 1. Import Required Libraries, Modules, and Packages
# ============================================================
import os
import sys
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import numpy as np
import json
import random
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import timm
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ============================================================
# 2. Setup and Seeds
# ============================================================
if 'notebook' in os.getcwd():
    os.chdir('../')

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [4]:
# ============================================================
# 3. ENHANCED MODEL ARCHITECTURE with Attention
# ============================================================
class EnhancedSSLEfficientNet(nn.Module):
    def __init__(self, model_name='efficientnet_b0', num_classes=2, dropout_rate=0.4):
        super(EnhancedSSLEfficientNet, self).__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0)
        self.feature_dim = self.backbone.num_features
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(self.feature_dim, self.feature_dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.feature_dim // 4, self.feature_dim),
            nn.Sigmoid()
        )
        
        # Deeper classifier with batch normalization
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.6),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate * 0.4),
            nn.Linear(256, num_classes)
        )
        
        self.projector = nn.Sequential(
            nn.Linear(self.feature_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in [self.attention, self.classifier, self.projector]:
            for layer in m:
                if isinstance(layer, nn.Linear):
                    nn.init.xavier_uniform_(layer.weight)
                    if layer.bias is not None:
                        nn.init.constant_(layer.bias, 0)
    
    def forward(self, x):
        features = self.backbone(x)
        attention_weights = self.attention(features)
        features = features * attention_weights
        return self.classifier(features)

In [5]:
# ============================================================
# 4. ENHANCED FOCAL LOSS (Recommendation #2)
# ============================================================
class AggressiveFocalLoss(nn.Module):
    """Enhanced focal loss with higher alpha and gamma for sensitivity"""
    def __init__(self, alpha=0.95, gamma=3.0):
        super(AggressiveFocalLoss, self).__init__()
        self.alpha = torch.tensor([1 - alpha, alpha])
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.cross_entropy(inputs, targets, reduction='none')
        targets = targets.type(torch.long)
        at = self.alpha.gather(0, targets.data.view(-1)).to(inputs.device)
        pt = torch.exp(-BCE_loss)
        F_loss = at * (1 - pt)**self.gamma * BCE_loss
        return F_loss.mean()

In [6]:
# ============================================================
# 5. AGGRESSIVE DATA AUGMENTATION (Recommendation #2)
# ============================================================
norm_params = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}

# Very strong augmentation for positives
very_strong_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.6),
    transforms.RandomVerticalFlip(p=0.6),
    transforms.RandomRotation(degrees=45),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.3, hue=0.15),
    transforms.RandomAffine(degrees=30, translate=(0.15, 0.15), scale=(0.7, 1.3)),
    transforms.RandomPerspective(distortion_scale=0.3, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(**norm_params)
])

moderate_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(**norm_params)
])

weak_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(**norm_params)
])

val_transform = weak_transform

In [7]:
# ============================================================
# 6. Datasets
# ============================================================
class LabeledDataset(Dataset):
    def __init__(self, paths, labels, transform_map):
        self.paths, self.labels, self.tm = paths, labels, transform_map
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]
        return self.tm[label](img), label

class UnlabeledDataset(Dataset):
    def __init__(self, paths, weak_tf, strong_tf):
        self.paths, self.weak_tf, self.strong_tf = paths, weak_tf, strong_tf
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        return self.weak_tf(img), self.strong_tf(img)

class ValidationDataset(Dataset):
    def __init__(self, paths, labels, transform):
        self.paths, self.labels, self.tf = paths, labels, transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        return self.tf(img), self.labels[idx]


In [8]:
# ============================================================
# 7. Data Loading
# ============================================================
def load_and_prepare_data():
    ARTIFACTS_DIR = Path("artifacts")
    with open(ARTIFACTS_DIR / "data_metadata.json", "r") as f:
        data_metadata = json.load(f)
    data_dir = Path(data_metadata["data_dir"])
    labeled_dir, unlabeled_dir = data_dir / "labeled", data_dir / "unlabeled"
    
    labeled_paths, labels = [], []
    for name, idx in {"Negative": 0, "Positive": 1}.items():
        class_dir = labeled_dir / name
        if class_dir.exists():
            for path in class_dir.glob("*.[jp][pn]g"):
                labeled_paths.append(str(path))
                labels.append(idx)
    
    unlabeled_paths = [str(p) for p in unlabeled_dir.glob("*.[jp][pn]g")] if unlabeled_dir.exists() else []
    print(f"Loaded {len(labeled_paths)} labeled ({Counter(labels)}) and {len(unlabeled_paths)} unlabeled images.")
    
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        labeled_paths, labels, test_size=0.2, random_state=SEED, stratify=labels
    )
    
    # Use very strong augmentation for positives
    train_dataset = LabeledDataset(train_paths, train_labels, 
                                   transform_map={0: moderate_transform, 1: very_strong_transform})
    val_dataset = ValidationDataset(val_paths, val_labels, val_transform)
    unlabeled_dataset = UnlabeledDataset(unlabeled_paths, weak_transform, very_strong_transform)
    
    return train_dataset, val_dataset, unlabeled_dataset, train_labels

In [9]:
# ============================================================
# 8. ENHANCED TRAINER with Better Threshold Search (Recommendation #3)
# ============================================================
class EnhancedTrainer:
    def __init__(self, model, device='cuda', ssl_weight=0.5, confidence_thresh=0.97):
        self.model, self.device = model.to(device), device
        self.criterion = AggressiveFocalLoss(alpha=0.95, gamma=3.0)
        self.ssl_weight, self.confidence_thresh = ssl_weight, confidence_thresh
        
        # Lower learning rate with higher weight decay (Recommendation #4)
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=1e-5, weight_decay=1e-2
        )
        
        # Cosine annealing with warm restarts (Recommendation #4)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=2
        )
        
        self.history = {
            'train_loss': [], 'consistency_loss': [], 
            'sensitivity': [], 'specificity': [], 
            'false_negatives': [], 'false_positives': []
        }
        self.best_clinical_score = float('inf')
        self.best_metrics = {}

    def train_epoch(self, labeled_loader, unlabeled_loader, epoch, warmup_epochs):
        self.model.train()
        total_sup_loss, total_cons_loss = 0, 0
        unlabeled_iter = iter(unlabeled_loader)
        pbar = tqdm(labeled_loader, desc=f'Epoch {epoch+1}')
        
        for labeled_images, labels in pbar:
            labeled_images, labels = labeled_images.to(self.device), labels.to(self.device)
            logits_l = self.model(labeled_images)
            sup_loss = self.criterion(logits_l, labels)
            total_loss = sup_loss
            
            if epoch >= warmup_epochs:
                try:
                    images_u_w, images_u_s = next(unlabeled_iter)
                except StopIteration:
                    unlabeled_iter = iter(unlabeled_loader)
                    images_u_w, images_u_s = next(unlabeled_iter)
                images_u_w, images_u_s = images_u_w.to(self.device), images_u_s.to(self.device)
                
                with torch.no_grad():
                    logits_u_w = self.model(images_u_w)
                    max_probs, p_targets = torch.max(F.softmax(logits_u_w, dim=1), dim=1)
                    mask = max_probs.ge(self.confidence_thresh).float()
                
                logits_u_s = self.model(images_u_s)
                cons_loss = (F.cross_entropy(logits_u_s, p_targets, reduction='none') * mask).mean()
                total_loss += self.ssl_weight * cons_loss
                total_cons_loss += cons_loss.item()
            
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            total_sup_loss += sup_loss.item()
            pbar.set_postfix({'sup_loss': f'{sup_loss.item():.3f}', 
                            'cons_loss': f'{total_cons_loss / (len(pbar) or 1):.3f}'})
        
        return total_sup_loss / len(labeled_loader), total_cons_loss / len(labeled_loader)

    def validate_and_find_threshold(self, val_loader, fn_target=2, fp_target=3):
        """Enhanced threshold search prioritizing sensitivity (Recommendation #3)"""
        self.model.eval()
        all_probs, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                logits = self.model(images.to(self.device))
                all_probs.extend(F.softmax(logits, dim=1)[:, 1].cpu().numpy())
                all_labels.extend(labels.numpy())
        all_probs, all_labels = np.array(all_probs), np.array(all_labels)
        
        # Finer granularity threshold search
        best_threshold = -1
        best_fp = float('inf')
        
        # First pass: Find thresholds meeting FN target
        fn_acceptable_thresholds = []
        for threshold in np.arange(0.01, 0.95, 0.005):  # Finer steps
            preds = (all_probs > threshold).astype(int)
            fn = np.sum((all_labels == 1) & (preds == 0))
            if fn <= fn_target:
                fn_acceptable_thresholds.append((threshold, fn))
        
        # Second pass: Among FN-acceptable, optimize FP
        if fn_acceptable_thresholds:
            for threshold, fn in fn_acceptable_thresholds:
                preds = (all_probs > threshold).astype(int)
                fp = np.sum((all_labels == 0) & (preds == 1))
                if fp <= fp_target and fp < best_fp:
                    best_fp = fp
                    best_threshold = threshold
        
        # If no threshold meets both, relax FP constraint
        if best_threshold == -1 and fn_acceptable_thresholds:
            for threshold, fn in fn_acceptable_thresholds:
                preds = (all_probs > threshold).astype(int)
                fp = np.sum((all_labels == 0) & (preds == 1))
                if fp < best_fp:
                    best_fp = fp
                    best_threshold = threshold
        
        # Fallback: prioritize FN minimization
        if best_threshold == -1:
            best_fn = float('inf')
            for threshold in np.arange(0.01, 0.95, 0.005):
                preds = (all_probs > threshold).astype(int)
                fn = np.sum((all_labels == 1) & (preds == 0))
                fp = np.sum((all_labels == 0) & (preds == 1))
                if fn < best_fn or (fn == best_fn and fp < best_fp):
                    best_fn, best_fp, best_threshold = fn, fp, threshold
        
        y_pred = (all_probs >= best_threshold).astype(int)
        tp = np.sum((all_labels == 1) & (y_pred == 1))
        fn_final = np.sum((all_labels == 1) & (y_pred == 0))
        tn = np.sum((all_labels == 0) & (y_pred == 0))
        fp_final = np.sum((all_labels == 0) & (y_pred == 1))
        
        metrics = {
            'sensitivity': tp/(tp+fn_final) if tp+fn_final>0 else 0,
            'specificity': tn/(tn+fp_final) if tn+fp_final>0 else 0,
            'false_negatives': fn_final,
            'false_positives': fp_final
        }
        return metrics, best_threshold

    def train(self, labeled_loader, unlabeled_loader, val_loader, epochs=60, warmup_epochs=5):
        print("\n" + "="*60)
        print("ENHANCED SSL TRAINING WITH AGGRESSIVE OPTIMIZATION")
        print("="*60)
        patience, max_patience = 0, 20  # Increased patience
        
        for epoch in range(epochs):
            train_loss, cons_loss = self.train_epoch(labeled_loader, unlabeled_loader, epoch, warmup_epochs)
            metrics, threshold = self.validate_and_find_threshold(val_loader, fn_target=2, fp_target=3)
            
            fn, fp = metrics['false_negatives'], metrics['false_positives']
            current_score = (fn * 5) + fp
            
            # Step scheduler
            self.scheduler.step()
            
            for key, value in metrics.items():
                self.history[key].append(value)
            self.history['train_loss'].append(train_loss)
            self.history['consistency_loss'].append(cons_loss)
            
            print(f"\nEpoch {epoch+1}/{epochs} | Sens: {metrics['sensitivity']:.3f}, "
                  f"Spec: {metrics['specificity']:.3f}, FN: {fn}, FP: {fp}, Thr: {threshold:.3f}")

            is_clinically_acceptable = fn <= 2 and fp <= 3
            is_best = False
            
            if is_clinically_acceptable:
                if not self.best_metrics.get('is_clinically_acceptable', False) or current_score < self.best_clinical_score:
                    is_best = True
            elif not self.best_metrics.get('is_clinically_acceptable', False) and current_score < self.best_clinical_score:
                is_best = True

            if is_best:
                self.best_clinical_score = current_score
                self.best_metrics = {**metrics, 'is_clinically_acceptable': is_clinically_acceptable}
                patience = 0
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'metrics': metrics,
                    'threshold': threshold,
                    'history': self.history
                }
                torch.save(checkpoint, 'artifacts/enhanced_model.pth')
                print(f"  ✓ Saved best model (FN: {fn}, FP: {fp}, Score: {current_score})")
            else:
                patience += 1
            
            if (patience >= max_patience and epoch > 15) or (self.best_metrics.get('is_clinically_acceptable', False) and epoch > 15):
                print(f"\n{'✓ CLINICAL TARGET MET' if self.best_metrics.get('is_clinically_acceptable', False) else 'Early stopping'}!")
                break
        
        print("\n" + "="*60)
        print("TRAINING COMPLETED")
        print(f"Best Results: {self.best_metrics}")
        print("="*60)
        return self.history

In [None]:
# ============================================================
# 9. ENSEMBLE TRAINING
# ============================================================
def train_ensemble(n_models=3):
    """Train multiple models with different seeds for ensemble"""
    models = []
    trainers = []
    
    for i in range(n_models):
        print(f"\n{'='*60}")
        print(f"TRAINING ENSEMBLE MODEL {i+1}/{n_models}")
        print(f"{'='*60}\n")
        
        # Set different seed
        seed = SEED + i * 100
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        
        # Create fresh data loaders
        train_dataset, val_dataset, unlabeled_dataset, train_labels = load_and_prepare_data()
        
        class_counts = np.bincount(train_labels)
        class_weights = 1.0 / class_counts
        sample_weights = np.array([class_weights[label] for label in train_labels])
        sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
        
        labeled_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
        unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)
        
        # Create and train model
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = EnhancedSSLEfficientNet(num_classes=2, dropout_rate=0.4)
        
        # Load pretrained weights if available
        pretrained_path = Path("artifacts/ssl_model_best_010_78.9.pth")
        # if pretrained_path.exists():
        #     checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)
        #     model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        if pretrained_path.exists():
            checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)

            # Drop mismatched classifier weights
            for key in list(checkpoint['model_state_dict'].keys()):
                if "classifier.1.weight" in key or "classifier.1.bias" in key:
                    del checkpoint['model_state_dict'][key]

            model.load_state_dict(checkpoint['model_state_dict'], strict=False)


        
        trainer = EnhancedTrainer(model, device=device, ssl_weight=0.5, confidence_thresh=0.97)
        trainer.train(labeled_loader, unlabeled_loader, val_loader, epochs=60, warmup_epochs=5)
        
        models.append(model)
        trainers.append(trainer)
        
        # Save individual model
        torch.save({
            'model_state_dict': model.state_dict(),
            'metrics': trainer.best_metrics
        }, f'artifacts/enhanced_model_ensemble_{i}.pth')
    
    return models, trainers

def ensemble_predict(models, val_loader, device):
    """Make ensemble predictions"""
    all_probs_list = []
    all_labels = None
    
    for model in models:
        model.eval()
        probs = []
        labels = []
        with torch.no_grad():
            for images, batch_labels in val_loader:
                logits = model(images.to(device))
                probs.extend(F.softmax(logits, dim=1)[:, 1].cpu().numpy())
                labels.extend(batch_labels.numpy())
        all_probs_list.append(np.array(probs))
        if all_labels is None:
            all_labels = np.array(labels)
    
    # Average probabilities
    ensemble_probs = np.mean(all_probs_list, axis=0)
    return ensemble_probs, all_labels

In [11]:
# ============================================================
# 10. Visualization
# ============================================================
def plot_training_history(history, filename='enhanced_training_history.png'):
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Enhanced SSL Training Progress', fontsize=16, fontweight='bold')
    
    axes[0, 0].plot(history['train_loss'], label='Supervised Loss')
    axes[0, 0].plot(history['consistency_loss'], label='Consistency Loss', linestyle='--')
    axes[0, 0].set_title('Training Loss Components')
    axes[0, 0].legend()
    
    axes[0, 1].plot(history['sensitivity'], label='Sensitivity', color='g')
    axes[0, 1].plot(history['specificity'], label='Specificity', color='b')
    axes[0, 1].axhline(y=0.90, color='g', linestyle=':', alpha=0.5, label='Target Sens (0.90)')
    axes[0, 1].axhline(y=0.85, color='b', linestyle=':', alpha=0.5, label='Target Spec (0.85)')
    axes[0, 1].set_title('Sensitivity vs Specificity')
    axes[0, 1].legend()

    axes[1, 0].plot(history['false_negatives'], label='False Negatives', color='r', marker='o')
    axes[1, 0].axhline(y=2, color='r', linestyle=':', alpha=0.5, label='Target (≤2)')
    axes[1, 0].set_title('False Negatives per Epoch')
    axes[1, 0].legend()
    
    axes[1, 1].plot(history['false_positives'], label='False Positives', color='orange', marker='s')
    axes[1, 1].axhline(y=3, color='orange', linestyle=':', alpha=0.5, label='Target (≤3)')
    axes[1, 1].set_title('False Positives per Epoch')
    axes[1, 1].legend()
    
    for ax in axes.flat:
        ax.set_xlabel('Epoch')
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.savefig(f'artifacts/{filename}', dpi=150)
    plt.show()

In [12]:
# ============================================================
# 11. Main Execution
# ============================================================
def convert_numpy_to_python(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    if isinstance(obj, np.floating):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if isinstance(obj, dict):
        return {k: convert_numpy_to_python(v) for k, v in obj.items()}
    if isinstance(obj, list):
        return [convert_numpy_to_python(item) for item in obj]
    return obj

def main_single_model():
    """Train single enhanced model"""
    print("="*60)
    print("ENHANCED SSL TRAINING - SINGLE MODEL")
    print("="*60)
    
    ARTIFACTS_DIR = Path("artifacts")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    train_dataset, val_dataset, unlabeled_dataset, train_labels = load_and_prepare_data()
    
    class_counts = np.bincount(train_labels)
    class_weights = 1.0 / class_counts
    sample_weights = np.array([class_weights[label] for label in train_labels])
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    
    labeled_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=32, shuffle=True)

    print(f"\nData loader summary: Labeled Batches: {len(labeled_loader)}, Unlabeled Batches: {len(unlabeled_loader)}")
    
    model = EnhancedSSLEfficientNet(num_classes=2, dropout_rate=0.4)
    print(f"\nModel summary: Total trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    pretrained_path = ARTIFACTS_DIR / "ssl_model_best_010_78.9.pth"
    if pretrained_path.exists():
        print(f"Loading pretrained weights from: {pretrained_path}")
        checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        print("✓ Pretrained weights loaded")
    
    trainer = EnhancedTrainer(model, device=device, ssl_weight=0.5, confidence_thresh=0.97)
    history = trainer.train(labeled_loader, unlabeled_loader, val_loader, epochs=60, warmup_epochs=5)
    
    plot_training_history(history)
    
    history_serializable = convert_numpy_to_python(history)
    with open(ARTIFACTS_DIR / "enhanced_training_history.json", "w") as f:
        json.dump(history_serializable, f, indent=2)
    
    return trainer, history

def main_ensemble():
    """Train ensemble of models"""
    models, trainers = train_ensemble(n_models=3)
    
    # Evaluate ensemble
    _, val_dataset, _, _ = load_and_prepare_data()
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    ensemble_probs, all_labels = ensemble_predict(models, val_loader, device)
    
    # Find best threshold for ensemble
    best_threshold = -1
    best_fn = float('inf')
    best_fp = float('inf')
    
    for threshold in np.arange(0.01, 0.95, 0.005):
        preds = (ensemble_probs > threshold).astype(int)
        fn = np.sum((all_labels == 1) & (preds == 0))
        fp = np.sum((all_labels == 0) & (preds == 1))
        if fn <= 2 and fp <= 3:
            best_threshold = threshold
            best_fn, best_fp = fn, fp
            break
        if fn < best_fn or (fn == best_fn and fp < best_fp):
            best_fn, best_fp, best_threshold = fn, fp, threshold
    
    y_pred = (ensemble_probs >= best_threshold).astype(int)
    tp = np.sum((all_labels == 1) & (y_pred == 1))
    tn = np.sum((all_labels == 0) & (y_pred == 0))
    
    ensemble_metrics = {
        'sensitivity': tp/(tp+best_fn) if tp+best_fn>0 else 0,
        'specificity': tn/(tn+best_fp) if tn+best_fp>0 else 0,
        'false_negatives': best_fn,
        'false_positives': best_fp,
        'threshold': best_threshold
    }
    
    print("\n" + "="*60)
    print("ENSEMBLE RESULTS")
    print("="*60)
    print(f"Sensitivity: {ensemble_metrics['sensitivity']:.3f}")
    print(f"Specificity: {ensemble_metrics['specificity']:.3f}")
    print(f"False Negatives: {ensemble_metrics['false_negatives']}")
    print(f"False Positives: {ensemble_metrics['false_positives']}")
    print(f"Threshold: {ensemble_metrics['threshold']:.3f}")
    
    return models, trainers, ensemble_metrics

if __name__ == "__main__":
    # Choose mode
    USE_ENSEMBLE = True  # Set to False for single model training
    
    if USE_ENSEMBLE:
        models, trainers, ensemble_metrics = main_ensemble()
    else:
        trainer, history = main_single_model()
        
        print("\n" + "="*60)
        print("FINAL PERFORMANCE SUMMARY")
        print("="*60)
        
        if trainer.best_metrics:
            best_metrics = trainer.best_metrics
            print(f"Best Model Performance:")
            print(f"  - False Negatives: {best_metrics['false_negatives']}")
            print(f"  - False Positives: {best_metrics['false_positives']}")
            print(f"  - Sensitivity: {best_metrics['sensitivity']:.3f}")
            print(f"  - Specificity: {best_metrics['specificity']:.3f}")
            print(f"  - Threshold: {best_metrics['threshold']:.3f}")
            print(f"  - Best Epoch: {best_metrics['epoch']}")


TRAINING ENSEMBLE MODEL 1/3

Loaded 190 labeled (Counter({1: 98, 0: 92})) and 6377 unlabeled images.

ENHANCED SSL TRAINING WITH AGGRESSIVE OPTIMIZATION


Epoch 1: 100%|██████████| 10/10 [00:48<00:00,  4.89s/it, sup_loss=0.445, cons_loss=0.000]



Epoch 1/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.415
  ✓ Saved best model (FN: 1, FP: 13, Score: 18)


Epoch 2: 100%|██████████| 10/10 [00:40<00:00,  4.05s/it, sup_loss=0.189, cons_loss=0.000]



Epoch 2/60 | Sens: 0.900, Spec: 0.389, FN: 2, FP: 11, Thr: 0.380


Epoch 3: 100%|██████████| 10/10 [00:41<00:00,  4.14s/it, sup_loss=0.138, cons_loss=0.000]



Epoch 3/60 | Sens: 0.900, Spec: 0.444, FN: 2, FP: 10, Thr: 0.310


Epoch 4: 100%|██████████| 10/10 [00:43<00:00,  4.36s/it, sup_loss=0.095, cons_loss=0.000]



Epoch 4/60 | Sens: 0.950, Spec: 0.389, FN: 1, FP: 11, Thr: 0.215
  ✓ Saved best model (FN: 1, FP: 11, Score: 16)


Epoch 5: 100%|██████████| 10/10 [00:45<00:00,  4.54s/it, sup_loss=0.137, cons_loss=0.000]



Epoch 5/60 | Sens: 0.900, Spec: 0.389, FN: 2, FP: 11, Thr: 0.250


Epoch 6: 100%|██████████| 10/10 [02:26<00:00, 14.61s/it, sup_loss=0.445, cons_loss=0.019]



Epoch 6/60 | Sens: 0.900, Spec: 0.111, FN: 2, FP: 16, Thr: 0.105


Epoch 7: 100%|██████████| 10/10 [02:21<00:00, 14.20s/it, sup_loss=0.443, cons_loss=0.040]



Epoch 7/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.095


Epoch 8: 100%|██████████| 10/10 [02:18<00:00, 13.89s/it, sup_loss=0.008, cons_loss=0.018]



Epoch 8/60 | Sens: 0.900, Spec: 0.056, FN: 2, FP: 17, Thr: 0.075


Epoch 9: 100%|██████████| 10/10 [02:20<00:00, 14.08s/it, sup_loss=0.120, cons_loss=0.009]



Epoch 9/60 | Sens: 0.900, Spec: 0.056, FN: 2, FP: 17, Thr: 0.085


Epoch 10: 100%|██████████| 10/10 [02:17<00:00, 13.73s/it, sup_loss=0.148, cons_loss=0.018]



Epoch 10/60 | Sens: 0.950, Spec: 0.111, FN: 1, FP: 16, Thr: 0.100


Epoch 11: 100%|██████████| 10/10 [02:17<00:00, 13.72s/it, sup_loss=0.081, cons_loss=0.044]



Epoch 11/60 | Sens: 0.900, Spec: 0.056, FN: 2, FP: 17, Thr: 0.080


Epoch 12: 100%|██████████| 10/10 [03:58<00:00, 23.89s/it, sup_loss=0.010, cons_loss=0.021]



Epoch 12/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.050


Epoch 13: 100%|██████████| 10/10 [03:54<00:00, 23.41s/it, sup_loss=0.062, cons_loss=0.018]



Epoch 13/60 | Sens: 0.900, Spec: 0.167, FN: 2, FP: 15, Thr: 0.070


Epoch 14: 100%|██████████| 10/10 [03:15<00:00, 19.52s/it, sup_loss=0.695, cons_loss=0.019]



Epoch 14/60 | Sens: 1.000, Spec: 0.000, FN: 0, FP: 18, Thr: 0.010


Epoch 15: 100%|██████████| 10/10 [03:40<00:00, 22.03s/it, sup_loss=0.105, cons_loss=0.035]



Epoch 15/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.035


Epoch 16: 100%|██████████| 10/10 [02:53<00:00, 17.39s/it, sup_loss=0.031, cons_loss=0.019]



Epoch 16/60 | Sens: 0.900, Spec: 0.222, FN: 2, FP: 14, Thr: 0.050


Epoch 17: 100%|██████████| 10/10 [01:57<00:00, 11.79s/it, sup_loss=0.174, cons_loss=0.019]



Epoch 17/60 | Sens: 0.950, Spec: 0.111, FN: 1, FP: 16, Thr: 0.030


Epoch 18: 100%|██████████| 10/10 [01:55<00:00, 11.51s/it, sup_loss=0.039, cons_loss=0.043]



Epoch 18/60 | Sens: 0.950, Spec: 0.000, FN: 1, FP: 18, Thr: 0.010


Epoch 19: 100%|██████████| 10/10 [02:03<00:00, 12.33s/it, sup_loss=0.002, cons_loss=0.029]



Epoch 19/60 | Sens: 0.900, Spec: 0.167, FN: 2, FP: 15, Thr: 0.045


Epoch 20: 100%|██████████| 10/10 [02:06<00:00, 12.63s/it, sup_loss=0.135, cons_loss=0.041]



Epoch 20/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.045


Epoch 21: 100%|██████████| 10/10 [02:02<00:00, 12.22s/it, sup_loss=0.094, cons_loss=0.029]



Epoch 21/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.020


Epoch 22: 100%|██████████| 10/10 [01:59<00:00, 11.98s/it, sup_loss=0.043, cons_loss=0.024]



Epoch 22/60 | Sens: 0.900, Spec: 0.222, FN: 2, FP: 14, Thr: 0.040


Epoch 23: 100%|██████████| 10/10 [02:10<00:00, 13.09s/it, sup_loss=0.056, cons_loss=0.036]



Epoch 23/60 | Sens: 0.900, Spec: 0.111, FN: 2, FP: 16, Thr: 0.035


Epoch 24: 100%|██████████| 10/10 [02:09<00:00, 12.91s/it, sup_loss=0.030, cons_loss=0.043]



Epoch 24/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.020

Early stopping!

TRAINING COMPLETED
Best Results: {'sensitivity': np.float64(0.95), 'specificity': np.float64(0.3888888888888889), 'false_negatives': np.int64(1), 'false_positives': np.int64(11), 'is_clinically_acceptable': np.False_}

TRAINING ENSEMBLE MODEL 2/3

Loaded 190 labeled (Counter({1: 98, 0: 92})) and 6377 unlabeled images.

ENHANCED SSL TRAINING WITH AGGRESSIVE OPTIMIZATION


Epoch 1: 100%|██████████| 10/10 [00:38<00:00,  3.83s/it, sup_loss=0.190, cons_loss=0.000]



Epoch 1/60 | Sens: 1.000, Spec: 0.389, FN: 0, FP: 11, Thr: 0.395
  ✓ Saved best model (FN: 0, FP: 11, Score: 11)


Epoch 2: 100%|██████████| 10/10 [00:36<00:00,  3.61s/it, sup_loss=0.102, cons_loss=0.000]



Epoch 2/60 | Sens: 0.900, Spec: 0.611, FN: 2, FP: 7, Thr: 0.385


Epoch 3: 100%|██████████| 10/10 [00:36<00:00,  3.65s/it, sup_loss=0.286, cons_loss=0.000]



Epoch 3/60 | Sens: 0.900, Spec: 0.556, FN: 2, FP: 8, Thr: 0.355


Epoch 4: 100%|██████████| 10/10 [00:43<00:00,  4.39s/it, sup_loss=0.184, cons_loss=0.000]



Epoch 4/60 | Sens: 1.000, Spec: 0.333, FN: 0, FP: 12, Thr: 0.200


Epoch 5: 100%|██████████| 10/10 [00:30<00:00,  3.05s/it, sup_loss=0.177, cons_loss=0.000]



Epoch 5/60 | Sens: 1.000, Spec: 0.556, FN: 0, FP: 8, Thr: 0.270
  ✓ Saved best model (FN: 0, FP: 8, Score: 8)


Epoch 6: 100%|██████████| 10/10 [01:46<00:00, 10.61s/it, sup_loss=0.195, cons_loss=0.024]



Epoch 6/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.215


Epoch 7: 100%|██████████| 10/10 [01:46<00:00, 10.70s/it, sup_loss=0.087, cons_loss=0.002]



Epoch 7/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.250


Epoch 8: 100%|██████████| 10/10 [01:43<00:00, 10.36s/it, sup_loss=0.473, cons_loss=0.005]



Epoch 8/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.245


Epoch 9: 100%|██████████| 10/10 [01:48<00:00, 10.87s/it, sup_loss=0.087, cons_loss=0.010]



Epoch 9/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.185


Epoch 10: 100%|██████████| 10/10 [02:02<00:00, 12.29s/it, sup_loss=0.178, cons_loss=0.013]



Epoch 10/60 | Sens: 0.950, Spec: 0.333, FN: 1, FP: 12, Thr: 0.270


Epoch 11: 100%|██████████| 10/10 [02:07<00:00, 12.70s/it, sup_loss=0.039, cons_loss=0.007]



Epoch 11/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.265


Epoch 12: 100%|██████████| 10/10 [02:33<00:00, 15.39s/it, sup_loss=0.039, cons_loss=0.004]



Epoch 12/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.175


Epoch 13: 100%|██████████| 10/10 [01:57<00:00, 11.77s/it, sup_loss=0.236, cons_loss=0.008]



Epoch 13/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.235


Epoch 14: 100%|██████████| 10/10 [01:59<00:00, 11.98s/it, sup_loss=0.018, cons_loss=0.013]



Epoch 14/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.150


Epoch 15: 100%|██████████| 10/10 [01:55<00:00, 11.53s/it, sup_loss=0.039, cons_loss=0.010]



Epoch 15/60 | Sens: 0.950, Spec: 0.222, FN: 1, FP: 14, Thr: 0.135


Epoch 16: 100%|██████████| 10/10 [01:59<00:00, 11.99s/it, sup_loss=0.276, cons_loss=0.013]



Epoch 16/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.135


Epoch 17: 100%|██████████| 10/10 [01:55<00:00, 11.53s/it, sup_loss=0.255, cons_loss=0.009]



Epoch 17/60 | Sens: 0.950, Spec: 0.333, FN: 1, FP: 12, Thr: 0.145


Epoch 18: 100%|██████████| 10/10 [01:56<00:00, 11.68s/it, sup_loss=0.279, cons_loss=0.012]



Epoch 18/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.135


Epoch 19: 100%|██████████| 10/10 [01:56<00:00, 11.64s/it, sup_loss=0.635, cons_loss=0.008]



Epoch 19/60 | Sens: 0.900, Spec: 0.222, FN: 2, FP: 14, Thr: 0.080


Epoch 20: 100%|██████████| 10/10 [01:55<00:00, 11.55s/it, sup_loss=0.465, cons_loss=0.007]



Epoch 20/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.110


Epoch 21: 100%|██████████| 10/10 [02:04<00:00, 12.43s/it, sup_loss=0.052, cons_loss=0.014]



Epoch 21/60 | Sens: 0.900, Spec: 0.500, FN: 2, FP: 9, Thr: 0.135


Epoch 22: 100%|██████████| 10/10 [01:55<00:00, 11.56s/it, sup_loss=0.146, cons_loss=0.023]



Epoch 22/60 | Sens: 0.950, Spec: 0.444, FN: 1, FP: 10, Thr: 0.145


Epoch 23: 100%|██████████| 10/10 [01:57<00:00, 11.78s/it, sup_loss=0.128, cons_loss=0.014]



Epoch 23/60 | Sens: 0.900, Spec: 0.389, FN: 2, FP: 11, Thr: 0.125


Epoch 24: 100%|██████████| 10/10 [01:53<00:00, 11.37s/it, sup_loss=0.012, cons_loss=0.005]



Epoch 24/60 | Sens: 0.900, Spec: 0.500, FN: 2, FP: 9, Thr: 0.125


Epoch 25: 100%|██████████| 10/10 [01:55<00:00, 11.54s/it, sup_loss=0.090, cons_loss=0.006]



Epoch 25/60 | Sens: 0.950, Spec: 0.389, FN: 1, FP: 11, Thr: 0.110

Early stopping!

TRAINING COMPLETED
Best Results: {'sensitivity': np.float64(1.0), 'specificity': np.float64(0.5555555555555556), 'false_negatives': np.int64(0), 'false_positives': np.int64(8), 'is_clinically_acceptable': np.False_}

TRAINING ENSEMBLE MODEL 3/3

Loaded 190 labeled (Counter({1: 98, 0: 92})) and 6377 unlabeled images.

ENHANCED SSL TRAINING WITH AGGRESSIVE OPTIMIZATION


Epoch 1: 100%|██████████| 10/10 [00:34<00:00,  3.47s/it, sup_loss=0.054, cons_loss=0.000]



Epoch 1/60 | Sens: 0.900, Spec: 0.111, FN: 2, FP: 16, Thr: 0.420
  ✓ Saved best model (FN: 2, FP: 16, Score: 26)


Epoch 2: 100%|██████████| 10/10 [00:37<00:00,  3.76s/it, sup_loss=0.180, cons_loss=0.000]



Epoch 2/60 | Sens: 1.000, Spec: 0.000, FN: 0, FP: 18, Thr: 0.010
  ✓ Saved best model (FN: 0, FP: 18, Score: 18)


Epoch 3: 100%|██████████| 10/10 [00:36<00:00,  3.69s/it, sup_loss=0.080, cons_loss=0.000]



Epoch 3/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.350


Epoch 4: 100%|██████████| 10/10 [00:35<00:00,  3.56s/it, sup_loss=0.027, cons_loss=0.000]



Epoch 4/60 | Sens: 0.950, Spec: 0.056, FN: 1, FP: 17, Thr: 0.175


Epoch 5: 100%|██████████| 10/10 [00:33<00:00,  3.35s/it, sup_loss=0.051, cons_loss=0.000]



Epoch 5/60 | Sens: 0.900, Spec: 0.167, FN: 2, FP: 15, Thr: 0.270


Epoch 6: 100%|██████████| 10/10 [01:55<00:00, 11.52s/it, sup_loss=0.023, cons_loss=0.021]



Epoch 6/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.255


Epoch 7: 100%|██████████| 10/10 [02:00<00:00, 12.02s/it, sup_loss=0.063, cons_loss=0.015]



Epoch 7/60 | Sens: 0.950, Spec: 0.222, FN: 1, FP: 14, Thr: 0.180


Epoch 8: 100%|██████████| 10/10 [01:55<00:00, 11.59s/it, sup_loss=0.008, cons_loss=0.012]



Epoch 8/60 | Sens: 0.950, Spec: 0.222, FN: 1, FP: 14, Thr: 0.165


Epoch 9: 100%|██████████| 10/10 [01:53<00:00, 11.40s/it, sup_loss=0.252, cons_loss=0.035]



Epoch 9/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.240


Epoch 10: 100%|██████████| 10/10 [01:55<00:00, 11.52s/it, sup_loss=0.026, cons_loss=0.007]



Epoch 10/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.240


Epoch 11: 100%|██████████| 10/10 [01:54<00:00, 11.41s/it, sup_loss=0.181, cons_loss=0.011]



Epoch 11/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.235


Epoch 12: 100%|██████████| 10/10 [01:56<00:00, 11.66s/it, sup_loss=0.053, cons_loss=0.016]



Epoch 12/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.190


Epoch 13: 100%|██████████| 10/10 [01:54<00:00, 11.41s/it, sup_loss=0.003, cons_loss=0.022]



Epoch 13/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.275


Epoch 14: 100%|██████████| 10/10 [02:08<00:00, 12.85s/it, sup_loss=0.040, cons_loss=0.008]



Epoch 14/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.190


Epoch 15: 100%|██████████| 10/10 [02:12<00:00, 13.23s/it, sup_loss=0.203, cons_loss=0.019]



Epoch 15/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.165


Epoch 16: 100%|██████████| 10/10 [02:21<00:00, 14.11s/it, sup_loss=0.017, cons_loss=0.015]



Epoch 16/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.145


Epoch 17: 100%|██████████| 10/10 [02:28<00:00, 14.84s/it, sup_loss=0.043, cons_loss=0.008]



Epoch 17/60 | Sens: 0.900, Spec: 0.333, FN: 2, FP: 12, Thr: 0.185


Epoch 18: 100%|██████████| 10/10 [02:09<00:00, 12.95s/it, sup_loss=0.020, cons_loss=0.011]



Epoch 18/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.170


Epoch 19: 100%|██████████| 10/10 [01:58<00:00, 11.81s/it, sup_loss=0.085, cons_loss=0.012]



Epoch 19/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.125


Epoch 20: 100%|██████████| 10/10 [01:57<00:00, 11.78s/it, sup_loss=0.082, cons_loss=0.025]



Epoch 20/60 | Sens: 0.900, Spec: 0.389, FN: 2, FP: 11, Thr: 0.285


Epoch 21: 100%|██████████| 10/10 [01:56<00:00, 11.67s/it, sup_loss=0.068, cons_loss=0.017]



Epoch 21/60 | Sens: 0.900, Spec: 0.278, FN: 2, FP: 13, Thr: 0.145


Epoch 22: 100%|██████████| 10/10 [01:54<00:00, 11.43s/it, sup_loss=0.013, cons_loss=0.005]



Epoch 22/60 | Sens: 0.950, Spec: 0.278, FN: 1, FP: 13, Thr: 0.145

Early stopping!

TRAINING COMPLETED
Best Results: {'sensitivity': np.float64(1.0), 'specificity': np.float64(0.0), 'false_negatives': np.int64(0), 'false_positives': np.int64(18), 'is_clinically_acceptable': np.False_}
Loaded 190 labeled (Counter({1: 98, 0: 92})) and 6377 unlabeled images.

ENSEMBLE RESULTS
Sensitivity: 1.000
Specificity: 0.000
False Negatives: 0
False Positives: 18
Threshold: 0.010
