## 1. Setup and Imports

In [1]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')
CUSTOM_DATA_PATH = "datasets"
os.makedirs(CUSTOM_DATA_PATH, exist_ok=True)
os.environ['KAGGLEHUB_CACHE'] = CUSTOM_DATA_PATH  # older versions
os.environ['KAGGLE_CACHE_DIR'] = CUSTOM_DATA_PATH  # some versions
os.environ['KAGGLEHUB_HOME'] = CUSTOM_DATA_PATH    # newer versions

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import kagglehub

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Enable optimizations if available
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    if hasattr(torch.backends.cudnn, 'allow_tf32'):
        torch.backends.cudnn.allow_tf32 = True
    if hasattr(torch.cuda, 'matmul'):
        torch.backends.cuda.matmul.allow_tf32 = True

print("‚úÖ Setup complete")

# ImageNet normalization for pretrained backbone
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)


Using device: cuda
‚úÖ Setup complete


## 2. Configuration

In [2]:
class CFG:
    # Data
    img_size = 224
    in_channels = 2  # grayscale image + segmentation mask
    
    # Training
    batch_size = 64
    pretrain_epochs = 50       # was 5 ‚Äî too few for SSL
    finetune_epochs = 100
    pretrain_lr = 1e-3
    finetune_lr = 1e-4
    weight_decay = 1e-5
    patience = 5               # LR scheduler patience
    
    # SSL
    temperature = 0.5
    projection_dim = 128
    
    # Device
    device = device
    
    # Subset for testing (set to None for full dataset)
    subset_size = None
    
    # Fine-tuning data fraction (1.0 = use all training data)
    # Set to e.g. 0.01, 0.1, 0.5 to fine-tune on a subset
    finetune_fraction = .1

cfg = CFG()
print(f"Configuration: img_size={cfg.img_size}, in_channels={cfg.in_channels}")
print(f"Pretrain: {cfg.pretrain_epochs} epochs, Finetune: {cfg.finetune_epochs} epochs")
print(f"Batch size: {cfg.batch_size}, AdamW weight_decay: {cfg.weight_decay}")
print(f"Fine-tune fraction: {cfg.finetune_fraction} of training data")


Configuration: img_size=224, in_channels=2
Pretrain: 50 epochs, Finetune: 100 epochs
Batch size: 64, AdamW weight_decay: 1e-05
Fine-tune fraction: 0.1 of training data


## 3. Data Loading

In [3]:
# Check environment and load data
IN_KAGGLE = os.path.exists('/kaggle/input')

if IN_KAGGLE:
    data_dir = Path('/kaggle/input/nih-chest-xrays')
    checkpoint_dir = Path('/kaggle/working/checkpoints')
else:
    data_dir = Path(kagglehub.dataset_download('nih-chest-xrays/data'))
    checkpoint_dir = Path('./checkpoints')

checkpoint_dir.mkdir(parents=True, exist_ok=True)
print(f"Data directory: {data_dir}")
print(f"Checkpoint directory: {checkpoint_dir}")

Data directory: datasets/datasets/nih-chest-xrays/data/versions/3
Checkpoint directory: checkpoints


In [4]:
# Load metadata
csv_path = data_dir / 'Data_Entry_2017.csv'
df = pd.read_csv(csv_path)

# Disease categories
disease_categories = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

# Create binary labels for each disease
for disease in disease_categories:
    df[disease] = df['Finding Labels'].apply(lambda x: 1 if disease in x else 0)

# Find image paths
image_dirs = list(data_dir.glob('images_*/images'))
if not image_dirs:
    image_dirs = [data_dir / 'images']

image_path_map = {}
for img_dir in image_dirs:
    for img_path in img_dir.glob('*.png'):
        image_path_map[img_path.name] = str(img_path)

df['Image Path'] = df['Image Index'].map(image_path_map)
df = df.dropna(subset=['Image Path'])

if cfg.subset_size:
    df = df.sample(n=min(cfg.subset_size, len(df)), random_state=42)

print(f"Total samples: {len(df)}")
print(f"Disease distribution:")
print(df[disease_categories].sum())

Total samples: 112120
Disease distribution:
Atelectasis           11559
Cardiomegaly           2776
Effusion              13317
Infiltration          19894
Mass                   5782
Nodule                 6331
Pneumonia              1431
Pneumothorax           5302
Consolidation          4667
Edema                  2303
Emphysema              2516
Fibrosis               1686
Pleural_Thickening     3385
Hernia                  227
dtype: int64


## 3.5. Checkpoint & Resume Configuration

In [5]:
# ============================================
# üíæ Checkpoint & Resume Configuration
# ============================================

import shutil
from datetime import datetime

OPTION_NAME = "option6"

