In [None]:
# ============================================
# üì¶ Step 1: Import Libraries
# ============================================

import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, classification_report
import cv2
import random
from tqdm import tqdm
from pathlib import Path
import kagglehub

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully")

In [None]:
# ============================================
# üìÅ Step 2: Download and Load Dataset
# ============================================

# Download NIH Chest X-ray 14 dataset (pre-resized to 224x224)
path = kagglehub.dataset_download("khanfashee/nih-chest-x-ray-14-224x224-resized")
BASE_PATH = Path(path)
print(f"üìÇ Dataset path: {BASE_PATH}")

# Load labels
df_labels = pd.read_csv(BASE_PATH / "Data_Entry_2017.csv")
images_dir = BASE_PATH / "images-224" / "images-224"
df_labels["Image Path"] = [str(images_dir / p) for p in df_labels["Image Index"].values]

# Define disease categories
DISEASE_CATEGORIES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

# Create binary columns for each disease
for disease in DISEASE_CATEGORIES:
    df_labels[disease] = df_labels['Finding Labels'].apply(lambda x: 1 if disease in x else 0)

# Validate sample images exist
sample_paths = df_labels['Image Path'].sample(200, random_state=42).values
missing = [p for p in sample_paths if not os.path.exists(p)]
if missing:
    raise FileNotFoundError(f"‚ùå Missing {len(missing)} images! First 3: {missing[:3]}")

print(f"‚úÖ Loaded {len(df_labels):,} images")
print(f"üìä Disease categories: {len(DISEASE_CATEGORIES)}")

In [None]:
# ============================================
# ‚öôÔ∏è Step 3: Configuration
# ============================================

class Config:
    # Model
    img_size = 224
    feat_dim = 256
    proj_dim = 128
    
    # Training
    batch_size = 64
    pretrain_epochs = 50
    finetune_epochs = 30
    lr_pretrain = 1e-3
    lr_finetune = 1e-4
    temperature = 0.1
    
    # Data
    num_workers = 4
    use_subset = False  # Set True for quick testing
    subset_size = 10000
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = Config()

print("‚öôÔ∏è Configuration:")
print(f"   Device: {cfg.device}")
print(f"   Batch size: {cfg.batch_size}")
print(f"   Pretrain epochs: {cfg.pretrain_epochs}")
print(f"   Finetune epochs: {cfg.finetune_epochs}")

# GPU optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================
# üîÑ Step 4: Data Augmentation
# ============================================

class ChestXrayAugment:
    """Augmentations for contrastive learning on chest X-rays"""
    
    def __init__(self, img_size=224):
        self.img_size = img_size
    
    def __call__(self, img):
        if isinstance(img, np.ndarray):
            x = torch.tensor(img, dtype=torch.float32)
        else:
            x = img.clone()
        
        # Random horizontal flip
        if random.random() < 0.5:
            x = torch.flip(x, dims=[2])
        
        # Random rotation (small angles)
        if random.random() < 0.7:
            angle = random.uniform(-15, 15)
            x = transforms.functional.rotate(x, angle)
        
        # Brightness adjustment
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_brightness(x, factor)
        
        # Contrast adjustment
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_contrast(x, factor)
        
        # Gaussian noise
        if random.random() < 0.5:
            noise = torch.randn_like(x) * 0.05
            x = torch.clamp(x + noise, 0, 1)
        
        return x

augment = ChestXrayAugment(cfg.img_size)
print("‚úÖ Augmentation pipeline ready")

In [None]:
# ============================================
# üì¶ Step 5: Dataset Classes
# ============================================

