In [None]:
# ============================================
# üì¶ Step 1: Import Libraries
# ============================================

import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, classification_report
import cv2
import random
from tqdm import tqdm
from pathlib import Path
import kagglehub

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully")

In [None]:
# ============================================
# üìÅ Step 2: Download and Load Dataset
# ============================================

# Download NIH Chest X-ray 14 dataset (pre-resized to 224x224)
path = kagglehub.dataset_download("khanfashee/nih-chest-x-ray-14-224x224-resized")
BASE_PATH = Path(path)
print(f"üìÇ Dataset path: {BASE_PATH}")

# Load labels
df_labels = pd.read_csv(BASE_PATH / "Data_Entry_2017.csv")
images_dir = BASE_PATH / "images-224" / "images-224"
df_labels["Image Path"] = [str(images_dir / p) for p in df_labels["Image Index"].values]

# Define disease categories
DISEASE_CATEGORIES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

# Create binary columns for each disease
for disease in DISEASE_CATEGORIES:
    df_labels[disease] = df_labels['Finding Labels'].apply(lambda x: 1 if disease in x else 0)

# Validate sample images exist
sample_paths = df_labels['Image Path'].sample(200, random_state=42).values
missing = [p for p in sample_paths if not os.path.exists(p)]
if missing:
    raise FileNotFoundError(f"‚ùå Missing {len(missing)} images! First 3: {missing[:3]}")

print(f"‚úÖ Loaded {len(df_labels):,} images")
print(f"üìä Disease categories: {len(DISEASE_CATEGORIES)}")

In [None]:
# ============================================
# ‚öôÔ∏è Step 3: Configuration
# ============================================

class Config:
    # Model
    img_size = 224
    feat_dim = 256
    proj_dim = 128
    
    # Training (DannyNet-inspired settings)
    batch_size = 64                # Same as DannyNet
    pretrain_epochs = 50
    finetune_epochs = 30
    lr_pretrain = 1e-3             # For SSL pretraining
    lr_finetune = 5e-5             # ‚úÖ DannyNet uses 5e-5 (was 1e-4)
    temperature = 0.1
    
    # Data
    num_workers = 4
    use_subset = False  # Set True for quick testing
    subset_size = 10000
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = Config()

print("‚öôÔ∏è Configuration:")
print(f"   Device: {cfg.device}")
print(f"   Batch size: {cfg.batch_size}")
print(f"   Pretrain epochs: {cfg.pretrain_epochs}")
print(f"   Finetune epochs: {cfg.finetune_epochs}")
print(f"   LR pretrain: {cfg.lr_pretrain}")
print(f"   LR finetune: {cfg.lr_finetune} (DannyNet setting)")

# GPU optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    print(f"   GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# ============================================
# üíæ Step 3.5: Checkpoint & Resume Configuration
# ============================================
# ‚ö†Ô∏è EDIT THIS SECTION WHEN RESUMING AFTER DAYS/WEEKS

import os
import shutil
from datetime import datetime

# ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
# ‚ïë  üîß RESUME CONFIGURATION - EDIT THESE VALUES WHEN RESUMING  ‚ïë
# ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

# ===== STEP 1: Set your checkpoint dataset name =====
# After your first run, save outputs as dataset and put the name here
# ‚ö†Ô∏è Use the EXACT name (lowercase, hyphens) - check /kaggle/input/
CHECKPOINT_DATASET_NAME = "baseline-ssl-checkpoints"  # Unique for baseline

# ===== STEP 2: Set resume flags =====
RESUME_SSL_PRETRAINING = True    # Set True to resume SSL pretraining
RESUME_FINETUNING = True         # Set True to resume fine-tuning

# ===== STEP 3: If resuming, specify which checkpoint to load =====
# Leave as "latest" to auto-detect, or specify: "baseline_ssl_epoch20.pth"
SSL_CHECKPOINT_FILE = "latest"
FINETUNE_CHECKPOINT_FILE = "latest"

# ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
# ‚ïë                    END OF USER CONFIG                        ‚ïë
# ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù

# Detect environment
IN_KAGGLE = os.path.exists('/kaggle')