# ===== RESUME CONFIGURATION =====
CHECKPOINT_DATASET_NAME = f"{OPTION_NAME}-ssl-checkpoints"  # Unique for Option 6
RESUME_SSL_PRETRAINING = True
RESUME_FINETUNING = True
SSL_CHECKPOINT_FILE = "latest"
FINETUNE_CHECKPOINT_FILE = "latest"

if IN_KAGGLE:
    CHECKPOINT_DIR = '/kaggle/working/checkpoints'
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    
    # Load checkpoints from ALL versions of the dataset
    input_path = '/kaggle/input'
    if os.path.exists(input_path):
        found_any = False
        for dataset_folder in sorted(os.listdir(input_path)):
            if dataset_folder.startswith(CHECKPOINT_DATASET_NAME):
                dataset_path = os.path.join(input_path, dataset_folder)
                if os.path.isdir(dataset_path):
                    # Check for .pth files in multiple locations
                    search_paths = [dataset_path]
                    
                    checkpoints_subdir = os.path.join(dataset_path, 'checkpoints')
                    if os.path.isdir(checkpoints_subdir):
                        search_paths.append(checkpoints_subdir)
                    
                    for item in os.listdir(dataset_path):
                        item_path = os.path.join(dataset_path, item)
                        if os.path.isdir(item_path) and item != 'checkpoints':
                            search_paths.append(item_path)
                    
                    for search_path in search_paths:
                        pth_files = [f for f in os.listdir(search_path) if f.endswith('.pth')]
                        if pth_files:
                            found_any = True
                            rel_path = os.path.relpath(search_path, input_path)
                            print(f"üìÇ Found checkpoints in: {rel_path}")
                            for f in pth_files:
                                src = os.path.join(search_path, f)
                                dst = os.path.join(CHECKPOINT_DIR, f)
                                if not os.path.exists(dst):
                                    shutil.copy2(src, dst)
                                    print(f"   üì¶ Copied: {f}")
                                else:
                                    src_time = os.path.getmtime(src)
                                    dst_time = os.path.getmtime(dst)
                                    if src_time > dst_time:
                                        shutil.copy2(src, dst)
                                        print(f"   üîÑ Updated: {f} (newer version)")
        
        if not found_any:
            print(f"‚ÑπÔ∏è No checkpoint datasets found matching: {CHECKPOINT_DATASET_NAME}*")
    
    existing = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')]
    if existing:
        print(f"‚úÖ Total checkpoints available: {len(existing)}")
    else:
        print(f"‚ÑπÔ∏è Starting fresh - no checkpoints loaded")
        
else:
    CHECKPOINT_DIR = str(checkpoint_dir)

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(state, filename):
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    state['saved_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    torch.save(state, filepath)
    print(f"üíæ Saved: {filename}")
    if IN_KAGGLE: torch.save(state, f'/kaggle/working/{filename}')

def load_checkpoint(filename):
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=cfg.device, weights_only=False)
        print(f"‚úÖ Loaded: {filename}")
        return checkpoint
    return None

def find_latest_checkpoint(prefix):
    if not os.path.exists(CHECKPOINT_DIR): return None
    latest = f'{prefix}_latest.pth'
    if os.path.exists(os.path.join(CHECKPOINT_DIR, latest)): return latest
    import re
    pattern = re.compile(rf'{prefix}_epoch(\d+)\.pth')
    max_epoch, best = -1, None
    for f in os.listdir(CHECKPOINT_DIR):
        m = pattern.match(f)
        if m and int(m.group(1)) > max_epoch: max_epoch, best = int(m.group(1)), f
    return best

print(f"üîß Environment: {'Kaggle' if IN_KAGGLE else 'Local'}")
print(f"üìÇ Checkpoint dir: {CHECKPOINT_DIR}")


üîß Environment: Local
üìÇ Checkpoint dir: checkpoints


## 4. Rule-Based Lung Segmentation

In [6]:
# ============================================
# Load Pre-Computed Lung Masks (ALL into memory)
# ============================================
# Masks pre-computed by precompute_lung_masks.ipynb
# Loading all masks upfront avoids file I/O in DataLoader workers

if IN_KAGGLE:
    PIXEL_MASK_DIR = "/kaggle/working/lung_masks/pixel_masks"
else:
    PIXEL_MASK_DIR = "./lung_masks/pixel_masks"

def load_all_pixel_masks(dataframe, mask_dir=PIXEL_MASK_DIR, img_size=224):
    """Bulk-load ALL pixel-level lung masks into a dict keyed by Image Index."""
    masks = {}
    missing = 0
    for img_name in tqdm(dataframe["Image Index"], desc="Loading pixel masks into memory"):
        mask_name = img_name.replace(".png", "")
        mask_path = os.path.join(mask_dir, f"{mask_name}.npy")
        if os.path.exists(mask_path):
            mask = np.load(mask_path)
            if mask.dtype == np.uint8:
                mask = mask.astype(np.float32) / 255.0
            if mask.shape[0] != img_size or mask.shape[1] != img_size:
                mask = cv2.resize(mask, (img_size, img_size))
            masks[img_name] = mask
        else:
            masks[img_name] = np.zeros((img_size, img_size), dtype=np.float32)
            missing += 1
    if missing > 0:
        print(f"‚ö†Ô∏è {missing} masks not found, using zero fallback")
    print(f"‚úÖ Loaded {len(masks)} pixel masks into memory from {mask_dir}")
    return masks