class PretrainDataset(Dataset):
    """Dataset for SSL pretraining"""
    
    def __init__(self, df, transform=None, img_size=224):
        self.df = df.copy().reset_index(drop=True)
        self.transform = transform
        self.img_size = img_size
        print(f"üì¶ PretrainDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Image Path']
        img = Image.open(img_path).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = np.expand_dims(img, 0)  # (1, H, W)
        
        if self.transform:
            view1 = self.transform(img)
            view2 = self.transform(img)
        else:
            view1 = torch.tensor(img, dtype=torch.float32)
            view2 = torch.tensor(img, dtype=torch.float32)
        
        return view1, view2


class ClassificationDataset(Dataset):
    """Dataset for multi-label classification"""
    
    def __init__(self, df, disease_categories, img_size=224):
        self.df = df.copy().reset_index(drop=True)
        self.disease_categories = disease_categories
        self.img_size = img_size
        print(f"üì¶ ClassificationDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['Image Path']).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        
        labels = torch.tensor([row[d] for d in self.disease_categories], dtype=torch.float32)
        return img, labels

print("‚úÖ Dataset classes defined")

In [None]:
# ============================================
# üèóÔ∏è Step 6: Model Architecture
# ============================================

def conv_block(in_c, out_c, kernel=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel, stride, padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    )

def residual_block(channels):
    return nn.Sequential(
        conv_block(channels, channels),
        conv_block(channels, channels)
    )


class Encoder(nn.Module):
    """CNN Encoder for feature extraction"""
    
    def __init__(self, in_channels=1, feat_dim=256):
        super().__init__()
        self.features = nn.Sequential(
            # Stage 1: 224 -> 112
            conv_block(in_channels, 64),
            residual_block(64),
            nn.MaxPool2d(2),
            
            # Stage 2: 112 -> 56
            conv_block(64, 128),
            residual_block(128),
            nn.MaxPool2d(2),
            
            # Stage 3: 56 -> 28
            conv_block(128, 256),
            residual_block(256),
            residual_block(256),
            nn.MaxPool2d(2),
            
            # Stage 4: 28 -> 14
            conv_block(256, 512),
            residual_block(512),
            residual_block(512),
            nn.MaxPool2d(2),
            
            # Stage 5: 14 -> 1
            conv_block(512, 512),
            residual_block(512),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, feat_dim)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class ProjectionHead(nn.Module):
    """Projection head for contrastive learning"""
    
    def __init__(self, feat_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, proj_dim)
        )
    
    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    """Decoder for reconstruction task"""
    
    def __init__(self, feat_dim=256, img_size=224):
        super().__init__()
        self.init_size = img_size // 32  # 7 for 224
        
        self.fc = nn.Sequential(
            nn.Linear(feat_dim, 256 * self.init_size * self.init_size),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 7->14
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 14->28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 28->56
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),    # 56->112
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1),     # 112->224
            nn.Sigmoid()
        )
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 256, self.init_size, self.init_size)
        return self.decoder(x)


class Classifier(nn.Module):
    """Multi-label classifier"""
    
    def __init__(self, feat_dim=256, num_classes=14):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)  # Returns logits


# Initialize models
encoder = Encoder(feat_dim=cfg.feat_dim).to(cfg.device)
proj_head = ProjectionHead(cfg.feat_dim, cfg.proj_dim).to(cfg.device)
decoder = Decoder(cfg.feat_dim, cfg.img_size).to(cfg.device)

total_params = sum(p.numel() for p in encoder.parameters()) + \
               sum(p.numel() for p in proj_head.parameters()) + \
               sum(p.numel() for p in decoder.parameters())

print(f"‚úÖ Models initialized")
print(f"   Total parameters: {total_params:,}")

In [None]:
# ============================================
# üî• Step 7: Loss Functions
# ============================================

def nt_xent_loss(z1, z2, temperature=0.1):
    """NT-Xent contrastive loss"""
    device = z1.device
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    batch_size = z1.shape[0]
    representations = torch.cat([z1, z2], dim=0)
    similarity = torch.matmul(representations, representations.T) / temperature
    
    # Mask self-similarities
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    similarity = similarity.masked_fill(mask, -float('inf'))
    
    # Labels: positive pairs
    labels = torch.cat([torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)]).to(device)
    
    return F.cross_entropy(similarity, labels)

print("‚úÖ Loss functions defined")

In [None]:
# ============================================
# üìä Step 8: Create Data Loaders
# ============================================

# Train/Val split
df_shuffled = df_labels.sample(frac=1, random_state=42).reset_index(drop=True)
train_size = int(0.8 * len(df_shuffled))
train_df = df_shuffled[:train_size]
val_df = df_shuffled[train_size:]

if cfg.use_subset:
    train_df = train_df.head(cfg.subset_size)
    val_df = val_df.head(cfg.subset_size // 4)
    print(f"‚ö° Using subset: {len(train_df)} train, {len(val_df)} val")

# Datasets
train_pretrain_ds = PretrainDataset(train_df, transform=augment, img_size=cfg.img_size)
train_class_ds = ClassificationDataset(train_df, DISEASE_CATEGORIES, cfg.img_size)
val_class_ds = ClassificationDataset(val_df, DISEASE_CATEGORIES, cfg.img_size)

# DataLoaders
pretrain_loader = DataLoader(
    train_pretrain_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)
train_loader = DataLoader(
    train_class_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)
val_loader = DataLoader(
    val_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True
)

print(f"‚úÖ DataLoaders ready")
print(f"   Train batches: {len(pretrain_loader)} (pretrain), {len(train_loader)} (classify)")
print(f"   Val batches: {len(val_loader)}")

In [None]:
# ============================================
# üöÄ Step 9: SSL Pretraining
# ============================================

optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(proj_head.parameters()) + list(decoder.parameters()),
    lr=cfg.lr_pretrain, weight_decay=1e-4
)