if IN_KAGGLE:
    CHECKPOINT_DIR = '/kaggle/working/checkpoints'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    
    print("üîç Scanning for checkpoint datasets...")
    print("="*60)
    
    # Load checkpoints from ALL versions of the dataset (v1, v2, v3, etc.)
    # This allows keeping older versions while adding new ones
    input_path = '/kaggle/input'
    if os.path.exists(input_path):
        found_any = False
        for dataset_folder in sorted(os.listdir(input_path)):
            # Match datasets starting with our checkpoint name (ssl-checkpoints, ssl-checkpoints-v2, etc.)
            if dataset_folder.startswith(CHECKPOINT_DATASET_NAME):
                dataset_path = os.path.join(input_path, dataset_folder)
                if os.path.isdir(dataset_path):
                    # Check for .pth files in multiple locations:
                    # 1. Directly in dataset folder
                    # 2. In 'checkpoints' subdirectory
                    # 3. In any subdirectory
                    search_paths = [dataset_path]
                    
                    # Add checkpoints subdirectory if it exists
                    checkpoints_subdir = os.path.join(dataset_path, 'checkpoints')
                    if os.path.isdir(checkpoints_subdir):
                        search_paths.append(checkpoints_subdir)
                    
                    # Also check any other subdirectories for .pth files
                    for item in os.listdir(dataset_path):
                        item_path = os.path.join(dataset_path, item)
                        if os.path.isdir(item_path) and item != 'checkpoints':
                            search_paths.append(item_path)
                    
                    for search_path in search_paths:
                        pth_files = [f for f in os.listdir(search_path) if f.endswith('.pth')]
                        if pth_files:
                            found_any = True
                            rel_path = os.path.relpath(search_path, input_path)
                            print(f"üìÇ Found checkpoints in: {rel_path}")
                            for f in pth_files:
                                src = os.path.join(search_path, f)
                                dst = os.path.join(CHECKPOINT_DIR, f)
                                if not os.path.exists(dst):
                                    shutil.copy2(src, dst)
                                    print(f"   üì¶ Copied: {f}")
                                else:
                                    # Check if source is newer
                                    src_time = os.path.getmtime(src)
                                    dst_time = os.path.getmtime(dst)
                                    if src_time > dst_time:
                                        shutil.copy2(src, dst)
                                        print(f"   üîÑ Updated: {f} (newer version)")
        
        if not found_any:
            print(f"‚ÑπÔ∏è No checkpoint datasets found matching: {CHECKPOINT_DATASET_NAME}*")
            print("   This is normal for a fresh start!")
    
    existing = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')]
    print("="*60)
    if existing:
        print(f"‚úÖ Total checkpoints available: {len(existing)}")
    else:
        print(f"‚ÑπÔ∏è Starting fresh - no checkpoints loaded")
        
else:\n    CHECKPOINT_DIR = './checkpoints'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ===== CHECKPOINT UTILITIES =====
def save_checkpoint(state, filename):
    """Save checkpoint with timestamp"""
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    state['saved_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    torch.save(state, filepath)
    print(f"üíæ Saved: {filename}")
    
    if IN_KAGGLE:
        # Also save to root /kaggle/working/ for easy access
        torch.save(state, f'/kaggle/working/{filename}')

def load_checkpoint(filename):
    """Load checkpoint from storage"""
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=cfg.device, weights_only=False)
        saved_at = checkpoint.get('saved_at', 'Unknown')
        print(f"‚úÖ Loaded: {filename} (saved: {saved_at})")
        return checkpoint
    print(f"‚ö†Ô∏è Not found: {filepath}")
    return None

def find_latest_checkpoint(prefix):
    """Find the most recent checkpoint with given prefix"""
    if not os.path.exists(CHECKPOINT_DIR):
        return None
    
    # First check for 'latest' checkpoint
    latest_file = f'{prefix}_latest.pth'
    if os.path.exists(os.path.join(CHECKPOINT_DIR, latest_file)):
        return latest_file
    
    # Otherwise find highest epoch number
    import re
    pattern = re.compile(rf'{prefix}_epoch(\d+)\.pth')
    max_epoch = -1
    best_file = None
    
    for f in os.listdir(CHECKPOINT_DIR):
        match = pattern.match(f)
        if match:
            epoch = int(match.group(1))
            if epoch > max_epoch:
                max_epoch = epoch
                best_file = f
    
    return best_file

def list_checkpoints():
    """List all available checkpoints with details"""
    print(f"\nüìÅ Checkpoints in {CHECKPOINT_DIR}:")
    if not os.path.exists(CHECKPOINT_DIR):
        print("   (empty)")
        return []
    
    files = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')])
    if not files:
        print("   (empty)")
        return []
    
    for f in files:
        filepath = os.path.join(CHECKPOINT_DIR, f)
        size = os.path.getsize(filepath) / (1024*1024)
        try:
            ckpt = torch.load(filepath, map_location='cpu', weights_only=False)
            epoch = ckpt.get('epoch', '?')
            saved_at = ckpt.get('saved_at', 'Unknown')
            print(f"   üì¶ {f} | Epoch {epoch} | {size:.1f}MB | {saved_at}")
        except:
            print(f"   üì¶ {f} | {size:.1f}MB")
    return files