print(f"Pixel mask directory: {PIXEL_MASK_DIR}")


Pixel mask directory: ./lung_masks/pixel_masks


## 5. Dataset Classes

In [7]:
class SSLAugmentation:
    """Augmentations for SSL pretraining."""
    def __init__(self, img_size=224):
        self.img_size = img_size
    
    def __call__(self, image, mask):
        # Random horizontal flip (both image and mask)
        if np.random.random() > 0.5:
            image = np.fliplr(image).copy()
            mask = np.fliplr(mask).copy()
        
        # Random rotation
        if np.random.random() > 0.3:
            angle = np.random.uniform(-15, 15)
            h, w = image.shape[:2]
            M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
            image = cv2.warpAffine(image, M, (w, h), borderMode=cv2.BORDER_REPLICATE)
            mask = cv2.warpAffine(mask, M, (w, h), borderMode=cv2.BORDER_CONSTANT)
        
        # Random resized crop (0.8-1.0)
        if np.random.random() > 0.3:
            h, w = image.shape[:2]
            crop_scale = np.random.uniform(0.8, 1.0)
            ch, cw = int(h * crop_scale), int(w * crop_scale)
            top = np.random.randint(0, h - ch + 1)
            left = np.random.randint(0, w - cw + 1)
            image = cv2.resize(image[top:top+ch, left:left+cw], (self.img_size, self.img_size))
            mask = cv2.resize(mask[top:top+ch, left:left+cw], (self.img_size, self.img_size))
        
        # Brightness + contrast (image only)
        if np.random.random() > 0.3:
            alpha = np.random.uniform(0.8, 1.2)
            beta = np.random.uniform(-0.1, 0.1)
            image = np.clip(alpha * image + beta, 0, 1)
        
        # Gaussian noise (image only)
        if np.random.random() > 0.5:
            noise = np.random.normal(0, 0.02, image.shape).astype(np.float32)
            image = np.clip(image + noise, 0, 1)
        
        mask = (mask > 0.5).astype(np.float32)
        return image.astype(np.float32), mask.astype(np.float32)


def to_2ch_normalized(image, mask):
    """Stack image+mask as 2-channel tensor with ImageNet-inspired normalization."""
    gray_mean = 0.449
    gray_std = 0.226
    image_norm = (image - gray_mean) / gray_std
    stacked = np.stack([image_norm, mask], axis=0)  # (2, H, W)
    return torch.from_numpy(stacked).float()


class SSLPretrainDataset(Dataset):
    """SSL pretraining dataset ‚Äî uses pre-loaded masks from memory (no disk I/O in workers)."""
    
    def __init__(self, df, preloaded_masks, img_size=224):
        self.df = df.reset_index(drop=True)
        self.img_size = img_size
        self.augmentation = SSLAugmentation(img_size)
        self.paths = df['Image Path'].tolist()
        self.img_names = df['Image Index'].tolist()
        self.preloaded_masks = preloaded_masks
        print(f"üì¶ SSLPretrainDataset: {len(self.df)} samples (masks in memory)")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.paths[idx], cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img.astype(np.float32) / 255.0
        
        # Get mask from memory (no file I/O!)
        mask = self.preloaded_masks[self.img_names[idx]].copy()
        if mask.shape != img.shape:
            mask = cv2.resize(mask, (self.img_size, self.img_size))
        
        # Two augmented views
        img1, mask1 = self.augmentation(img.copy(), mask.copy())
        img2, mask2 = self.augmentation(img.copy(), mask.copy())
        
        view1 = to_2ch_normalized(img1, mask1)
        view2 = to_2ch_normalized(img2, mask2)
        return view1, view2


