In [None]:
# ============================================
# üì¶ Step 1: Import Libraries
# ============================================

import warnings
warnings.filterwarnings('ignore')

import os
os.environ['OPENCV_LOG_LEVEL'] = 'SILENT'  # Suppress libpng ICC warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import cv2
import random
from tqdm import tqdm
from pathlib import Path
import kagglehub

# Set random seeds
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully")

In [None]:
# ============================================
# üìÅ Step 2: Download and Load Dataset
# ============================================

path = kagglehub.dataset_download("khanfashee/nih-chest-x-ray-14-224x224-resized")
BASE_PATH = Path(path)
print(f"üìÇ Dataset path: {BASE_PATH}")

df = pd.read_csv(BASE_PATH / "Data_Entry_2017.csv")
images_dir = BASE_PATH / "images-224" / "images-224"
df["Image Path"] = [str(images_dir / p) for p in df["Image Index"].values]

DISEASE_CATEGORIES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

for disease in DISEASE_CATEGORIES:
    df[disease] = df['Finding Labels'].apply(lambda x: 1 if disease in x else 0)

# Validate sample images
sample_paths = df['Image Path'].sample(200, random_state=42).values
missing = [p for p in sample_paths if not os.path.exists(p)]
if missing:
    raise FileNotFoundError(f"‚ùå Missing {len(missing)} images! First 3: {missing[:3]}")

print(f"‚úÖ Loaded {len(df):,} images")
print(f"üìä Disease categories: {len(DISEASE_CATEGORIES)}")

In [None]:
# ============================================
# ‚öôÔ∏è Step 3: Configuration
# ============================================

class Config:
    # Model
    img_size = 224
    # Encoder backbone selection
    encoder_backbone = 'custom'  # 'custom' or 'mobilenet_v2'

    feat_dim = 256
    proj_dim = 128
    
    # Training
    batch_size = 32  # Reduced from 64 to avoid OOM
    pretrain_epochs = 50
    finetune_epochs = 30
    lr_pretrain = 1e-3
    lr_finetune = 1e-4
    temperature = 0.1
    region_weight = 0.5  # Weight for region-specific loss
    
    # Data
    num_workers = 4
    use_subset = False
    subset_size = 10000
    
    # Regions
    num_vertical_regions = 3  # upper, middle, lower
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = Config()

# Memory optimization
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

print("‚öôÔ∏è Configuration:")
print(f"   Device: {cfg.device}")
print(f"   Batch size: {cfg.batch_size}")
print(f"   Num anatomical regions: {cfg.num_vertical_regions * 2 + 1}")
print(f"   (Upper/Middle/Lower √ó Left/Right + Central)")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# ============================================
# ‚öôÔ∏è Step 3: Configuration
# ============================================

class Config:
    # Model
    img_size = 224
    # Encoder backbone selection
    encoder_backbone = 'custom'  # 'custom' or 'mobilenet_v2'

    feat_dim = 256
    proj_dim = 128
    
    # Training
    batch_size = 32  # Reduced from 64 to avoid OOM
    pretrain_epochs = 50
    finetune_epochs = 30
    lr_pretrain = 1e-3
    lr_finetune = 1e-4
    temperature = 0.1
    region_weight = 0.5  # Weight for region-specific loss
    
    # Data
    num_workers = 4
    use_subset = False
    subset_size = 10000
    
    # Regions
    num_vertical_regions = 3  # upper, middle, lower
    
    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = Config()

# Memory optimization
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

print("‚öôÔ∏è Configuration:")
print(f"   Device: {cfg.device}")
print(f"   Batch size: {cfg.batch_size}")
print(f"   Encoder backbone: {cfg.encoder_backbone}")
print(f"   Num anatomical regions: {cfg.num_vertical_regions * 2 + 1}")
print(f"   (Upper/Middle/Lower √ó Left/Right + Central)")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")


In [None]:
# ============================================
# üó∫Ô∏è Step 4: Multi-Region Segmentation
# ============================================