def get_training_status():
    """Get current training progress"""
    ssl_ckpt = find_latest_checkpoint('baseline_ssl')
    ft_ckpt = find_latest_checkpoint('baseline_finetune')
    
    print("\n" + "="*60)
    print("üìä TRAINING STATUS")
    print("="*60)
    
    if ssl_ckpt:
        ckpt = torch.load(os.path.join(CHECKPOINT_DIR, ssl_ckpt), map_location='cpu', weights_only=False)
        ssl_epoch = ckpt.get('epoch', 0)
        print(f"SSL Pretraining: Epoch {ssl_epoch}/{cfg.pretrain_epochs} "
              f"({'COMPLETE ‚úÖ' if ssl_epoch >= cfg.pretrain_epochs else 'IN PROGRESS'})")
    else:
        print("SSL Pretraining: NOT STARTED")
        ssl_epoch = 0
    
    if ft_ckpt:
        ckpt = torch.load(os.path.join(CHECKPOINT_DIR, ft_ckpt), map_location='cpu', weights_only=False)
        ft_epoch = ckpt.get('epoch', 0)
        best_auc = ckpt.get('best_val_auc', 0)
        print(f"Fine-tuning: Epoch {ft_epoch}/{cfg.finetune_epochs} "
              f"({'COMPLETE ‚úÖ' if ft_epoch >= cfg.finetune_epochs else 'IN PROGRESS'})")
        print(f"Best Val AUC: {best_auc:.4f}")
    else:
        print("Fine-tuning: NOT STARTED")
    
    print("="*60)

# Show current status
print(f"\nüîß Environment: {'Kaggle' if IN_KAGGLE else 'Local'}")
print(f"üìÇ Checkpoint directory: {CHECKPOINT_DIR}")
list_checkpoints()
get_training_status()

print("\n" + "="*60)
print("üìå SAVING & RESUMING WORKFLOW:")
print("="*60)
print(f"""
After each run on Kaggle:
  1. Click 'Save Version' ‚Üí 'Quick Save'
  2. Go to Output tab ‚Üí '+ New Dataset'
  3. Name it: {CHECKPOINT_DATASET_NAME}
     (Kaggle will auto-version: {CHECKPOINT_DATASET_NAME}, {CHECKPOINT_DATASET_NAME}-v2, etc.)

To resume in a NEW session:
  1. Click 'Add Input' (right panel)
  2. Select 'Your Datasets' ‚Üí Add ALL versions of {CHECKPOINT_DATASET_NAME}
  3. Run notebook - it will automatically load the latest checkpoints!
""")
print("="*60)

In [None]:
# ============================================
# üîÑ Step 4: Data Augmentation
# ============================================

class ChestXrayAugment:
    """Augmentations for contrastive learning on chest X-rays"""
    
    def __init__(self, img_size=224):
        self.img_size = img_size
    
    def __call__(self, img):
        if isinstance(img, np.ndarray):
            x = torch.tensor(img, dtype=torch.float32)
        else:
            x = img.clone()
        
        # Random horizontal flip
        if random.random() < 0.5:
            x = torch.flip(x, dims=[2])
        
        # Random rotation (small angles)
        if random.random() < 0.7:
            angle = random.uniform(-15, 15)
            x = transforms.functional.rotate(x, angle)
        
        # Brightness adjustment
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_brightness(x, factor)
        
        # Contrast adjustment
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_contrast(x, factor)
        
        # Gaussian noise
        if random.random() < 0.5:
            noise = torch.randn_like(x) * 0.05
            x = torch.clamp(x + noise, 0, 1)
        
        return x

augment = ChestXrayAugment(cfg.img_size)
print("‚úÖ Augmentation pipeline ready")

In [None]:
# ============================================
# üì¶ Step 5: Dataset Classes
# ============================================

class PretrainDataset(Dataset):
    """Dataset for SSL pretraining"""
    
    def __init__(self, df, transform=None, img_size=224):
        self.df = df.copy().reset_index(drop=True)
        self.transform = transform
        self.img_size = img_size
        print(f"üì¶ PretrainDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Image Path']
        img = Image.open(img_path).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = np.expand_dims(img, 0)  # (1, H, W)
        
        if self.transform:
            view1 = self.transform(img)
            view2 = self.transform(img)
        else:
            view1 = torch.tensor(img, dtype=torch.float32)
            view2 = torch.tensor(img, dtype=torch.float32)
        
        return view1, view2


class ClassificationDataset(Dataset):
    """Dataset for multi-label classification with optional augmentation"""
    
    def __init__(self, df, disease_categories, img_size=224, is_training=False):
        self.df = df.copy().reset_index(drop=True)
        self.disease_categories = disease_categories
        self.img_size = img_size
        self.is_training = is_training
        print(f"üì¶ ClassificationDataset: {len(self.df)} samples (training={is_training})")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['Image Path']).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        
        # Apply augmentation during training
        if self.is_training:
            if random.random() < 0.5:
                img = torch.flip(img, dims=[2])
            if random.random() < 0.5:
                angle = random.uniform(-15, 15)
                img = transforms.functional.rotate(img, angle)
            if random.random() < 0.5:
                factor = 1 + random.uniform(-0.2, 0.2)
                img = transforms.functional.adjust_brightness(img, factor)
            if random.random() < 0.5:
                factor = 1 + random.uniform(-0.2, 0.2)
                img = transforms.functional.adjust_contrast(img, factor)
        
        labels = torch.tensor([row[d] for d in self.disease_categories], dtype=torch.float32)
        return img, labels