class ClassificationDataset(Dataset):
    """Classification dataset ‚Äî uses pre-loaded masks from memory (no disk I/O in workers)."""
    
    def __init__(self, df, disease_categories, preloaded_masks, img_size=224, augment=False):
        self.df = df.reset_index(drop=True)
        self.disease_categories = disease_categories
        self.img_size = img_size
        self.augment = augment
        self.labels = torch.tensor(df[disease_categories].values.astype(np.float32))
        self.paths = df['Image Path'].tolist()
        self.img_names = df['Image Index'].tolist()
        self.preloaded_masks = preloaded_masks
        print(f"üì¶ ClassificationDataset: {len(self.df)} samples (augment={augment}, masks in memory)")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.paths[idx], cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = np.zeros((self.img_size, self.img_size), dtype=np.uint8)
        img = cv2.resize(img, (self.img_size, self.img_size))
        img = img.astype(np.float32) / 255.0
        
        # Get mask from memory (no file I/O!)
        mask = self.preloaded_masks[self.img_names[idx]].copy()
        if mask.shape != img.shape:
            mask = cv2.resize(mask, (self.img_size, self.img_size))
        
        if self.augment:
            if np.random.random() > 0.5:
                img = np.fliplr(img).copy()
                mask = np.fliplr(mask).copy()
            alpha = np.random.uniform(0.9, 1.1)
            beta = np.random.uniform(-0.05, 0.05)
            img = np.clip(alpha * img + beta, 0, 1)
            if np.random.random() > 0.5:
                h, w = img.shape[:2]
                crop_scale = np.random.uniform(0.85, 1.0)
                ch, cw = int(h * crop_scale), int(w * crop_scale)
                top = np.random.randint(0, h - ch + 1)
                left = np.random.randint(0, w - cw + 1)
                img = cv2.resize(img[top:top+ch, left:left+cw], (self.img_size, self.img_size))
                mask = cv2.resize(mask[top:top+ch, left:left+cw], (self.img_size, self.img_size))
        
        mask = (mask > 0.5).astype(np.float32)
        tensor = to_2ch_normalized(img, mask)
        return tensor, self.labels[idx]


print("‚úÖ Dataset classes defined (masks loaded from memory, no disk I/O in workers)")


‚úÖ Dataset classes defined (masks loaded from memory, no disk I/O in workers)


## 6. Model Architecture

In [8]:
def get_resnet50_multichannel(in_channels=2, pretrained=True):
    """
    Create a ResNet50 backbone modified for 2-channel input.
    Adapts pretrained weights by averaging RGB channels.
    """
    model = models.resnet50(pretrained=pretrained)
    
    # Get original first conv layer
    original_conv = model.conv1
    
    # Create new conv layer with desired input channels
    new_conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=original_conv.out_channels,
        kernel_size=original_conv.kernel_size,
        stride=original_conv.stride,
        padding=original_conv.padding,
        bias=original_conv.bias is not None
    )
    
    # Initialize weights from pretrained model
    with torch.no_grad():
        if pretrained:
            # Average the RGB weights and replicate for each input channel
            original_weights = original_conv.weight.data
            avg_weight = original_weights.mean(dim=1, keepdim=True)
            new_conv.weight.data = torch.cat([avg_weight] * in_channels, dim=1)
            
            if original_conv.bias is not None:
                new_conv.bias.data = original_conv.bias.data.clone()
    
    model.conv1 = new_conv
    return model


class ProjectionHead(nn.Module):
    """Projection head for contrastive learning."""
    
    def __init__(self, in_dim, hidden_dim=512, out_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
    
    def forward(self, x):
        return self.net(x)


class SSLModel(nn.Module):
    """SSL model with ResNet50 backbone and projection head."""
    
    def __init__(self, in_channels=2, projection_dim=128, pretrained=True):
        super().__init__()
        
        # Backbone
        self.backbone = get_resnet50_multichannel(in_channels, pretrained)
        self.feature_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()
        
        # Projection head for SSL
        self.projection = ProjectionHead(self.feature_dim, out_dim=projection_dim)
    
    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection(features)
        return F.normalize(projections, dim=1)
    
    def get_features(self, x):
        return self.backbone(x)


class ClassificationModel(nn.Module):
    """Classification model using pretrained SSL backbone."""
    
    def __init__(self, ssl_model, num_classes=14, freeze_backbone=False):
        super().__init__()
        self.backbone = ssl_model.backbone
        self.feature_dim = ssl_model.feature_dim
        
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, num_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

print("‚úÖ Model classes defined")

‚úÖ Model classes defined


## 7. Contrastive Loss

In [9]:
class NTXentLoss(nn.Module):
    """NT-Xent loss for contrastive learning (SimCLR)."""
    
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, z1, z2):
        batch_size = z1.shape[0]
        
        # Concatenate representations
        z = torch.cat([z1, z2], dim=0)
        
        # Compute similarity matrix
        sim = torch.mm(z, z.t()) / self.temperature
        
        # Create mask for positive pairs
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z.device)
        
        # Mask out self-similarities
        sim.masked_fill_(mask, float('-inf'))
        
        # Labels: positive pairs are at positions batch_size apart
        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(batch_size)
        ]).to(z.device)
        
        loss = F.cross_entropy(sim, labels)
        return loss

print("‚úÖ Contrastive loss defined")


class FocalLoss(nn.Module):
    """Focal Loss for class-imbalanced multi-label classification."""
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        pt = torch.exp(-bce)
        focal = self.alpha * (1 - pt) ** self.gamma * bce
        return focal.mean()