ssl_history = {'loss': [], 'contrastive': [], 'reconstruction': []}

print("üöÄ Starting SSL Pretraining (Baseline)")
print("=" * 50)

for epoch in range(1, cfg.pretrain_epochs + 1):
    encoder.train()
    proj_head.train()
    decoder.train()
    
    total_loss = 0
    total_cont = 0
    total_recon = 0
    
    pbar = tqdm(pretrain_loader, desc=f"Epoch {epoch}/{cfg.pretrain_epochs}")
    for view1, view2 in pbar:
        view1 = view1.to(cfg.device)
        view2 = view2.to(cfg.device)
        
        optimizer.zero_grad()
        
        # Encode
        z1 = encoder(view1)
        z2 = encoder(view2)
        
        # Contrastive loss
        p1 = proj_head(z1)
        p2 = proj_head(z2)
        cont_loss = nt_xent_loss(p1, p2, cfg.temperature)
        
        # Reconstruction loss
        rec1 = decoder(z1)
        rec2 = decoder(z2)
        recon_loss = (F.mse_loss(rec1, view1) + F.mse_loss(rec2, view2)) / 2
        
        # Combined loss
        loss = cont_loss + 0.5 * recon_loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_cont += cont_loss.item()
        total_recon += recon_loss.item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Log epoch metrics
    n = len(pretrain_loader)
    ssl_history['loss'].append(total_loss / n)
    ssl_history['contrastive'].append(total_cont / n)
    ssl_history['reconstruction'].append(total_recon / n)
    
    print(f"Epoch {epoch}: Loss={total_loss/n:.4f}, Cont={total_cont/n:.4f}, Recon={total_recon/n:.4f}")

print("\n‚úÖ SSL Pretraining Complete!")