print("‚úÖ Dataset classes defined")

In [None]:
# ============================================
# üèóÔ∏è Step 6: Model Architecture
# ============================================

def conv_block(in_c, out_c, kernel=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel, stride, padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    )

def residual_block(channels):
    return nn.Sequential(
        conv_block(channels, channels),
        conv_block(channels, channels)
    )


class Encoder(nn.Module):
    """CNN Encoder for feature extraction"""
    
    def __init__(self, in_channels=1, feat_dim=256):
        super().__init__()
        self.features = nn.Sequential(
            # Stage 1: 224 -> 112
            conv_block(in_channels, 64),
            residual_block(64),
            nn.MaxPool2d(2),
            
            # Stage 2: 112 -> 56
            conv_block(64, 128),
            residual_block(128),
            nn.MaxPool2d(2),
            
            # Stage 3: 56 -> 28
            conv_block(128, 256),
            residual_block(256),
            residual_block(256),
            nn.MaxPool2d(2),
            
            # Stage 4: 28 -> 14
            conv_block(256, 512),
            residual_block(512),
            residual_block(512),
            nn.MaxPool2d(2),
            
            # Stage 5: 14 -> 1
            conv_block(512, 512),
            residual_block(512),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.fc = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, feat_dim)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class ProjectionHead(nn.Module):
    """Projection head for contrastive learning"""
    
    def __init__(self, feat_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, feat_dim),
            nn.BatchNorm1d(feat_dim),
            nn.ReLU(),
            nn.Linear(feat_dim, proj_dim)
        )
    
    def forward(self, x):
        return self.net(x)


class Decoder(nn.Module):
    """Decoder for reconstruction task"""
    
    def __init__(self, feat_dim=256, img_size=224):
        super().__init__()
        self.init_size = img_size // 32  # 7 for 224
        
        self.fc = nn.Sequential(
            nn.Linear(feat_dim, 256 * self.init_size * self.init_size),
            nn.ReLU()
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 7->14
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 14->28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),    # 28->56
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),    # 56->112
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1),     # 112->224
            nn.Sigmoid()
        )
    
    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 256, self.init_size, self.init_size)
        return self.decoder(x)


class Classifier(nn.Module):
    """Multi-label classifier"""
    
    def __init__(self, feat_dim=256, num_classes=14):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)  # Returns logits


# Initialize models
encoder = Encoder(feat_dim=cfg.feat_dim).to(cfg.device)
proj_head = ProjectionHead(cfg.feat_dim, cfg.proj_dim).to(cfg.device)
decoder = Decoder(cfg.feat_dim, cfg.img_size).to(cfg.device)

total_params = sum(p.numel() for p in encoder.parameters()) + \
               sum(p.numel() for p in proj_head.parameters()) + \
               sum(p.numel() for p in decoder.parameters())

print(f"‚úÖ Models initialized")
print(f"   Total parameters: {total_params:,}")

In [None]:
# ============================================
# üî• Step 7: Loss Functions
# ============================================

def nt_xent_loss(z1, z2, temperature=0.1):
    """NT-Xent contrastive loss"""
    device = z1.device
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    batch_size = z1.shape[0]
    representations = torch.cat([z1, z2], dim=0)
    similarity = torch.matmul(representations, representations.T) / temperature
    
    # Mask self-similarities
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    similarity = similarity.masked_fill(mask, -float('inf'))
    
    # Labels: positive pairs
    labels = torch.cat([torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)]).to(device)
    
    return F.cross_entropy(similarity, labels)


class FocalLoss(nn.Module):
    """
    Focal Loss for imbalanced classification (from DannyNet SOTA)
    Down-weights easy examples, focuses on hard ones
    """
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return focal_loss.mean()


print("‚úÖ Loss functions defined")
print("   üéØ NT-Xent: Contrastive loss")
print("   üéØ FocalLoss: For class imbalance (Œ±=1.0, Œ≥=2.0)")

In [None]:
# ============================================
# üìä Step 8: Create Data Loaders (Patient-Level Split)
# ============================================
from sklearn.model_selection import train_test_split

# ‚ö†Ô∏è CRITICAL: Patient-level splitting to prevent data leakage
print("="*60)
print("üîÄ PATIENT-LEVEL SPLITTING")
print("="*60)

unique_patients = df_labels['Patient ID'].unique()
print(f"Total unique patients: {len(unique_patients):,}")

# Split patients: 93% train, 5% val, 2% test
train_val_patients, test_patients = train_test_split(
    unique_patients, test_size=0.02, random_state=42
)
train_patients, val_patients = train_test_split(
    train_val_patients, test_size=0.052, random_state=42
)

# Create dataframes based on patient splits
train_df = df_labels[df_labels['Patient ID'].isin(train_patients)].copy()
val_df = df_labels[df_labels['Patient ID'].isin(val_patients)].copy()
test_df = df_labels[df_labels['Patient ID'].isin(test_patients)].copy()