print("‚úÖ Loss functions: NTXentLoss, FocalLoss")


‚úÖ Contrastive loss defined
‚úÖ Loss functions: NTXentLoss, FocalLoss


## 8. Training Functions

In [10]:
def pretrain_epoch(model, dataloader, optimizer, criterion, device):
    """Run one pretraining epoch."""
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc="Pretraining")
    for view1, view2 in pbar:
        view1, view2 = view1.to(device), view2.to(device)
        
        optimizer.zero_grad()
        
        z1 = model(view1)
        z2 = model(view2)
        
        loss = criterion(z1, z2)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(dataloader)


def finetune_epoch(model, dataloader, optimizer, criterion, device):
    """Run one fine-tuning epoch."""
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc="Fine-tuning")
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    return total_loss / len(dataloader)


@torch.no_grad()
def evaluate(model, dataloader, device):
    """Evaluate model and return AUC scores."""
    model.eval()
    all_preds = []
    all_labels = []
    
    for images, labels in tqdm(dataloader, desc="Evaluating"):
        images = images.to(device)
        outputs = torch.sigmoid(model(images))
        all_preds.append(outputs.cpu().numpy())
        all_labels.append(labels.numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # Calculate AUC for each disease
    aucs = []
    for i in range(all_labels.shape[1]):
        if all_labels[:, i].sum() > 0:  # Only if there are positive samples
            auc = roc_auc_score(all_labels[:, i], all_preds[:, i])
            aucs.append(auc)
        else:
            aucs.append(0.5)
    
    return np.mean(aucs), aucs

print("‚úÖ Training functions defined")

‚úÖ Training functions defined


## 9. Data Preparation

In [11]:
# Patient-level splitting (prevents data leakage)
print("üîÄ PATIENT-LEVEL SPLITTING")
print("=" * 60)

unique_patients = df['Patient ID'].unique()
print(f"Total unique patients: {len(unique_patients):,}")

train_val_patients, test_patients = train_test_split(
    unique_patients, test_size=0.02, random_state=42)
train_patients, val_patients = train_test_split(
    train_val_patients, test_size=0.052, random_state=42)

train_df = df[df['Patient ID'].isin(train_patients)].copy()
val_df = df[df['Patient ID'].isin(val_patients)].copy()
test_df = df[df['Patient ID'].isin(test_patients)].copy()

print(f"‚úì Train: {len(train_df):,} images from {len(train_patients):,} patients")
print(f"‚úì Val: {len(val_df):,} images from {len(val_patients):,} patients")
print(f"‚úì Test: {len(test_df):,} images from {len(test_patients):,} patients")

# ‚îÄ‚îÄ Subsample training data for fine-tuning (patient-level) ‚îÄ‚îÄ
if cfg.finetune_fraction < 1.0:
    n_finetune_patients = max(1, int(len(train_patients) * cfg.finetune_fraction))
    rng = np.random.RandomState(42)
    finetune_patient_indices = rng.choice(len(train_patients), size=n_finetune_patients, replace=False)
    finetune_patients = train_patients[finetune_patient_indices]
    finetune_train_df = train_df[train_df['Patient ID'].isin(finetune_patients)].copy()
    print(f"\nüî¨ FINE-TUNING DATA SUBSET")
    print(f"   Fraction: {cfg.finetune_fraction} ({cfg.finetune_fraction*100:.1f}%)")
    print(f"   Finetune: {len(finetune_train_df):,} images from {n_finetune_patients:,} patients")
    print(f"   (SSL pretraining still uses all {len(train_df):,} training images)")
else:
    finetune_train_df = train_df
    print(f"\nüî¨ Fine-tuning uses all {len(train_df):,} training images (fraction=1.0)")

# ‚îÄ‚îÄ Bulk-load ALL masks into memory (eliminates disk I/O in DataLoader workers) ‚îÄ‚îÄ
all_masks = load_all_pixel_masks(df, img_size=cfg.img_size)

# Create datasets (pass preloaded masks)
pretrain_dataset = SSLPretrainDataset(train_df, all_masks, img_size=cfg.img_size)
train_dataset = ClassificationDataset(finetune_train_df, disease_categories, all_masks, img_size=cfg.img_size, augment=True)
val_dataset = ClassificationDataset(val_df, disease_categories, all_masks, img_size=cfg.img_size, augment=False)
test_dataset = ClassificationDataset(test_df, disease_categories, all_masks, img_size=cfg.img_size, augment=False)

# DataLoaders
# Keep num_workers low (2) to avoid OOM-induced worker crashes on Kaggle
_nw = 8
_pin = torch.cuda.is_available()
pretrain_loader = DataLoader(pretrain_dataset, batch_size=cfg.batch_size, shuffle=True,
    num_workers=_nw, pin_memory=_pin, persistent_workers=_nw > 0)
train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
    num_workers=_nw, pin_memory=_pin, persistent_workers=_nw > 0)
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
    num_workers=_nw, pin_memory=_pin, persistent_workers=_nw > 0)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
    num_workers=_nw, pin_memory=_pin, persistent_workers=_nw > 0)