In [None]:
# ============================================
# üìà Step 10: Plot SSL Training Curves
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ssl_history['loss'], 'b-', linewidth=2)
axes[0].set_title('Total Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].grid(True, alpha=0.3)

axes[1].plot(ssl_history['contrastive'], 'r-', linewidth=2)
axes[1].set_title('Contrastive Loss', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].grid(True, alpha=0.3)

axes[2].plot(ssl_history['reconstruction'], 'g-', linewidth=2)
axes[2].set_title('Reconstruction Loss', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Baseline SSL Training Curves', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('baseline_ssl_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üíæ Step 11: Save Pretrained Model
# ============================================

torch.save({
    'encoder': encoder.state_dict(),
    'proj_head': proj_head.state_dict(),
    'decoder': decoder.state_dict(),
    'config': {
        'feat_dim': cfg.feat_dim,
        'proj_dim': cfg.proj_dim,
        'img_size': cfg.img_size
    }
}, 'baseline_ssl_pretrained.pth')

print("üíæ Pretrained model saved: baseline_ssl_pretrained.pth")

In [None]:
# ============================================
# üéØ Step 12: Fine-tuning for Classification
# ============================================

# Freeze encoder
for param in encoder.parameters():
    param.requires_grad = False
encoder.eval()

# Initialize classifier
classifier = Classifier(cfg.feat_dim, len(DISEASE_CATEGORIES)).to(cfg.device)

# Class weights for imbalanced data
pos_counts = train_df[DISEASE_CATEGORIES].sum().values
neg_counts = len(train_df) - pos_counts
pos_weights = torch.tensor(neg_counts / (pos_counts + 1e-6), dtype=torch.float32).to(cfg.device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
optimizer = torch.optim.Adam(classifier.parameters(), lr=cfg.lr_finetune, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5, factor=0.5)

finetune_history = {'train_loss': [], 'train_auc': [], 'val_loss': [], 'val_auc': []}
best_val_auc = 0

print("üéØ Starting Fine-tuning")
print("=" * 50)

for epoch in range(1, cfg.finetune_epochs + 1):
    # Training
    classifier.train()
    train_loss = 0
    train_preds, train_targets = [], []
    
    for images, targets in tqdm(train_loader, desc=f"Train {epoch}/{cfg.finetune_epochs}"):
        images = images.to(cfg.device)
        targets = targets.to(cfg.device)
        
        optimizer.zero_grad()
        with torch.no_grad():
            features = encoder(images)
        logits = classifier(features)
        loss = criterion(logits, targets)
        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_preds.append(torch.sigmoid(logits).detach().cpu())
        train_targets.append(targets.cpu())
    
    # Validation
    classifier.eval()
    val_loss = 0
    val_preds, val_targets = [], []
    
    with torch.no_grad():
        for images, targets in val_loader:
            images = images.to(cfg.device)
            targets = targets.to(cfg.device)
            
            features = encoder(images)
            logits = classifier(features)
            loss = criterion(logits, targets)
            
            val_loss += loss.item()
            val_preds.append(torch.sigmoid(logits).cpu())
            val_targets.append(targets.cpu())
    
    # Calculate metrics
    train_preds = torch.cat(train_preds).numpy()
    train_targets = torch.cat(train_targets).numpy()
    val_preds = torch.cat(val_preds).numpy()
    val_targets = torch.cat(val_targets).numpy()
    
    train_auc = np.mean([roc_auc_score(train_targets[:, i], train_preds[:, i]) 
                         for i in range(len(DISEASE_CATEGORIES)) 
                         if len(np.unique(train_targets[:, i])) > 1])
    val_auc = np.mean([roc_auc_score(val_targets[:, i], val_preds[:, i]) 
                       for i in range(len(DISEASE_CATEGORIES)) 
                       if len(np.unique(val_targets[:, i])) > 1])
    
    # Log
    finetune_history['train_loss'].append(train_loss / len(train_loader))
    finetune_history['train_auc'].append(train_auc)
    finetune_history['val_loss'].append(val_loss / len(val_loader))
    finetune_history['val_auc'].append(val_auc)
    
    scheduler.step(val_auc)
    
    print(f"Epoch {epoch}: Train AUC={train_auc:.4f}, Val AUC={val_auc:.4f}")
    
    # Save best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save({
            'encoder': encoder.state_dict(),
            'classifier': classifier.state_dict(),
            'val_auc': val_auc
        }, 'baseline_best_model.pth')
        print(f"  ‚úÖ Best model saved! Val AUC: {val_auc:.4f}")

print(f"\nüèÜ Best Validation AUC: {best_val_auc:.4f}")

In [None]:
# ============================================
# üìä Step 13: Plot Fine-tuning Curves
# ============================================

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(finetune_history['train_loss'], 'b-', label='Train', linewidth=2)
axes[0].plot(finetune_history['val_loss'], 'r-', label='Val', linewidth=2)
axes[0].set_title('Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(finetune_history['train_auc'], 'b-', label='Train', linewidth=2)
axes[1].plot(finetune_history['val_auc'], 'r-', label='Val', linewidth=2)
axes[1].set_title('Mean AUC', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Baseline Fine-tuning Curves', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('baseline_finetune_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üìà Step 14: Final Evaluation
# ============================================

# Load best model
checkpoint = torch.load('baseline_best_model.pth')
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

encoder.eval()
classifier.eval()

all_preds, all_targets = [], []
with torch.no_grad():
    for images, targets in tqdm(val_loader, desc="Evaluating"):
        images = images.to(cfg.device)
        features = encoder(images)
        logits = classifier(features)
        all_preds.append(torch.sigmoid(logits).cpu())
        all_targets.append(targets)

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

# Per-disease AUC
print("\nüìä Per-Disease AUC Scores:")
print("=" * 40)
auc_scores = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
        auc_scores.append((disease, auc))
        print(f"{disease:20s}: {auc:.4f}")

mean_auc = np.mean([a for _, a in auc_scores])
print(f"\n{'Mean AUC':20s}: {mean_auc:.4f}")

# Plot AUC bar chart
auc_scores.sort(key=lambda x: x[1], reverse=True)
diseases, aucs = zip(*auc_scores)

plt.figure(figsize=(12, 6))
colors = ['green' if a >= 0.7 else 'orange' if a >= 0.6 else 'red' for a in aucs]
plt.barh(diseases, aucs, color=colors, alpha=0.8)
plt.axvline(0.5, color='red', linestyle='--', alpha=0.5, label='Random')
plt.axvline(mean_auc, color='blue', linestyle='--', alpha=0.7, label=f'Mean: {mean_auc:.3f}')
plt.xlabel('AUC Score')
plt.title('Baseline: Per-Disease AUC Performance', fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
plt.savefig('baseline_auc_performance.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nüèÜ BASELINE RESULTS")
print(f"   Mean AUC: {mean_auc:.4f}")

In [None]:
# ============================================
# üìù Summary
# ============================================

print("\n" + "=" * 60)
print("üìù BASELINE SSL SUMMARY")
print("=" * 60)
print(f"\nMethod: Standard SimCLR (NT-Xent + Reconstruction)")
print(f"Dataset: NIH Chest X-ray 14")
print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"\nPretraining epochs: {cfg.pretrain_epochs}")
print(f"Fine-tuning epochs: {cfg.finetune_epochs}")
print(f"\nüèÜ Final Mean AUC: {mean_auc:.4f}")
print("\nFiles saved:")
print("  - baseline_ssl_pretrained.pth")
print("  - baseline_best_model.pth")
print("  - baseline_ssl_curves.png")
print("  - baseline_finetune_curves.png")
print("  - baseline_auc_performance.png")
print("=" * 60)