print(f"‚úì Train: {len(train_df):,} images from {len(train_patients):,} patients")
print(f"‚úì Val: {len(val_df):,} images from {len(val_patients):,} patients")
print(f"‚úì Test: {len(test_df):,} images from {len(test_patients):,} patients")
print("="*60)

if cfg.use_subset:
    train_df = train_df.head(cfg.subset_size)
    val_df = val_df.head(cfg.subset_size // 4)
    test_df = test_df.head(cfg.subset_size // 8)
    print(f"‚ö° Using subset: {len(train_df)} train, {len(val_df)} val, {len(test_df)} test")

# Datasets - NOW WITH AUGMENTATION FOR TRAINING
train_pretrain_ds = PretrainDataset(train_df, transform=augment, img_size=cfg.img_size)
train_class_ds = ClassificationDataset(train_df, DISEASE_CATEGORIES, cfg.img_size, is_training=True)
val_class_ds = ClassificationDataset(val_df, DISEASE_CATEGORIES, cfg.img_size, is_training=False)
test_class_ds = ClassificationDataset(test_df, DISEASE_CATEGORIES, cfg.img_size, is_training=False)

# DataLoaders - FAST PIPELINE (like tf.data)
# üöÄ num_workers: Parallel data loading (like num_parallel_calls)
# üöÄ pin_memory: Faster CPU‚ÜíGPU transfer  
# üöÄ prefetch_factor: Prefetch batches per worker (like prefetch)
# üöÄ persistent_workers: Keep workers alive between epochs
pretrain_loader = DataLoader(
    train_pretrain_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
train_loader = DataLoader(
    train_class_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
val_loader = DataLoader(
    val_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
test_loader = DataLoader(
    test_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)

print(f"‚úÖ DataLoaders ready - FAST PIPELINE (with training augmentation)")
print(f"   Train batches: {len(pretrain_loader)} (pretrain), {len(train_loader)} (classify)")
print(f"   Val batches: {len(val_loader)}")
print(f"   Test batches: {len(test_loader)}")

In [None]:
# ============================================
# üöÄ Step 9: SSL Pretraining
# ============================================

# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()

optimizer_ssl = torch.optim.Adam(
    list(encoder.parameters()) + list(proj_head.parameters()) + list(decoder.parameters()),
    lr=cfg.lr_pretrain, weight_decay=1e-4
)

ssl_history = {'loss': [], 'contrastive': [], 'reconstruction': []}
START_EPOCH = 1

# ===== AUTO-RESUME FROM CHECKPOINT =====
if RESUME_SSL_PRETRAINING:
    if SSL_CHECKPOINT_FILE == "latest":
        ckpt_file = find_latest_checkpoint('baseline_ssl')
    else:
        ckpt_file = SSL_CHECKPOINT_FILE
    
    if ckpt_file:
        checkpoint = load_checkpoint(ckpt_file)
        if checkpoint:
            encoder.load_state_dict(checkpoint['encoder'])
            proj_head.load_state_dict(checkpoint['proj_head'])
            decoder.load_state_dict(checkpoint['decoder'])
            if 'optimizer' in checkpoint:
                optimizer_ssl.load_state_dict(checkpoint['optimizer'])
            ssl_history = checkpoint.get('ssl_history', ssl_history)
            START_EPOCH = checkpoint['epoch'] + 1
            print(f"üîÑ Resuming SSL pretraining from epoch {START_EPOCH}")
    else:
        print("‚ö†Ô∏è RESUME_SSL_PRETRAINING=True but no checkpoint found. Starting fresh.")

if START_EPOCH > cfg.pretrain_epochs:
    print(f"‚úÖ SSL Pretraining already complete ({cfg.pretrain_epochs} epochs)")
    print("   Skipping to next step...")
else:
    print(f"\nüöÄ Starting Baseline SSL Pretraining")
    print(f"   Epochs: {START_EPOCH} ‚Üí {cfg.pretrain_epochs}")
    print("=" * 60)
    
    SAVE_EVERY = 5  # Save every 5 epochs
    
    for epoch in range(START_EPOCH, cfg.pretrain_epochs + 1):
        encoder.train()
        proj_head.train()
        decoder.train()
        
        total_loss = 0
        total_cont = 0
        total_recon = 0
        
        loader = tqdm(pretrain_loader, desc=f"Epoch {epoch}/{cfg.pretrain_epochs}") if not IN_KAGGLE else pretrain_loader
        for view1, view2 in loader:
            view1 = view1.to(cfg.device)
            view2 = view2.to(cfg.device)
            
            optimizer_ssl.zero_grad()
            
            # Encode
            z1 = encoder(view1)
            z2 = encoder(view2)
            
            # Contrastive loss
            p1 = proj_head(z1)
            p2 = proj_head(z2)
            cont_loss = nt_xent_loss(p1, p2, cfg.temperature)
            
            # Reconstruction loss
            rec1 = decoder(z1)
            rec2 = decoder(z2)
            recon_loss = (F.mse_loss(rec1, view1) + F.mse_loss(rec2, view2)) / 2
            
            # Combined loss
            loss = cont_loss + 0.5 * recon_loss
            
            loss.backward()
            optimizer_ssl.step()
            
            total_loss += loss.item()
            total_cont += cont_loss.item()
            total_recon += recon_loss.item()
            
            if not IN_KAGGLE:
                loader.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Free memory
            del z1, z2, p1, p2, rec1, rec2, loss, cont_loss, recon_loss
        
        # Clear cache at end of epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Log epoch metrics
        n = len(pretrain_loader)
        ssl_history['loss'].append(total_loss / n)
        ssl_history['contrastive'].append(total_cont / n)
        ssl_history['reconstruction'].append(total_recon / n)
        
        print(f"Epoch {epoch}: Loss={total_loss/n:.4f}, Cont={total_cont/n:.4f}, Recon={total_recon/n:.4f}")
        
        # Save checkpoints periodically
        if epoch % SAVE_EVERY == 0 or epoch == cfg.pretrain_epochs:
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'proj_head': proj_head.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer_ssl.state_dict(),
                'ssl_history': ssl_history,
                'config': vars(cfg),
                'phase': 'ssl_pretraining'
            }, 'baseline_ssl_latest.pth')
            
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'proj_head': proj_head.state_dict(),
                'decoder': decoder.state_dict(),
                'ssl_history': ssl_history,
            }, f'baseline_ssl_epoch{epoch}.pth')
    
    print("\n‚úÖ Baseline SSL Pretraining Complete!")

In [None]:
# ============================================
# üìà Step 10: Plot SSL Training Curves
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ssl_history['loss'], 'b-', linewidth=2)
axes[0].set_title('Total Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].grid(True, alpha=0.3)

axes[1].plot(ssl_history['contrastive'], 'r-', linewidth=2)
axes[1].set_title('Contrastive Loss', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].grid(True, alpha=0.3)

axes[2].plot(ssl_history['reconstruction'], 'g-', linewidth=2)
axes[2].set_title('Reconstruction Loss', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Baseline SSL Training Curves', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('baseline_ssl_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üíæ Step 11: Save Pretrained Model
# ============================================

torch.save({
    'encoder': encoder.state_dict(),
    'proj_head': proj_head.state_dict(),
    'decoder': decoder.state_dict(),
    'config': {
        'feat_dim': cfg.feat_dim,
        'proj_dim': cfg.proj_dim,
        'img_size': cfg.img_size
    }
}, 'baseline_ssl_pretrained.pth')

print("üíæ Pretrained model saved: baseline_ssl_pretrained.pth")

In [None]:
# ============================================
# üéØ Step 12: Fine-tuning for Classification
# ============================================
# KEY IMPROVEMENTS (inspired by DannyNet SOTA):
# 1. UNFREEZE encoder with differential learning rate
# 2. Use Focal Loss instead of BCE
# 3. Use AdamW optimizer
# 4. More aggressive LR scheduler
# ============================================

# ‚úÖ UNFREEZE encoder for fine-tuning (CRITICAL for performance!)
for param in encoder.parameters():
    param.requires_grad = True  # UNFROZEN!
encoder.train()

# Initialize classifier
classifier = Classifier(cfg.feat_dim, len(DISEASE_CATEGORIES)).to(cfg.device)

# ‚úÖ Use Focal Loss instead of weighted BCE
criterion = FocalLoss(alpha=1.0, gamma=2.0)

# ‚úÖ Differential learning rates with AdamW
encoder_lr = cfg.lr_finetune / 10  # Lower LR for pretrained encoder
classifier_lr = cfg.lr_finetune

optimizer = torch.optim.AdamW([
    {'params': encoder.parameters(), 'lr': encoder_lr},
    {'params': classifier.parameters(), 'lr': classifier_lr}
], weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', patience=2, factor=0.1, min_lr=1e-7
)

print("üîß Fine-tuning Configuration:")
print(f"   ‚úÖ Encoder: UNFROZEN with LR={encoder_lr:.2e}")
print(f"   ‚úÖ Classifier LR: {classifier_lr:.2e}")
print(f"   ‚úÖ Loss: FocalLoss (Œ±=1.0, Œ≥=2.0)")
print(f"   ‚úÖ Optimizer: AdamW")

finetune_history = {'train_loss': [], 'train_auc': [], 'val_loss': [], 'val_auc': []}
best_val_auc = 0
FINETUNE_START_EPOCH = 1

# ===== AUTO-RESUME FROM CHECKPOINT =====
if RESUME_FINETUNING:
    if FINETUNE_CHECKPOINT_FILE == "latest":
        ckpt_file = find_latest_checkpoint('baseline_finetune')
    else:
        ckpt_file = FINETUNE_CHECKPOINT_FILE
    
    if ckpt_file:
        ft_checkpoint = load_checkpoint(ckpt_file)
        if ft_checkpoint:
            classifier.load_state_dict(ft_checkpoint['classifier'])
            if 'encoder' in ft_checkpoint:
                encoder.load_state_dict(ft_checkpoint['encoder'])
            if 'optimizer' in ft_checkpoint:
                try:
                    optimizer.load_state_dict(ft_checkpoint['optimizer'])
                except:
                    print("‚ö†Ô∏è Optimizer state incompatible, starting fresh")
            finetune_history = ft_checkpoint.get('finetune_history', finetune_history)
            best_val_auc = ft_checkpoint.get('best_val_auc', 0)
            FINETUNE_START_EPOCH = ft_checkpoint['epoch'] + 1
            print(f"üîÑ Resuming fine-tuning from epoch {FINETUNE_START_EPOCH}")
    else:
        print("‚ö†Ô∏è RESUME_FINETUNING=True but no checkpoint found. Starting fresh.")

if FINETUNE_START_EPOCH > cfg.finetune_epochs:
    print(f"‚úÖ Fine-tuning already complete ({cfg.finetune_epochs} epochs)")
    print(f"   Best Val AUC: {best_val_auc:.4f}")
else:
    print(f"\nüéØ Starting Baseline Fine-tuning (ENCODER UNFROZEN)")
    print(f"   Epochs: {FINETUNE_START_EPOCH} ‚Üí {cfg.finetune_epochs}")
    print("=" * 50)
    
    SAVE_EVERY_FT = 5  # Save every 5 epochs
    
    for epoch in range(FINETUNE_START_EPOCH, cfg.finetune_epochs + 1):
        # Training - encoder is now also training!
        encoder.train()
        classifier.train()
        train_loss = 0
        train_preds, train_targets = [], []
        
        loader = tqdm(train_loader, desc=f"Train {epoch}/{cfg.finetune_epochs}") if not IN_KAGGLE else train_loader
        for images, targets in loader:
            images = images.to(cfg.device)
            targets = targets.to(cfg.device)
            
            optimizer.zero_grad()
            
            # Forward pass (encoder is trainable now)
            features = encoder(images)
            logits = classifier(features)
            loss = criterion(logits, targets)
            
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_preds.append(torch.sigmoid(logits).detach().cpu())
            train_targets.append(targets.cpu())
        
        # Validation
        encoder.eval()
        classifier.eval()
        val_loss = 0
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for images, targets in val_loader:
                images = images.to(cfg.device)
                targets = targets.to(cfg.device)
                
                features = encoder(images)
                logits = classifier(features)
                loss = criterion(logits, targets)
                
                val_loss += loss.item()
                val_preds.append(torch.sigmoid(logits).cpu())
                val_targets.append(targets.cpu())
        
        # Calculate metrics
        train_preds = torch.cat(train_preds).numpy()
        train_targets = torch.cat(train_targets).numpy()
        val_preds = torch.cat(val_preds).numpy()
        val_targets = torch.cat(val_targets).numpy()
        
        train_auc = np.mean([roc_auc_score(train_targets[:, i], train_preds[:, i]) 
                             for i in range(len(DISEASE_CATEGORIES)) 
                             if len(np.unique(train_targets[:, i])) > 1])
        val_auc = np.mean([roc_auc_score(val_targets[:, i], val_preds[:, i]) 
                           for i in range(len(DISEASE_CATEGORIES)) 
                           if len(np.unique(val_targets[:, i])) > 1])
        
        # Log
        finetune_history['train_loss'].append(train_loss / len(train_loader))
        finetune_history['train_auc'].append(train_auc)
        finetune_history['val_loss'].append(val_loss / len(val_loader))
        finetune_history['val_auc'].append(val_auc)
        
        scheduler.step(val_auc)
        
        print(f"Epoch {epoch}: Train AUC={train_auc:.4f}, Val AUC={val_auc:.4f}")
        
        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            save_checkpoint({
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'val_auc': val_auc,
                'epoch': epoch,
                'phase': 'best_model'
            }, 'baseline_best_model.pth')
            print(f"  ‚úÖ Best model saved! Val AUC: {val_auc:.4f}")
        
        # Save periodic checkpoints
        if epoch % SAVE_EVERY_FT == 0 or epoch == cfg.finetune_epochs:
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'optimizer': optimizer.state_dict(),
                'finetune_history': finetune_history,
                'best_val_auc': best_val_auc,
                'phase': 'finetuning'
            }, 'baseline_finetune_latest.pth')
            
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'finetune_history': finetune_history,
                'best_val_auc': best_val_auc,
                'phase': 'finetuning'
            }, f'baseline_finetune_epoch{epoch}.pth')
    
    print(f"\nüèÜ Best Validation AUC: {best_val_auc:.4f}")