print(f"üì¶ DataLoaders: num_workers={_nw}, pin_memory={_pin}")

print("‚úÖ Data prepared with patient-level splitting")


üîÄ PATIENT-LEVEL SPLITTING
Total unique patients: 30,805
‚úì Train: 103,847 images from 28,618 patients
‚úì Val: 5,974 images from 1,570 patients
‚úì Test: 2,299 images from 617 patients

üî¨ FINE-TUNING DATA SUBSET
   Fraction: 0.1 (10.0%)
   Finetune: 10,232 images from 2,861 patients
   (SSL pretraining still uses all 103,847 training images)


Loading pixel masks into memory: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 112120/112120 [00:29<00:00, 3838.64it/s]


‚úÖ Loaded 112120 pixel masks into memory from ./lung_masks/pixel_masks
üì¶ SSLPretrainDataset: 103847 samples (masks in memory)
üì¶ ClassificationDataset: 10232 samples (augment=True, masks in memory)
üì¶ ClassificationDataset: 5974 samples (augment=False, masks in memory)
üì¶ ClassificationDataset: 2299 samples (augment=False, masks in memory)
üì¶ DataLoaders: num_workers=8, pin_memory=True
‚úÖ Data prepared with patient-level splitting


## 10. SSL Pretraining

In [None]:
# Initialize SSL model
ssl_model = SSLModel(
    in_channels=cfg.in_channels,
    projection_dim=cfg.projection_dim,
    pretrained=True
).to(cfg.device)

# Loss and optimizer
ssl_criterion = NTXentLoss(temperature=cfg.temperature)
ssl_optimizer = optim.AdamW(ssl_model.parameters(), lr=cfg.pretrain_lr, weight_decay=cfg.weight_decay)
ssl_scheduler = optim.lr_scheduler.CosineAnnealingLR(ssl_optimizer, T_max=cfg.pretrain_epochs, eta_min=1e-6)

print(f"SSL Model parameters: {sum(p.numel() for p in ssl_model.parameters()):,}")
print(f"Optimizer: AdamW, LR: {cfg.pretrain_lr}, Schedule: CosineAnnealing")

pretrain_losses = []
START_EPOCH = 1

if RESUME_SSL_PRETRAINING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_ssl') if SSL_CHECKPOINT_FILE == "latest" else SSL_CHECKPOINT_FILE
    if ckpt_file:
        checkpoint = load_checkpoint(ckpt_file)
        if checkpoint:
            ssl_model.load_state_dict(checkpoint['model'])
            if 'optimizer' in checkpoint: ssl_optimizer.load_state_dict(checkpoint['optimizer'])
            if 'scheduler' in checkpoint: ssl_scheduler.load_state_dict(checkpoint['scheduler'])
            pretrain_losses = checkpoint.get('pretrain_losses', pretrain_losses)
            START_EPOCH = checkpoint['epoch'] + 1
            print(f"üîÑ Resuming SSL pretraining from epoch {START_EPOCH}")
    else:
        print("‚ö†Ô∏è No SSL checkpoint found. Starting fresh.")

if START_EPOCH > cfg.pretrain_epochs:
    print(f"‚úÖ SSL Pretraining already complete ({cfg.pretrain_epochs} epochs)")
else:
    print(f"\nüöÄ Starting SSL Pretraining!")
    print(f"   Epochs: {START_EPOCH} ‚Üí {cfg.pretrain_epochs}")
    print("=" * 60)
    SAVE_EVERY = 5
    
    for epoch in range(START_EPOCH, cfg.pretrain_epochs + 1):
        loss = pretrain_epoch(ssl_model, pretrain_loader, ssl_optimizer, ssl_criterion, cfg.device)
        ssl_scheduler.step()
        pretrain_losses.append(loss)
        print(f"Epoch {epoch}/{cfg.pretrain_epochs} - Loss: {loss:.4f} - LR: {ssl_scheduler.get_last_lr()[0]:.6f}")
        
        if epoch % SAVE_EVERY == 0 or epoch == cfg.pretrain_epochs:
            save_checkpoint({
                'epoch': epoch, 'model': ssl_model.state_dict(),
                'optimizer': ssl_optimizer.state_dict(),
                'scheduler': ssl_scheduler.state_dict(),
                'pretrain_losses': pretrain_losses,
            }, f'{OPTION_NAME}_ssl_latest.pth')
        if epoch % SAVE_EVERY == 0 or epoch == cfg.pretrain_epochs:
            save_checkpoint({
                'epoch': epoch, 'model': ssl_model.state_dict(),
                'pretrain_losses': pretrain_losses,
            }, f'{OPTION_NAME}_ssl_epoch{epoch}.pth')
    
    # Also save with the legacy filename for compatibility
    torch.save(ssl_model.state_dict(), checkpoint_dir / 'option6_ssl_pretrained.pth')
    print("\n‚úÖ SSL Pretraining complete")