def multi_region_segmentation(image, num_vertical_regions=3):
    """
    Segment chest X-ray into anatomical regions
    
    Regions:
    - Vertical: Upper, Middle, Lower lung fields
    - Horizontal: Left, Right hemithorax
    - Central: Mediastinum/heart area
    
    Args:
        image: Grayscale image (H, W) or (1, H, W)
        num_vertical_regions: Number of vertical divisions (default: 3)
    
    Returns:
        Dictionary with region masks and metadata
    """
    if len(image.shape) == 3 and image.shape[0] == 1:
        image = image[0]
    
    h, w = image.shape
    regions = {}
    region_masks = {}
    
    # Vertical regions (upper/middle/lower lung fields)
    region_height = h // num_vertical_regions
    vert_names = ['upper', 'middle', 'lower']
    
    for i in range(num_vertical_regions):
        mask = np.zeros_like(image)
        start_h = i * region_height
        end_h = h if i == num_vertical_regions - 1 else (i + 1) * region_height
        mask[start_h:end_h, :] = 1.0
        
        region_name = f'vert_{vert_names[i]}'
        regions[region_name] = image[start_h:end_h, :]
        region_masks[region_name] = mask
    
    # Horizontal regions (left/right hemithorax)
    left_mask = np.zeros_like(image)
    right_mask = np.zeros_like(image)
    left_mask[:, :w//2] = 1.0
    right_mask[:, w//2:] = 1.0
    
    regions['horiz_left'] = image[:, :w//2]
    regions['horiz_right'] = image[:, w//2:]
    region_masks['horiz_left'] = left_mask
    region_masks['horiz_right'] = right_mask
    
    # Central region (mediastinum/heart)
    central_mask = np.zeros_like(image)
    central_mask[h//3:2*h//3, w//3:2*w//3] = 1.0
    regions['central_mediastinum'] = image[h//3:2*h//3, w//3:2*w//3]
    region_masks['central_mediastinum'] = central_mask
    
    return {
        'regions': regions,
        'masks': region_masks,
        'region_names': list(region_masks.keys())
    }

print("‚úÖ Multi-region segmentation function defined")
print("   6 anatomical regions: upper/middle/lower √ó left/right + mediastinum")

In [None]:
# ============================================
# üëÅÔ∏è Step 5: Visualize Region Segmentation
# ============================================

sample_indices = [0, 100, 500]
fig = plt.figure(figsize=(16, 4*len(sample_indices)))

color_map = {
    'vert_upper': (1, 0, 0),      # Red
    'vert_middle': (0, 1, 0),     # Green
    'vert_lower': (0, 0, 1),      # Blue
    'horiz_left': (1, 1, 0),      # Yellow
    'horiz_right': (1, 0, 1),     # Magenta
    'central_mediastinum': (0, 1, 1)  # Cyan
}

for row_idx, sample_idx in enumerate(sample_indices):
    img_path = df.iloc[sample_idx]['Image Path']
    img = Image.open(img_path).convert('L')
    img = img.resize((cfg.img_size, cfg.img_size), Image.LANCZOS)
    img_np = np.array(img, dtype=np.float32) / 255.0
    
    # Get region segmentation
    seg_result = multi_region_segmentation(img_np)
    
    # Plot original
    ax = plt.subplot(len(sample_indices), 2, row_idx*2 + 1)
    ax.imshow(img_np, cmap='gray')
    ax.set_title(f'Original Image {sample_idx}')
    ax.axis('off')
    
    # Plot regions with colors
    ax = plt.subplot(len(sample_indices), 2, row_idx*2 + 2)
    region_overlay = np.zeros((*img_np.shape, 3))
    
    for region_name, mask in seg_result['masks'].items():
        color = color_map.get(region_name, (0.5, 0.5, 0.5))
        for c in range(3):
            region_overlay[:, :, c] += mask * color[c] * 0.4
    
    # Add original image
    for c in range(3):
        region_overlay[:, :, c] += img_np * 0.6
    
    region_overlay = np.clip(region_overlay, 0, 1)
    ax.imshow(region_overlay)
    ax.set_title(f'6 Anatomical Regions')
    ax.axis('off')

plt.suptitle('Option 2: Multi-Region Segmentation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('option2_region_segmentation.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Region visualization complete")

In [None]:
# ============================================
# üîÑ Step 6: Augmentation
# ============================================

class ChestXrayAugment:
    def __init__(self, img_size=224):
        self.img_size = img_size
    
    def __call__(self, img):
        if isinstance(img, np.ndarray):
            x = torch.tensor(img, dtype=torch.float32)
        else:
            x = img.clone()
        
        if random.random() < 0.5:
            x = torch.flip(x, dims=[2])
        
        if random.random() < 0.7:
            angle = random.uniform(-15, 15)
            x = transforms.functional.rotate(x, angle)
        
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_brightness(x, factor)
        
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_contrast(x, factor)
        
        if random.random() < 0.5:
            noise = torch.randn_like(x) * 0.05
            x = torch.clamp(x + noise, 0, 1)
        
        return x

augment = ChestXrayAugment(cfg.img_size)
print("‚úÖ Augmentation pipeline ready")

In [None]:
# ============================================
# üì¶ Step 7: Dataset Classes
# ============================================

class MultiRegionPretrainDataset(Dataset):
    def __init__(self, df, transform=None, img_size=224, num_vertical_regions=3):
        self.df = df.copy().reset_index(drop=True)
        self.transform = transform
        self.img_size = img_size
        self.num_vertical_regions = num_vertical_regions
        
        sample_paths = self.df['Image Path'].sample(min(200, len(self.df)), random_state=42).values
        missing = [p for p in sample_paths if not os.path.exists(p)]
        if missing:
            raise FileNotFoundError(f"‚ùå Missing {len(missing)} images!")
        
        print(f"üì¶ MultiRegionPretrainDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Image Path']
        
        img = Image.open(img_path).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = np.expand_dims(img, 0)
        
        # Get region masks
        seg_result = multi_region_segmentation(img, self.num_vertical_regions)
        region_masks = {}
        for name, mask in seg_result['masks'].items():
            region_masks[name] = torch.tensor(mask[None, ...], dtype=torch.float32)
        
        # Augmented views
        if self.transform:
            view1 = self.transform(img)
            view2 = self.transform(img)
        else:
            view1 = torch.tensor(img, dtype=torch.float32)
            view2 = torch.tensor(img, dtype=torch.float32)
        
        return view1, view2, region_masks


class ClassificationDataset(Dataset):
    """Classification dataset WITH augmentation support for fine-tuning"""
    def __init__(self, df, disease_categories, img_size=224, is_training=False):
        self.df = df.copy().reset_index(drop=True)
        self.disease_categories = disease_categories
        self.img_size = img_size
        self.is_training = is_training  # ‚úÖ Augmentation during fine-tuning!
        print(f"üì¶ ClassificationDataset: {len(self.df)} samples (training={is_training})")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['Image Path']).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        
        # ‚úÖ Apply augmentation during training (like DannyNet)
        if self.is_training:
            # Random horizontal flip
            if np.random.random() > 0.5:
                img = np.fliplr(img).copy()
            # Random brightness
            img = img * (0.8 + 0.4 * np.random.random())
            # Random contrast
            mean = img.mean()
            img = (img - mean) * (0.8 + 0.4 * np.random.random()) + mean
            # Random rotation (small)
            if np.random.random() > 0.5:
                angle = np.random.uniform(-10, 10)
                img = rotate(img, angle, reshape=False, mode='constant', cval=0)
            img = np.clip(img, 0, 1)
        
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        labels = torch.tensor([row[d] for d in self.disease_categories], dtype=torch.float32)
        return img, labels

print("‚úÖ Dataset classes defined (with training augmentation support)")

In [None]:
class Classifier(nn.Module):
    def __init__(self, feat_dim=256, num_classes=14):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

# Initialize models with backbone selection
if cfg.encoder_backbone == 'mobilenet_v2':
    encoder = MobileNetV2Encoder(feat_dim=cfg.feat_dim, pretrained=True).to(cfg.device)
    print(f"‚úÖ Using MobileNetV2 encoder backbone")
else:
    encoder = Encoder(feat_dim=cfg.feat_dim).to(cfg.device)
    print(f"‚úÖ Using custom CNN encoder backbone")

proj_head = ProjectionHead(cfg.feat_dim, cfg.proj_dim).to(cfg.device)
decoder = Decoder(cfg.feat_dim, cfg.img_size).to(cfg.device)

total_params = sum(p.numel() for m in [encoder, proj_head, decoder] for p in m.parameters())
print(f"‚úÖ Models initialized ({total_params:,} parameters)")


In [None]:
# ============================================
# üî• Step 9: Loss Functions
# ============================================

def nt_xent_loss(z1, z2, temperature=0.1):
    device = z1.device
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    
    batch_size = z1.shape[0]
    representations = torch.cat([z1, z2], dim=0)
    similarity = torch.matmul(representations, representations.T) / temperature
    
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    similarity = similarity.masked_fill(mask, -float('inf'))
    
    labels = torch.cat([torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)]).to(device)
    
    return F.cross_entropy(similarity, labels)


def region_aware_loss(proj_1, proj_2, region_masks_1, region_masks_2, 
                      temperature=0.1, region_weight=0.5):
    """
    üó∫Ô∏è KEY INNOVATION: Region-aware contrastive loss
    
    - Standard NT-Xent loss as base
    - Emphasizes samples with diverse anatomical region information
    - Higher weight for multi-region pathology patterns
    """
    # Base contrastive loss
    base_loss = nt_xent_loss(proj_1, proj_2, temperature)
    
    # Calculate region coverage - how many regions have significant signal
    all_masks = list(region_masks_1.values()) + list(region_masks_2.values())
    region_scores = [m.mean() for m in all_masks]
    avg_coverage = np.mean(region_scores)
    
    # Weight by anatomical completeness
    weight_factor = 1.0 + region_weight * (avg_coverage - 0.5) * 2
    
    return base_loss * weight_factor


class FocalLoss(nn.Module):
    """
    ‚≠ê Focal Loss for handling class imbalance (from DannyNet SOTA)
    FL(p_t) = -alpha * (1 - p_t)^gamma * log(p_t)
    """
    def __init__(self, alpha=1.0, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss.mean()


print("‚úÖ Loss functions defined")
print("   üó∫Ô∏è region_aware_loss: Weights by anatomical region coverage")
print("   ‚≠ê FocalLoss: For class imbalance (Œ±=1.0, Œ≥=2.0)")

In [None]:
# ============================================
# üìä Step 10: Create Data Loaders (Patient-Level Split)
# ============================================

# ‚ö†Ô∏è CRITICAL: Patient-level splitting to prevent data leakage
# Same patient's images must stay in the same split
print("="*60)
print("üîÄ PATIENT-LEVEL SPLITTING")
print("="*60)

unique_patients = df['Patient ID'].unique()
print(f"Total unique patients: {len(unique_patients):,}")

# Split patients: 93% train, 5% val, 2% test
train_val_patients, test_patients = train_test_split(
    unique_patients, test_size=0.02, random_state=42
)
train_patients, val_patients = train_test_split(
    train_val_patients, test_size=0.052, random_state=42  # ~5% of total
)

# Create dataframes based on patient splits
train_df = df[df['Patient ID'].isin(train_patients)].copy()
val_df = df[df['Patient ID'].isin(val_patients)].copy()
test_df = df[df['Patient ID'].isin(test_patients)].copy()

print(f"‚úì Train: {len(train_df):,} images from {len(train_patients):,} patients")
print(f"‚úì Val: {len(val_df):,} images from {len(val_patients):,} patients")
print(f"‚úì Test: {len(test_df):,} images from {len(test_patients):,} patients")
print("="*60)

if cfg.use_subset:
    train_df = train_df.head(cfg.subset_size)
    val_df = val_df.head(cfg.subset_size // 4)
    test_df = test_df.head(cfg.subset_size // 8)
    print(f"‚ö° Using subset: {len(train_df)} train, {len(val_df)} val, {len(test_df)} test")

# Datasets - NOW WITH AUGMENTATION FOR TRAINING
train_pretrain_ds = MultiRegionPretrainDataset(train_df, transform=augment, img_size=cfg.img_size)
train_class_ds = ClassificationDataset(train_df, DISEASE_CATEGORIES, cfg.img_size, is_training=True)  # ‚úÖ Augmentation ON
val_class_ds = ClassificationDataset(val_df, DISEASE_CATEGORIES, cfg.img_size, is_training=False)
test_class_ds = ClassificationDataset(test_df, DISEASE_CATEGORIES, cfg.img_size, is_training=False)

# DataLoaders - FAST PIPELINE (like tf.data)
# üöÄ num_workers: Parallel data loading (like num_parallel_calls)
# üöÄ pin_memory: Faster CPU‚ÜíGPU transfer  
# üöÄ prefetch_factor: Prefetch batches per worker (like prefetch)
# üöÄ persistent_workers: Keep workers alive between epochs
pretrain_loader = DataLoader(
    train_pretrain_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
train_loader = DataLoader(
    train_class_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
val_loader = DataLoader(
    val_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)
test_loader = DataLoader(
    test_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True,
    prefetch_factor=2, persistent_workers=True if cfg.num_workers > 0 else False
)

print(f"‚úÖ DataLoaders ready - FAST PIPELINE (with training augmentation)")
print(f"   Train batches: {len(pretrain_loader)}")
print(f"   Test batches: {len(test_loader)}")

In [None]:
# ============================================
# üöÄ Step 11: Region-Aware SSL Pretraining
# ============================================

# Clear GPU cache before training
if torch.cuda.is_available():
    torch.cuda.empty_cache()

optimizer_ssl = torch.optim.Adam(
    list(encoder.parameters()) + list(proj_head.parameters()) + list(decoder.parameters()),
    lr=cfg.lr_pretrain, weight_decay=1e-4
)

ssl_history = {'loss': [], 'contrastive': [], 'reconstruction': []}
START_EPOCH = 1

if RESUME_SSL_PRETRAINING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_ssl') if SSL_CHECKPOINT_FILE == "latest" else SSL_CHECKPOINT_FILE
    if ckpt_file:
        checkpoint = load_checkpoint(ckpt_file)
        if checkpoint:
            encoder.load_state_dict(checkpoint['encoder'])
            proj_head.load_state_dict(checkpoint['proj_head'])
            decoder.load_state_dict(checkpoint['decoder'])
            if 'optimizer' in checkpoint:
                optimizer_ssl.load_state_dict(checkpoint['optimizer'])
            ssl_history = checkpoint.get('ssl_history', ssl_history)
            START_EPOCH = checkpoint['epoch'] + 1
            print(f"üîÑ Resuming from epoch {START_EPOCH}")
    else:
        print("‚ö†Ô∏è RESUME_SSL_PRETRAINING=True but no checkpoint found. Starting fresh.")

if START_EPOCH > cfg.pretrain_epochs:
    print(f"‚úÖ SSL Pretraining already complete ({cfg.pretrain_epochs} epochs)")
else:
    print(f"\nüöÄ Starting Option 2: Region-Aware SSL Pretraining")
    print(f"   Epochs: {START_EPOCH} ‚Üí {cfg.pretrain_epochs}")
    print("=" * 60)
    
    SAVE_EVERY = 1
    
    for epoch in range(START_EPOCH, cfg.pretrain_epochs + 1):
        encoder.train()
        proj_head.train()
        decoder.train()
        
        total_loss = 0
        total_cont = 0
        total_recon = 0
        
        loader = tqdm(pretrain_loader, desc=f"Epoch {epoch}/{cfg.pretrain_epochs}") if not IN_KAGGLE else pretrain_loader
        for view1, view2, region_masks in loader:
            view1 = view1.to(cfg.device)
            view2 = view2.to(cfg.device)
            region_masks = {k: v.to(cfg.device) for k, v in region_masks.items()}
            
            optimizer_ssl.zero_grad()
            
            z1 = encoder(view1)
            z2 = encoder(view2)
            
            p1 = proj_head(z1)
            p2 = proj_head(z2)
            cont_loss = region_aware_loss(p1, p2, region_masks, region_masks, 
                                           cfg.temperature, cfg.region_weight)
            
            rec1 = decoder(z1)
            rec2 = decoder(z2)
            recon_loss = (F.mse_loss(rec1, view1) + F.mse_loss(rec2, view2)) / 2
            
            loss = cont_loss + 0.5 * recon_loss
            
            loss.backward()
            optimizer_ssl.step()
            
            total_loss += loss.item()
            total_cont += cont_loss.item()
            total_recon += recon_loss.item()
            
            if not IN_KAGGLE:
                loader.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Free memory
            del z1, z2, p1, p2, rec1, rec2, loss, cont_loss, recon_loss
        
        # Clear cache at end of epoch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        n = len(pretrain_loader)
        ssl_history['loss'].append(total_loss / n)
        ssl_history['contrastive'].append(total_cont / n)
        ssl_history['reconstruction'].append(total_recon / n)
        
        print(f"Epoch {epoch}: Loss={total_loss/n:.4f}, Cont={total_cont/n:.4f}, Recon={total_recon/n:.4f}")
        
        if epoch % SAVE_EVERY == 0 or epoch == cfg.pretrain_epochs:
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'proj_head': proj_head.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer_ssl.state_dict(),
                'ssl_history': ssl_history,
            }, f'{OPTION_NAME}_ssl_latest.pth')
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'proj_head': proj_head.state_dict(),
                'decoder': decoder.state_dict(),
                'ssl_history': ssl_history,
            }, f'{OPTION_NAME}_ssl_epoch{epoch}.pth')
    
    print("\n‚úÖ Region-Aware SSL Pretraining Complete!")

In [None]:
# ============================================
# üìà Step 12: Plot SSL Training Curves
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ssl_history['loss'], 'b-', linewidth=2)
axes[0].set_title('Total Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].grid(True, alpha=0.3)

axes[1].plot(ssl_history['contrastive'], 'r-', linewidth=2)
axes[1].set_title('Region-Aware Contrastive Loss', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].grid(True, alpha=0.3)

axes[2].plot(ssl_history['reconstruction'], 'g-', linewidth=2)
axes[2].set_title('Reconstruction Loss', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Option 2: Region-Aware SSL Training', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('option2_ssl_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üíæ Step 13: Save Pretrained Model
# ============================================

torch.save({
    'encoder': encoder.state_dict(),
    'proj_head': proj_head.state_dict(),
    'decoder': decoder.state_dict(),
    'config': {'feat_dim': cfg.feat_dim, 'proj_dim': cfg.proj_dim}
}, 'option2_ssl_pretrained.pth')

print("üíæ Pretrained model saved: option2_ssl_pretrained.pth")

In [None]:
# ============================================
# üéØ Step 14: Fine-tuning
# ============================================
# KEY IMPROVEMENTS (inspired by DannyNet SOTA):
# 1. UNFREEZE encoder with differential learning rate
# 2. Use Focal Loss instead of BCE (handles class imbalance better)
# 3. Use AdamW optimizer (better generalization)
# 4. More aggressive LR scheduler (factor=0.1, patience=2)
# ============================================

# ‚úÖ UNFREEZE encoder for fine-tuning (CRITICAL for performance!)
for param in encoder.parameters():
    param.requires_grad = True  # UNFROZEN!
encoder.train()

classifier = Classifier(cfg.feat_dim, len(DISEASE_CATEGORIES)).to(cfg.device)

# ‚úÖ Use Focal Loss (better for imbalanced multi-label classification)
criterion = FocalLoss(alpha=1.0, gamma=2.0)

# ‚úÖ Differential learning rates with AdamW
encoder_lr = cfg.lr_finetune / 10  # 1e-5 if base is 1e-4
classifier_lr = cfg.lr_finetune    # 1e-4

optimizer_ft = torch.optim.AdamW([
    {'params': encoder.parameters(), 'lr': encoder_lr},
    {'params': classifier.parameters(), 'lr': classifier_lr}
], weight_decay=1e-4)

# ‚úÖ More aggressive scheduler (like DannyNet)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_ft, mode='max', patience=2, factor=0.1, min_lr=1e-7
)

print("üîß Fine-tuning Configuration:")
print(f"   ‚úÖ Encoder: UNFROZEN with LR={encoder_lr:.2e}")
print(f"   ‚úÖ Classifier LR: {classifier_lr:.2e}")
print(f"   ‚úÖ Loss: FocalLoss (Œ±=1.0, Œ≥=2.0)")
print(f"   ‚úÖ Optimizer: AdamW")
print(f"   ‚úÖ Scheduler: ReduceLROnPlateau (patience=2, factor=0.1)")

finetune_history = {'train_loss': [], 'train_auc': [], 'val_loss': [], 'val_auc': []}
best_val_auc = 0
FINETUNE_START_EPOCH = 1

if RESUME_FINETUNING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_finetune') if FINETUNE_CHECKPOINT_FILE == "latest" else FINETUNE_CHECKPOINT_FILE
    if ckpt_file:
        ft_checkpoint = load_checkpoint(ckpt_file)
        if ft_checkpoint:
            classifier.load_state_dict(ft_checkpoint['classifier'])
            if 'encoder' in ft_checkpoint:
                encoder.load_state_dict(ft_checkpoint['encoder'])
            if 'optimizer' in ft_checkpoint:
                try:
                    optimizer_ft.load_state_dict(ft_checkpoint['optimizer'])
                except:
                    print("‚ö†Ô∏è Optimizer state incompatible, starting fresh")
            finetune_history = ft_checkpoint.get('finetune_history', finetune_history)
            best_val_auc = ft_checkpoint.get('best_val_auc', 0)
            FINETUNE_START_EPOCH = ft_checkpoint['epoch'] + 1
            print(f"üîÑ Resuming fine-tuning from epoch {FINETUNE_START_EPOCH}")
    else:
        print("‚ö†Ô∏è RESUME_FINETUNING=True but no checkpoint found. Starting fresh.")

if FINETUNE_START_EPOCH > cfg.finetune_epochs:
    print(f"‚úÖ Fine-tuning already complete ({cfg.finetune_epochs} epochs)")
else:
    print(f"\nüéØ Starting Fine-tuning (ENCODER UNFROZEN)")
    print(f"   Epochs: {FINETUNE_START_EPOCH} ‚Üí {cfg.finetune_epochs}")
    print("=" * 50)
    
    SAVE_EVERY_FT = 5
    
    for epoch in range(FINETUNE_START_EPOCH, cfg.finetune_epochs + 1):
        encoder.train()  # Encoder is also training now!
        classifier.train()
        train_loss = 0
        train_preds, train_targets = [], []
        
        loader = tqdm(train_loader, desc=f"Train {epoch}/{cfg.finetune_epochs}") if not IN_KAGGLE else train_loader
        for images, targets in loader:
            images = images.to(cfg.device)
            targets = targets.to(cfg.device)
            
            optimizer_ft.zero_grad()
            
            # ‚úÖ NO torch.no_grad() - encoder is being fine-tuned!
            features = encoder(images)
            logits = classifier(features)
            loss = criterion(logits, targets)
            
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=1.0)
            
            optimizer_ft.step()
            
            train_loss += loss.item()
            train_preds.append(torch.sigmoid(logits).detach().cpu())
            train_targets.append(targets.cpu())
        
        encoder.eval()
        classifier.eval()
        val_loss = 0
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for images, targets in val_loader:
                images = images.to(cfg.device)
                targets = targets.to(cfg.device)
                
                features = encoder(images)
                logits = classifier(features)
                loss = criterion(logits, targets)
                
                val_loss += loss.item()
                val_preds.append(torch.sigmoid(logits).cpu())
                val_targets.append(targets.cpu())
        
        train_preds = torch.cat(train_preds).numpy()
        train_targets = torch.cat(train_targets).numpy()
        val_preds = torch.cat(val_preds).numpy()
        val_targets = torch.cat(val_targets).numpy()
        
        train_auc = np.mean([roc_auc_score(train_targets[:, i], train_preds[:, i]) 
                             for i in range(len(DISEASE_CATEGORIES)) 
                             if len(np.unique(train_targets[:, i])) > 1])
        val_auc = np.mean([roc_auc_score(val_targets[:, i], val_preds[:, i]) 
                           for i in range(len(DISEASE_CATEGORIES)) 
                           if len(np.unique(val_targets[:, i])) > 1])
        
        finetune_history['train_loss'].append(train_loss / len(train_loader))
        finetune_history['train_auc'].append(train_auc)
        finetune_history['val_loss'].append(val_loss / len(val_loader))
        finetune_history['val_auc'].append(val_auc)
        
        scheduler.step(val_auc)
        
        current_lr = optimizer_ft.param_groups[0]['lr']
        print(f"Epoch {epoch}: Train AUC={train_auc:.4f}, Val AUC={val_auc:.4f}, LR={current_lr:.2e}")
        
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            save_checkpoint({
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'val_auc': val_auc,
                'epoch': epoch,
            }, f'{OPTION_NAME}_best_model.pth')
            print(f"  ‚úÖ Best model saved! Val AUC: {val_auc:.4f}")
        
        if epoch % SAVE_EVERY_FT == 0 or epoch == cfg.finetune_epochs:
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'optimizer': optimizer_ft.state_dict(),
                'finetune_history': finetune_history,
                'best_val_auc': best_val_auc,
            }, f'{OPTION_NAME}_finetune_latest.pth')
            save_checkpoint({
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'classifier': classifier.state_dict(),
                'finetune_history': finetune_history,
                'best_val_auc': best_val_auc,
            }, f'{OPTION_NAME}_finetune_epoch{epoch}.pth')
    
    print(f"\nüèÜ Best Validation AUC: {best_val_auc:.4f}")

In [None]:
# ============================================
# üìä Step 15: Final Evaluation on Validation Set
# ============================================

from sklearn.metrics import f1_score, precision_score, recall_score

# Load best model
best_model_path = os.path.join(CHECKPOINT_DIR, f'{OPTION_NAME}_best_model.pth')
checkpoint = torch.load(best_model_path, weights_only=False)
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

encoder.eval()
classifier.eval()

all_preds, all_targets = [], []
with torch.no_grad():
    loader = tqdm(val_loader, desc="Evaluating Val") if not IN_KAGGLE else val_loader
    for images, targets in loader:
        images = images.to(cfg.device)
        features = encoder(images)
        logits = classifier(features)
        all_preds.append(torch.sigmoid(logits).cpu())
        all_targets.append(targets)

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

print("\nüìä Validation Set - Per-Disease AUC Scores:")
print("=" * 40)
val_auc_scores = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
        val_auc_scores.append((disease, auc))
        print(f"{disease:20s}: {auc:.4f}")

val_mean_auc = np.mean([a for _, a in val_auc_scores])
print(f"\n{'Val Mean AUC':20s}: {val_mean_auc:.4f}")

# ============================================
# üéØ Find Optimal Per-Disease Thresholds (like DannyNet)
# ============================================
print("\n" + "="*60)
print("üéØ OPTIMAL THRESHOLD SEARCH (per-disease)")
print("="*60)

optimal_thresholds = {}
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        best_f1 = 0
        best_thresh = 0.5
        
        # Search thresholds from 0.1 to 0.9
        for thresh in np.arange(0.1, 0.9, 0.02):
            preds_binary = (all_preds[:, i] > thresh).astype(int)
            if preds_binary.sum() > 0 and (1 - preds_binary).sum() > 0:
                f1 = f1_score(all_targets[:, i], preds_binary, zero_division=0)
                if f1 > best_f1:
                    best_f1 = f1
                    best_thresh = thresh
        
        optimal_thresholds[disease] = best_thresh
        print(f"{disease:20s}: optimal thresh = {best_thresh:.2f}, F1 = {best_f1:.4f}")

# ============================================
# üß™ Step 16: Test Set Evaluation (Held-Out)
# ============================================

print("\n" + "="*60)
print("üß™ TEST SET EVALUATION (Patient-Level Held-Out)")
print("="*60)

test_preds, test_targets = [], []
with torch.no_grad():
    loader = tqdm(test_loader, desc="Evaluating Test") if not IN_KAGGLE else test_loader
    for images, targets in loader:
        images = images.to(cfg.device)
        features = encoder(images)
        logits = classifier(features)
        test_preds.append(torch.sigmoid(logits).cpu())
        test_targets.append(targets)

test_preds = torch.cat(test_preds).numpy()
test_targets = torch.cat(test_targets).numpy()

# Test AUC with fixed threshold (0.5)
print("\nüìä Test Set - Per-Disease AUC Scores (threshold=0.5):")
print("=" * 40)
test_auc_scores = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(test_targets[:, i])) > 1:
        auc = roc_auc_score(test_targets[:, i], test_preds[:, i])
        test_auc_scores.append((disease, auc))
        print(f"{disease:20s}: {auc:.4f}")

test_mean_auc = np.mean([a for _, a in test_auc_scores])
print(f"\n{'Test Mean AUC':20s}: {test_mean_auc:.4f}")

# Test with OPTIMAL thresholds
print("\nüìä Test Set - With Optimal Thresholds (from validation):")
print("=" * 60)
print(f"{'Disease':20s} {'AUC':>8s} {'Thresh':>8s} {'F1':>8s} {'Precision':>10s} {'Recall':>8s}")
print("-" * 60)

test_f1_scores = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(test_targets[:, i])) > 1:
        auc = roc_auc_score(test_targets[:, i], test_preds[:, i])
        thresh = optimal_thresholds.get(disease, 0.5)
        preds_binary = (test_preds[:, i] > thresh).astype(int)
        
        f1 = f1_score(test_targets[:, i], preds_binary, zero_division=0)
        prec = precision_score(test_targets[:, i], preds_binary, zero_division=0)
        rec = recall_score(test_targets[:, i], preds_binary, zero_division=0)
        
        test_f1_scores.append(f1)
        print(f"{disease:20s} {auc:8.4f} {thresh:8.2f} {f1:8.4f} {prec:10.4f} {rec:8.4f}")

print("-" * 60)
print(f"{'MEAN':20s} {test_mean_auc:8.4f} {'--':>8s} {np.mean(test_f1_scores):8.4f}")

# Plot comparison
test_auc_scores.sort(key=lambda x: x[1], reverse=True)
diseases, aucs = zip(*test_auc_scores)

plt.figure(figsize=(12, 6))
colors = ['green' if a >= 0.7 else 'orange' if a >= 0.6 else 'red' for a in aucs]
plt.barh(diseases, aucs, color=colors, alpha=0.8)
plt.axvline(0.5, color='red', linestyle='--', alpha=0.5, label='Random')
plt.axvline(test_mean_auc, color='blue', linestyle='--', alpha=0.7, label=f'Test Mean: {test_mean_auc:.3f}')
plt.xlabel('AUC Score')
plt.title('Option 2: Test Set Per-Disease AUC Performance', fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
plt.savefig('option2_test_auc_performance.png', dpi=150, bbox_inches='tight')
plt.show()

# ============================================
# üìù Summary
# ============================================

print("\n" + "=" * 60)
print("üìù OPTION 2: MULTI-REGION SEGMENTATION SUMMARY")
print("=" * 60)
print(f"Method: Region-Aware SSL with 6 Anatomical Regions")
print(f"Regions: Upper/Middle/Lower √ó Left/Right + Mediastinum")
print(f"\nDataset: NIH Chest X-ray 14 (Patient-Level Split)")
print(f"  - Training: {len(train_df):,} images ({len(train_patients):,} patients)")
print(f"  - Validation: {len(val_df):,} images ({len(val_patients):,} patients)")
print(f"  - Test: {len(test_df):,} images ({len(test_patients):,} patients)")
print(f"\nResults:")
print(f"  üìà Validation Mean AUC: {val_mean_auc:.4f}")
print(f"  üß™ Test Mean AUC: {test_mean_auc:.4f}")
print("=" * 60)