In [None]:
# ============================================
# üìä Step 13: Plot Fine-tuning Curves
# ============================================

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(finetune_history['train_loss'], 'b-', label='Train', linewidth=2)
axes[0].plot(finetune_history['val_loss'], 'r-', label='Val', linewidth=2)
axes[0].set_title('Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(finetune_history['train_auc'], 'b-', label='Train', linewidth=2)
axes[1].plot(finetune_history['val_auc'], 'r-', label='Val', linewidth=2)
axes[1].set_title('Mean AUC', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('Baseline Fine-tuning Curves', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('baseline_finetune_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üìà Step 14: Final Evaluation on TEST SET
# ============================================
from sklearn.metrics import precision_recall_curve, f1_score, precision_score, recall_score

# Load best model
best_model_path = os.path.join(CHECKPOINT_DIR, 'baseline_best_model.pth')
checkpoint = torch.load(best_model_path, weights_only=False)
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

encoder.eval()
classifier.eval()

# Evaluate on TEST set (not validation!)
print("="*60)
print("üìä TEST SET EVALUATION")
print("="*60)

all_preds, all_targets = [], []
with torch.no_grad():
    loader = tqdm(test_loader, desc="Evaluating on TEST set") if not IN_KAGGLE else test_loader
    for images, targets in loader:
        images = images.to(cfg.device)
        features = encoder(images)
        logits = classifier(features)
        all_preds.append(torch.sigmoid(logits).cpu())
        all_targets.append(targets)

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

# Find optimal thresholds per disease
print("\nüéØ OPTIMAL THRESHOLDS:")
print("-"*40)
optimal_thresholds = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        precision, recall, thresholds = precision_recall_curve(all_targets[:, i], all_preds[:, i])
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
        best_idx = np.argmax(f1_scores)
        best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5
    else:
        best_threshold = 0.5
    optimal_thresholds.append(best_threshold)
    print(f"{disease:20s}: {best_threshold:.3f}")

# Per-disease metrics with optimal thresholds
print("\nüìä PER-DISEASE METRICS (TEST SET):")
print("="*60)
auc_scores = []
f1_scores_list = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
        pred_binary = (all_preds[:, i] > optimal_thresholds[i]).astype(int)
        f1 = f1_score(all_targets[:, i], pred_binary)
        prec = precision_score(all_targets[:, i], pred_binary, zero_division=0)
        rec = recall_score(all_targets[:, i], pred_binary, zero_division=0)
        auc_scores.append(auc)
        f1_scores_list.append(f1)
        print(f"{disease:20s}: AUC={auc:.4f} | F1={f1:.4f} | Prec={prec:.4f} | Rec={rec:.4f}")

mean_auc = np.mean(auc_scores)
mean_f1 = np.mean(f1_scores_list)

print("\n" + "="*60)
print(f"üèÜ TEST SET RESULTS:")
print(f"   Mean AUC: {mean_auc:.4f}")
print(f"   Mean F1:  {mean_f1:.4f}")
print("="*60)

# Plot AUC bar chart
auc_data = list(zip(DISEASE_CATEGORIES, auc_scores))
auc_data.sort(key=lambda x: x[1], reverse=True)
diseases, aucs = zip(*auc_data)

plt.figure(figsize=(12, 6))
colors = ['green' if a >= 0.7 else 'orange' if a >= 0.6 else 'red' for a in aucs]
plt.barh(diseases, aucs, color=colors, alpha=0.8)
plt.axvline(0.5, color='red', linestyle='--', alpha=0.5, label='Random')
plt.axvline(mean_auc, color='blue', linestyle='--', alpha=0.7, label=f'Mean: {mean_auc:.3f}')
plt.xlabel('AUC Score')
plt.title('Baseline: Per-Disease AUC Performance (TEST SET)', fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
plt.savefig('baseline_auc_performance.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üìù Summary
# ============================================

print("\n" + "=" * 60)
print("üìù BASELINE SSL SUMMARY")
print("=" * 60)
print(f"\nMethod: SimCLR (NT-Xent + Reconstruction) + Unfrozen Fine-tuning")
print(f"Dataset: NIH Chest X-ray 14")
print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")
print(f"Test samples: {len(test_df):,}")
print(f"\nPretraining epochs: {cfg.pretrain_epochs}")
print(f"Fine-tuning epochs: {cfg.finetune_epochs}")
print(f"\nüîß Key Improvements Applied:")
print(f"   ‚úÖ Patient-level train/val/test splits")
print(f"   ‚úÖ Unfrozen encoder during fine-tuning")
print(f"   ‚úÖ Focal Loss for class imbalance")
print(f"   ‚úÖ AdamW optimizer with differential LR")
print(f"   ‚úÖ Training augmentation")
print(f"   ‚úÖ Per-disease optimal thresholds")
print(f"   ‚úÖ Fast DataLoader pipeline")
print(f"\nüèÜ TEST SET Mean AUC: {mean_auc:.4f}")
print(f"üèÜ TEST SET Mean F1:  {mean_f1:.4f}")
print("\nFiles saved:")
print("  - baseline_ssl_pretrained.pth")
print("  - baseline_best_model.pth")
print("  - baseline_ssl_curves.png")
print("  - baseline_finetune_curves.png")
print("  - baseline_auc_performance.png")
print("=" * 60)