SSL Model parameters: 24,620,672
Optimizer: AdamW, LR: 0.001, Schedule: CosineAnnealing
‚ö†Ô∏è No SSL checkpoint found. Starting fresh.

üöÄ Starting SSL Pretraining!
   Epochs: 1 ‚Üí 50


Pretraining:   0%|          | 0/1623 [00:00<?, ?it/s]

## 11. Fine-tuning

In [None]:
# Initialize classification model with pretrained backbone
classifier = ClassificationModel(
    ssl_model,
    num_classes=len(disease_categories),
    freeze_backbone=False
).to(cfg.device)

# FocalLoss + AdamW + ReduceLROnPlateau
criterion = FocalLoss(alpha=1, gamma=2)
optimizer = optim.AdamW([
    {'params': classifier.backbone.parameters(), 'lr': cfg.finetune_lr * 0.1},
    {'params': classifier.classifier.parameters(), 'lr': cfg.finetune_lr},
], weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=cfg.patience)

print(f"Classifier parameters: {sum(p.numel() for p in classifier.parameters()):,}")
print("Optimizer: AdamW (differential LR), Loss: FocalLoss")

best_auc = 0
train_losses = []
val_aucs = []
patience_counter = 0
FINETUNE_START_EPOCH = 1

if RESUME_FINETUNING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_finetune') if FINETUNE_CHECKPOINT_FILE == "latest" else FINETUNE_CHECKPOINT_FILE
    if ckpt_file:
        ft_ckpt = load_checkpoint(ckpt_file)
        if ft_ckpt:
            classifier.load_state_dict(ft_ckpt['classifier'])
            if 'optimizer' in ft_ckpt:
                try:
                    optimizer.load_state_dict(ft_ckpt['optimizer'])
                except:
                    print("‚ö†Ô∏è Optimizer state incompatible, starting fresh")
            if 'scheduler' in ft_ckpt:
                try:
                    scheduler.load_state_dict(ft_ckpt['scheduler'])
                except:
                    print("‚ö†Ô∏è Scheduler state incompatible, starting fresh")
            train_losses = ft_ckpt.get('train_losses', train_losses)
            val_aucs = ft_ckpt.get('val_aucs', val_aucs)
            best_auc = ft_ckpt.get('best_auc', 0)
            patience_counter = ft_ckpt.get('patience_counter', 0)
            FINETUNE_START_EPOCH = ft_ckpt['epoch'] + 1
            print(f"üîÑ Resuming fine-tuning from epoch {FINETUNE_START_EPOCH} (best AUC: {best_auc:.4f})")
    else:
        print("‚ö†Ô∏è No fine-tuning checkpoint found. Starting fresh.")

if FINETUNE_START_EPOCH > cfg.finetune_epochs:
    print(f"‚úÖ Fine-tuning already complete ({cfg.finetune_epochs} epochs)")
else:
    print(f"\nüéØ Starting Fine-tuning")
    print(f"   Epochs: {FINETUNE_START_EPOCH} ‚Üí {cfg.finetune_epochs}")
    print(f"   Training data: {len(train_loader.dataset):,} samples (fraction={cfg.finetune_fraction})")
    print("=" * 50)
    SAVE_EVERY = 5
    
    for epoch in range(FINETUNE_START_EPOCH, cfg.finetune_epochs + 1):
        train_loss = finetune_epoch(classifier, train_loader, optimizer, criterion, cfg.device)
        val_auc, _ = evaluate(classifier, val_loader, cfg.device)
        
        train_losses.append(train_loss)
        val_aucs.append(val_auc)
        scheduler.step(val_auc)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch}/{cfg.finetune_epochs} - Loss: {train_loss:.4f}, Val AUC: {val_auc:.4f}, LR: {current_lr:.2e}")
        
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(classifier.state_dict(), checkpoint_dir / 'option6_ssl_best.pth')
            save_checkpoint({
                'epoch': epoch, 'classifier': classifier.state_dict(),
                'val_auc': val_auc,
            }, f'{OPTION_NAME}_best_model.pth')
            print(f"  ‚úÖ Best model saved! Val AUC: {val_auc:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % SAVE_EVERY == 0 or epoch == cfg.finetune_epochs:
            save_checkpoint({
                'epoch': epoch, 'classifier': classifier.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'train_losses': train_losses, 'val_aucs': val_aucs,
                'best_auc': best_auc, 'patience_counter': patience_counter,
            }, f'{OPTION_NAME}_finetune_latest.pth')
        
        if patience_counter >= 10:
            print(f"Early stopping at epoch {epoch}")
            break

print(f"\n‚úÖ Fine-tuning complete. Best Val AUC: {best_auc:.4f}")


## 12. Evaluation

In [None]:
# Load best model
best_model_path = os.path.join(CHECKPOINT_DIR, f'{OPTION_NAME}_best_model.pth')
checkpoint = load_checkpoint(f'{OPTION_NAME}_best_model.pth')
if checkpoint:
    classifier.load_state_dict(checkpoint['classifier'])
    print(f"‚úÖ Loaded best model (Val AUC: {checkpoint.get('val_auc', 'N/A')})")
else:
    # Fallback to legacy path
    legacy_path = checkpoint_dir / 'option6_ssl_best.pth'
    if legacy_path.exists():
        classifier.load_state_dict(torch.load(legacy_path, map_location=cfg.device, weights_only=False))
        print('‚úÖ Loaded best model from legacy path')
    else:
        print('‚ö†Ô∏è No best model found!')

# Evaluate on test set
test_auc, disease_aucs = evaluate(classifier, test_loader, cfg.device)

print(f"\n{'='*50}")
print(f"TEST RESULTS - Option 6 SSL (Segmentation Channel)")
print(f"{'='*50}")
print(f"\nOverall Test AUC: {test_auc:.4f}")
print(f"\nPer-disease AUC scores:")
print("-" * 40)

for disease, auc in zip(disease_categories, disease_aucs):
    print(f"{disease:20s}: {auc:.4f}")

print(f"\n{'='*50}")

## 13. Visualization

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Pretraining loss
axes[0].plot(pretrain_losses, 'b-', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Contrastive Loss')
axes[0].set_title('SSL Pretraining Loss')
axes[0].grid(True, alpha=0.3)

# Fine-tuning loss
axes[1].plot(train_losses, 'g-', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('BCE Loss')
axes[1].set_title('Fine-tuning Loss')
axes[1].grid(True, alpha=0.3)

# Validation AUC
axes[2].plot(val_aucs, 'r-', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('AUC')
axes[2].set_title('Validation AUC')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(checkpoint_dir / 'option6_ssl_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

# Disease-wise AUC bar plot
fig, ax = plt.subplots(figsize=(12, 6))
colors = plt.cm.viridis(np.linspace(0, 0.8, len(disease_categories)))
bars = ax.bar(disease_categories, disease_aucs, color=colors)
ax.axhline(y=test_auc, color='red', linestyle='--', linewidth=2, label=f'Mean AUC: {test_auc:.4f}')
ax.set_xlabel('Disease')
ax.set_ylabel('AUC Score')
ax.set_title('Option 6 SSL - Per-Disease AUC Scores')
ax.set_ylim(0, 1)
plt.xticks(rotation=45, ha='right')
ax.legend()
plt.tight_layout()
plt.savefig(checkpoint_dir / 'option6_ssl_disease_aucs.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Visualizations saved")

## 14. Sample Predictions Visualization

In [None]:
# Visualize sample predictions
classifier.eval()

fig, axes = plt.subplots(2, 4, figsize=(16, 8))

for i, ax_row in enumerate(axes):
    for j, ax in enumerate(ax_row):
        idx = i * 4 + j
        sample, label = test_dataset[idx]
        
        with torch.no_grad():
            pred = torch.sigmoid(classifier(sample.unsqueeze(0).to(cfg.device)))
        
        # Show original image (channel 0)
        ax.imshow(sample[0].numpy(), cmap='gray')
        
        # Show mask as overlay
        mask_overlay = sample[1].numpy()
        ax.imshow(mask_overlay, cmap='Reds', alpha=0.3)
        
        # Get top predictions
        pred_np = pred.cpu().numpy().flatten()
        top_idx = pred_np.argsort()[-3:][::-1]
        
        title_lines = []
        for tidx in top_idx:
            if pred_np[tidx] > 0.3:
                title_lines.append(f"{disease_categories[tidx][:8]}: {pred_np[tidx]:.2f}")
        
        ax.set_title('\n'.join(title_lines) if title_lines else 'No Finding', fontsize=8)
        ax.axis('off')

plt.suptitle('Sample Predictions (Image + Segmentation Overlay)', fontsize=12)
plt.tight_layout()
plt.savefig(checkpoint_dir / 'option6_ssl_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Sample predictions visualized")

## Summary

This notebook implemented **Option 6 with SSL**:

1. **2-Channel Input**: Combined grayscale image + rule-based lung segmentation mask
2. **SSL Pretraining**: Contrastive learning (NT-Xent loss) to learn representations
3. **Fine-tuning**: Multi-label classification for 14 diseases

### Key Benefits:
- Preserves all original image information
- Provides anatomical context through segmentation channel
- SSL pretraining helps learn robust features before supervised learning
- Consistent augmentations applied to both image and mask channels