In [None]:
# ============================================
# üì¶ Step 1: Import Libraries
# ============================================

import warnings
warnings.filterwarnings('ignore')

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
import cv2
import random
from tqdm import tqdm
from pathlib import Path
import kagglehub

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)
print("‚úÖ Libraries imported successfully")

In [None]:
# ============================================
# üìÅ Step 2: Download and Load Dataset
# ============================================

path = kagglehub.dataset_download("khanfashee/nih-chest-x-ray-14-224x224-resized")
BASE_PATH = Path(path)

df_labels = pd.read_csv(BASE_PATH / "Data_Entry_2017.csv")
images_dir = BASE_PATH / "images-224" / "images-224"
df_labels["Image Path"] = [str(images_dir / p) for p in df_labels["Image Index"].values]

DISEASE_CATEGORIES = [
    'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
    'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
    'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
]

for disease in DISEASE_CATEGORIES:
    df_labels[disease] = df_labels['Finding Labels'].apply(lambda x: 1 if disease in x else 0)

sample_paths = df_labels['Image Path'].sample(200, random_state=42).values
missing = [p for p in sample_paths if not os.path.exists(p)]
if missing:
    raise FileNotFoundError(f"‚ùå Missing {len(missing)} images!")

print(f"‚úÖ Loaded {len(df_labels):,} images")

In [None]:
# ============================================
# ‚öôÔ∏è Step 3: Configuration
# ============================================

class Config:
    img_size = 224
    feat_dim = 256
    proj_dim = 128
    
    batch_size = 64
    pretrain_epochs = 50
    finetune_epochs = 30
    lr_pretrain = 1e-3
    lr_finetune = 1e-4
    temperature = 0.1
    pathology_weight = 0.5
    
    num_workers = 4
    use_subset = False
    
    # Adaptive thresholding params
    adaptive_block_size = 11
    adaptive_C = 2
    gradient_threshold = 0.15
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

cfg = Config()

print("‚öôÔ∏è Configuration:")
print(f"   Device: {cfg.device}")
print(f"   Adaptive block size: {cfg.adaptive_block_size}")
print(f"   Gradient threshold: {cfg.gradient_threshold}")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

In [None]:
# ============================================
# üíæ Step 3.5: Checkpoint & Resume Configuration
# ============================================

import os
import shutil
from datetime import datetime

OPTION_NAME = "option3"

# ===== RESUME CONFIGURATION =====
CHECKPOINT_DATASET_NAME = "chest-xray-ssl-checkpoints"
RESUME_SSL_PRETRAINING = False
RESUME_FINETUNING = False
SSL_CHECKPOINT_FILE = "latest"
FINETUNE_CHECKPOINT_FILE = "latest"

IN_KAGGLE = os.path.exists('/kaggle')
IN_COLAB = False

try:
    from google.colab import drive
    drive.mount('/content/drive')
    CHECKPOINT_DIR = '/content/drive/MyDrive/chest_xray_ssl'
    IN_COLAB = True
except:
    pass

if IN_KAGGLE:
    CHECKPOINT_DIR = '/kaggle/working/checkpoints'
    PREV_CHECKPOINT_DIR = f'/kaggle/input/{CHECKPOINT_DATASET_NAME}'
    if os.path.exists(PREV_CHECKPOINT_DIR):
        print(f"‚úÖ Found checkpoints at: {PREV_CHECKPOINT_DIR}")
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        for f in os.listdir(PREV_CHECKPOINT_DIR):
            if f.endswith('.pth'):
                src, dst = os.path.join(PREV_CHECKPOINT_DIR, f), os.path.join(CHECKPOINT_DIR, f)
                if not os.path.exists(dst): shutil.copy2(src, dst)
elif not IN_COLAB:
    CHECKPOINT_DIR = './checkpoints'

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

def save_checkpoint(state, filename):
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    state['saved_at'] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    torch.save(state, filepath)
    print(f"üíæ Saved: {filename}")
    if IN_KAGGLE: torch.save(state, f'/kaggle/working/{filename}')

def load_checkpoint(filename):
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.exists(filepath):
        checkpoint = torch.load(filepath, map_location=cfg.device)
        print(f"‚úÖ Loaded: {filename}")
        return checkpoint
    return None

def find_latest_checkpoint(prefix):
    if not os.path.exists(CHECKPOINT_DIR): return None
    latest = f'{prefix}_latest.pth'
    if os.path.exists(os.path.join(CHECKPOINT_DIR, latest)): return latest
    import re
    pattern = re.compile(rf'{prefix}_epoch(\d+)\.pth')
    max_epoch, best = -1, None
    for f in os.listdir(CHECKPOINT_DIR):
        m = pattern.match(f)
        if m and int(m.group(1)) > max_epoch: max_epoch, best = int(m.group(1)), f
    return best

print(f"üîß Environment: {'Kaggle' if IN_KAGGLE else 'Colab' if IN_COLAB else 'Local'}")
print(f"üìÇ Checkpoint dir: {CHECKPOINT_DIR}")

In [None]:
# ============================================
# üîç Step 4: Adaptive Pathology Segmentation
# ============================================

def adaptive_pathology_segmentation(image, block_size=11, C=2, gradient_threshold=0.15, min_size=100):
    """
    Detect potential pathological regions using adaptive thresholding + gradients
    
    Algorithm:
    1. Adaptive thresholding: Detects local high-contrast regions
    2. Sobel gradients: Finds edges and region boundaries
    3. Combined mask: High local contrast AND edges = likely pathology
    
    Args:
        image: Grayscale image (H, W) or (1, H, W)
        block_size: Adaptive thresholding block size
        C: Constant subtracted in adaptive thresholding
        gradient_threshold: Threshold for gradient magnitude
        min_size: Minimum region size to keep (pixels)
    
    Returns:
        Pathology mask (H, W) with values in [0, 1]
    """
    if len(image.shape) == 3 and image.shape[0] == 1:
        image = image[0]
    
    img_uint8 = (image * 255).astype(np.uint8)
    
    # Adaptive Gaussian thresholding - detects local high-contrast areas
    adaptive = cv2.adaptiveThreshold(
        img_uint8,
        255,
        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
        cv2.THRESH_BINARY,
        block_size,
        C
    )
    
    # Sobel edge detection for boundaries
    sobelx = cv2.Sobel(img_uint8, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(img_uint8, cv2.CV_64F, 0, 1, ksize=3)
    
    gradient_mag = np.sqrt(sobelx**2 + sobely**2)
    gradient_mag = gradient_mag / (gradient_mag.max() + 1e-8)
    gradient_mask = (gradient_mag > gradient_threshold).astype(np.uint8) * 255
    
    # Combine: regions with BOTH high local contrast AND edges
    combined = cv2.bitwise_and(adaptive, gradient_mask)
    
    # Morphology cleanup
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    combined = cv2.morphologyEx(combined, cv2.MORPH_CLOSE, kernel)
    combined = cv2.morphologyEx(combined, cv2.MORPH_OPEN, kernel)
    
    # Remove small noise regions
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(combined, connectivity=8)
    
    roi_mask = np.zeros_like(combined)
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_size:
            roi_mask[labels == i] = 255
    
    return roi_mask.astype(np.float32) / 255.0

print("‚úÖ Adaptive pathology segmentation function defined")

In [None]:
# ============================================
# üëÅÔ∏è Step 5: Visualize Pathology Detection
# ============================================

sample_indices = [0, 100, 500]
fig, axes = plt.subplots(len(sample_indices), 4, figsize=(16, 4*len(sample_indices)))

for i, idx in enumerate(sample_indices):
    img_path = df_labels.iloc[idx]['Image Path']
    img = Image.open(img_path).convert('L')
    img = img.resize((cfg.img_size, cfg.img_size), Image.LANCZOS)
    img_np = np.array(img, dtype=np.float32) / 255.0
    
    # Get pathology mask
    pathology = adaptive_pathology_segmentation(img_np, cfg.adaptive_block_size, 
                                                cfg.adaptive_C, cfg.gradient_threshold)
    
    # Adaptive thresholding alone
    img_uint8 = (img_np * 255).astype(np.uint8)
    adaptive = cv2.adaptiveThreshold(img_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                      cv2.THRESH_BINARY, cfg.adaptive_block_size, cfg.adaptive_C)
    
    # Gradient magnitude
    sobelx = cv2.Sobel(img_uint8, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(img_uint8, cv2.CV_64F, 0, 1, ksize=3)
    gradient_mag = np.sqrt(sobelx**2 + sobely**2)
    gradient_mag = gradient_mag / (gradient_mag.max() + 1e-8)
    
    # Plot
    axes[i, 0].imshow(img_np, cmap='gray')
    axes[i, 0].set_title('Original')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(adaptive, cmap='hot')
    axes[i, 1].set_title('Adaptive Threshold')
    axes[i, 1].axis('off')
    
    axes[i, 2].imshow(gradient_mag, cmap='hot')
    axes[i, 2].set_title('Gradient Magnitude')
    axes[i, 2].axis('off')
    
    axes[i, 3].imshow(pathology, cmap='hot')
    axes[i, 3].set_title(f'Combined Pathology ({pathology.mean():.1%})')
    axes[i, 3].axis('off')

plt.suptitle('Option 3: Adaptive Pathology Detection', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('option3_pathology_detection.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üîÑ Step 6: Augmentation & Datasets
# ============================================

class ChestXrayAugment:
    def __init__(self, img_size=224):
        self.img_size = img_size
    
    def __call__(self, img):
        if isinstance(img, np.ndarray):
            x = torch.tensor(img, dtype=torch.float32)
        else:
            x = img.clone()
        
        if random.random() < 0.5:
            x = torch.flip(x, dims=[2])
        if random.random() < 0.7:
            angle = random.uniform(-15, 15)
            x = transforms.functional.rotate(x, angle)
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_brightness(x, factor)
        if random.random() < 0.8:
            factor = 1 + random.uniform(-0.2, 0.2)
            x = transforms.functional.adjust_contrast(x, factor)
        if random.random() < 0.5:
            noise = torch.randn_like(x) * 0.05
            x = torch.clamp(x + noise, 0, 1)
        return x

augment = ChestXrayAugment(cfg.img_size)


class PathologyAwareDataset(Dataset):
    def __init__(self, df, transform=None, img_size=224):
        self.df = df.copy().reset_index(drop=True)
        self.transform = transform
        self.img_size = img_size
        
        sample_paths = self.df['Image Path'].sample(min(200, len(self.df)), random_state=42).values
        missing = [p for p in sample_paths if not os.path.exists(p)]
        if missing:
            raise FileNotFoundError(f"‚ùå Missing {len(missing)} images!")
        
        print(f"üì¶ PathologyAwareDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = self.df.iloc[idx]['Image Path']
        
        img = Image.open(img_path).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = np.expand_dims(img, 0)
        
        pathology_mask = adaptive_pathology_segmentation(img, cfg.adaptive_block_size, 
                                                         cfg.adaptive_C, cfg.gradient_threshold)
        pathology_mask = np.expand_dims(pathology_mask, 0)
        
        if self.transform:
            view1 = self.transform(img)
            view2 = self.transform(img)
        else:
            view1 = torch.tensor(img, dtype=torch.float32)
            view2 = torch.tensor(img, dtype=torch.float32)
        
        mask = torch.tensor(pathology_mask, dtype=torch.float32)
        
        return view1, view2, mask


class ClassificationDataset(Dataset):
    def __init__(self, df, disease_categories, img_size=224):
        self.df = df.copy().reset_index(drop=True)
        self.disease_categories = disease_categories
        self.img_size = img_size
        print(f"üì¶ ClassificationDataset: {len(self.df)} samples")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['Image Path']).convert('L')
        img = img.resize((self.img_size, self.img_size), Image.LANCZOS)
        img = np.array(img, dtype=np.float32) / 255.0
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        labels = torch.tensor([row[d] for d in self.disease_categories], dtype=torch.float32)
        return img, labels

print("‚úÖ Dataset classes defined")

In [None]:
# ============================================
# üèóÔ∏è Step 7: Model Architecture
# ============================================

def conv_block(in_c, out_c, kernel=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel, stride, padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    )

def residual_block(channels):
    return nn.Sequential(
        conv_block(channels, channels),
        conv_block(channels, channels)
    )

class Encoder(nn.Module):
    def __init__(self, in_channels=1, feat_dim=256):
        super().__init__()
        self.features = nn.Sequential(
            conv_block(in_channels, 64), residual_block(64), nn.MaxPool2d(2),
            conv_block(64, 128), residual_block(128), nn.MaxPool2d(2),
            conv_block(128, 256), residual_block(256), residual_block(256), nn.MaxPool2d(2),
            conv_block(256, 512), residual_block(512), residual_block(512), nn.MaxPool2d(2),
            conv_block(512, 512), residual_block(512), nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Sequential(
            nn.Linear(512, 512), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(512, feat_dim)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ProjectionHead(nn.Module):
    def __init__(self, feat_dim=256, proj_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU(),
            nn.Linear(feat_dim, proj_dim)
        )
    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, feat_dim=256, img_size=224):
        super().__init__()
        self.init_size = img_size // 32
        self.fc = nn.Sequential(
            nn.Linear(feat_dim, 256 * self.init_size * self.init_size), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.BatchNorm2d(16), nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 4, 2, 1), nn.Sigmoid()
        )
    def forward(self, z):
        x = self.fc(z)
        x = x.view(z.size(0), 256, self.init_size, self.init_size)
        return self.decoder(x)

class Classifier(nn.Module):
    def __init__(self, feat_dim=256, num_classes=14):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.net(x)

encoder = Encoder(feat_dim=cfg.feat_dim).to(cfg.device)
proj_head = ProjectionHead(cfg.feat_dim, cfg.proj_dim).to(cfg.device)
decoder = Decoder(cfg.feat_dim, cfg.img_size).to(cfg.device)

total_params = sum(p.numel() for m in [encoder, proj_head, decoder] for p in m.parameters())
print(f"‚úÖ Models initialized ({total_params:,} parameters)")

In [None]:
# ============================================
# üî• Step 8: Loss Functions
# ============================================

def nt_xent_loss(z1, z2, temperature=0.1):
    device = z1.device
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    batch_size = z1.shape[0]
    representations = torch.cat([z1, z2], dim=0)
    similarity = torch.matmul(representations, representations.T) / temperature
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    similarity = similarity.masked_fill(mask, -float('inf'))
    labels = torch.cat([torch.arange(batch_size) + batch_size,
                        torch.arange(batch_size)]).to(device)
    return F.cross_entropy(similarity, labels)


def pathology_aware_loss(proj_1, proj_2, pathology_mask_1, pathology_mask_2, 
                         temperature=0.1, pathology_weight=0.5):
    """
    üéØ KEY INNOVATION: Pathology-aware contrastive loss
    
    - Base NT-Xent loss
    - Emphasizes images with detected pathological regions
    - Higher weight for images with clear abnormalities
    """
    base_loss = nt_xent_loss(proj_1, proj_2, temperature)
    
    path_score_1 = pathology_mask_1.mean(dim=[1, 2, 3])
    path_score_2 = pathology_mask_2.mean(dim=[1, 2, 3])
    pathology_score = (path_score_1 + path_score_2) / 2
    
    batch_weights = 1.0 + pathology_weight * pathology_score.to(proj_1.device)
    avg_weight = batch_weights.mean()
    
    return base_loss * avg_weight

print("‚úÖ Loss functions defined")
print("   üéØ pathology_aware_loss: Emphasizes abnormal regions")

In [None]:
# ============================================
# üìä Step 9: Create Data Loaders
# ============================================

df_shuffled = df_labels.sample(frac=1, random_state=42).reset_index(drop=True)
train_size = int(0.8 * len(df_shuffled))
train_df = df_shuffled[:train_size]
val_df = df_shuffled[train_size:]

train_pretrain_ds = PathologyAwareDataset(train_df, transform=augment, img_size=cfg.img_size)
train_class_ds = ClassificationDataset(train_df, DISEASE_CATEGORIES, cfg.img_size)
val_class_ds = ClassificationDataset(val_df, DISEASE_CATEGORIES, cfg.img_size)

pretrain_loader = DataLoader(
    train_pretrain_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)
train_loader = DataLoader(
    train_class_ds, batch_size=cfg.batch_size, shuffle=True,
    num_workers=cfg.num_workers, pin_memory=True, drop_last=True
)
val_loader = DataLoader(
    val_class_ds, batch_size=cfg.batch_size, shuffle=False,
    num_workers=cfg.num_workers, pin_memory=True
)

print(f"‚úÖ DataLoaders ready")

In [None]:
# ============================================
# üöÄ Step 10: Pathology-Aware SSL Pretraining
# ============================================

optimizer_ssl = torch.optim.Adam(
    list(encoder.parameters()) + list(proj_head.parameters()) + list(decoder.parameters()),
    lr=cfg.lr_pretrain, weight_decay=1e-4
)

ssl_history = {'loss': [], 'contrastive': [], 'reconstruction': []}
START_EPOCH = 1

if RESUME_SSL_PRETRAINING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_ssl') if SSL_CHECKPOINT_FILE == "latest" else SSL_CHECKPOINT_FILE
    if ckpt_file:
        checkpoint = load_checkpoint(ckpt_file)
        if checkpoint:
            encoder.load_state_dict(checkpoint['encoder'])
            proj_head.load_state_dict(checkpoint['proj_head'])
            decoder.load_state_dict(checkpoint['decoder'])
            if 'optimizer' in checkpoint: optimizer_ssl.load_state_dict(checkpoint['optimizer'])
            ssl_history = checkpoint.get('ssl_history', ssl_history)
            START_EPOCH = checkpoint['epoch'] + 1
            print(f"üîÑ Resuming from epoch {START_EPOCH}")
    else:
        print("‚ö†Ô∏è No checkpoint found. Starting fresh.")

if START_EPOCH > cfg.pretrain_epochs:
    print(f"‚úÖ SSL Pretraining already complete")
else:
    print(f"\nüöÄ Starting Option 3: Pathology-Aware SSL Pretraining")
    print(f"   Epochs: {START_EPOCH} ‚Üí {cfg.pretrain_epochs}")
    print("=" * 60)
    
    for epoch in range(START_EPOCH, cfg.pretrain_epochs + 1):
        encoder.train()
        proj_head.train()
        decoder.train()
        
        total_loss = total_cont = total_recon = 0
        
        pbar = tqdm(pretrain_loader, desc=f"Epoch {epoch}/{cfg.pretrain_epochs}")
        for view1, view2, pathology_mask in pbar:
            view1 = view1.to(cfg.device)
            view2 = view2.to(cfg.device)
            pathology_mask = pathology_mask.to(cfg.device)
            
            optimizer_ssl.zero_grad()
            
            z1, z2 = encoder(view1), encoder(view2)
            p1, p2 = proj_head(z1), proj_head(z2)
            cont_loss = pathology_aware_loss(p1, p2, pathology_mask, pathology_mask, 
                                             cfg.temperature, cfg.pathology_weight)
            
            rec1, rec2 = decoder(z1), decoder(z2)
            recon_loss = (F.mse_loss(rec1, view1) + F.mse_loss(rec2, view2)) / 2
            
            loss = cont_loss + 0.5 * recon_loss
            loss.backward()
            optimizer_ssl.step()
            
            total_loss += loss.item()
            total_cont += cont_loss.item()
            total_recon += recon_loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        n = len(pretrain_loader)
        ssl_history['loss'].append(total_loss / n)
        ssl_history['contrastive'].append(total_cont / n)
        ssl_history['reconstruction'].append(total_recon / n)
        
        print(f"Epoch {epoch}: Loss={total_loss/n:.4f}, Cont={total_cont/n:.4f}, Recon={total_recon/n:.4f}")
        
        save_checkpoint({
            'epoch': epoch, 'encoder': encoder.state_dict(),
            'proj_head': proj_head.state_dict(), 'decoder': decoder.state_dict(),
            'optimizer': optimizer_ssl.state_dict(), 'ssl_history': ssl_history,
        }, f'{OPTION_NAME}_ssl_latest.pth')
        save_checkpoint({
            'epoch': epoch, 'encoder': encoder.state_dict(),
            'proj_head': proj_head.state_dict(), 'decoder': decoder.state_dict(),
            'ssl_history': ssl_history,
        }, f'{OPTION_NAME}_ssl_epoch{epoch}.pth')
    
    print("\n‚úÖ Pathology-Aware SSL Pretraining Complete!")

In [None]:
# ============================================
# üìà Step 11: Plot SSL Training Curves
# ============================================

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(ssl_history['loss'], 'b-', linewidth=2)
axes[0].set_title('Total Loss', fontsize=12)
axes[0].set_xlabel('Epoch')
axes[0].grid(True, alpha=0.3)

axes[1].plot(ssl_history['contrastive'], 'r-', linewidth=2)
axes[1].set_title('Pathology-Aware Contrastive Loss', fontsize=12)
axes[1].set_xlabel('Epoch')
axes[1].grid(True, alpha=0.3)

axes[2].plot(ssl_history['reconstruction'], 'g-', linewidth=2)
axes[2].set_title('Reconstruction Loss', fontsize=12)
axes[2].set_xlabel('Epoch')
axes[2].grid(True, alpha=0.3)

plt.suptitle('Option 3: Pathology-Aware SSL Training', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('option3_ssl_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================
# üíæ Step 12: Save Pretrained Model
# ============================================

torch.save({
    'encoder': encoder.state_dict(),
    'proj_head': proj_head.state_dict(),
    'decoder': decoder.state_dict(),
    'config': {'feat_dim': cfg.feat_dim, 'proj_dim': cfg.proj_dim}
}, 'option3_ssl_pretrained.pth')

print("üíæ Pretrained model saved: option3_ssl_pretrained.pth")

In [None]:
# ============================================
# üéØ Step 13: Fine-tuning
# ============================================

for param in encoder.parameters():
    param.requires_grad = False
encoder.eval()

classifier = Classifier(cfg.feat_dim, len(DISEASE_CATEGORIES)).to(cfg.device)

pos_counts = train_df[DISEASE_CATEGORIES].sum().values
neg_counts = len(train_df) - pos_counts
pos_weights = torch.tensor(neg_counts / (pos_counts + 1e-6), dtype=torch.float32).to(cfg.device)

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
optimizer_ft = torch.optim.Adam(classifier.parameters(), lr=cfg.lr_finetune, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_ft, 'max', patience=5, factor=0.5)

finetune_history = {'train_loss': [], 'train_auc': [], 'val_loss': [], 'val_auc': []}
best_val_auc = 0
FINETUNE_START_EPOCH = 1

if RESUME_FINETUNING:
    ckpt_file = find_latest_checkpoint(f'{OPTION_NAME}_finetune') if FINETUNE_CHECKPOINT_FILE == "latest" else FINETUNE_CHECKPOINT_FILE
    if ckpt_file:
        ft_ckpt = load_checkpoint(ckpt_file)
        if ft_ckpt:
            classifier.load_state_dict(ft_ckpt['classifier'])
            if 'optimizer' in ft_ckpt: optimizer_ft.load_state_dict(ft_ckpt['optimizer'])
            finetune_history = ft_ckpt.get('finetune_history', finetune_history)
            best_val_auc = ft_ckpt.get('best_val_auc', 0)
            FINETUNE_START_EPOCH = ft_ckpt['epoch'] + 1
            print(f"üîÑ Resuming fine-tuning from epoch {FINETUNE_START_EPOCH}")

if FINETUNE_START_EPOCH > cfg.finetune_epochs:
    print(f"‚úÖ Fine-tuning already complete")
else:
    print(f"\nüéØ Starting Fine-tuning: Epochs {FINETUNE_START_EPOCH} ‚Üí {cfg.finetune_epochs}")
    print("=" * 50)
    
    for epoch in range(FINETUNE_START_EPOCH, cfg.finetune_epochs + 1):
        classifier.train()
        train_loss = 0
        train_preds, train_targets = [], []
        
        for images, targets in tqdm(train_loader, desc=f"Train {epoch}/{cfg.finetune_epochs}"):
            images, targets = images.to(cfg.device), targets.to(cfg.device)
            optimizer_ft.zero_grad()
            with torch.no_grad(): features = encoder(images)
            logits = classifier(features)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer_ft.step()
            train_loss += loss.item()
            train_preds.append(torch.sigmoid(logits).detach().cpu())
            train_targets.append(targets.cpu())
        
        classifier.eval()
        val_loss = 0
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for images, targets in val_loader:
                images = images.to(cfg.device)
                features = encoder(images)
                logits = classifier(features)
                loss = criterion(logits, targets)
                val_loss += loss.item()
                val_preds.append(torch.sigmoid(logits).cpu())
                val_targets.append(targets.cpu())
        
        train_preds, train_targets = torch.cat(train_preds).numpy(), torch.cat(train_targets).numpy()
        val_preds, val_targets = torch.cat(val_preds).numpy(), torch.cat(val_targets).numpy()
        
        train_auc = np.mean([roc_auc_score(train_targets[:, i], train_preds[:, i]) 
                             for i in range(len(DISEASE_CATEGORIES)) if len(np.unique(train_targets[:, i])) > 1])
        val_auc = np.mean([roc_auc_score(val_targets[:, i], val_preds[:, i]) 
                           for i in range(len(DISEASE_CATEGORIES)) if len(np.unique(val_targets[:, i])) > 1])
        
        finetune_history['train_loss'].append(train_loss / len(train_loader))
        finetune_history['train_auc'].append(train_auc)
        finetune_history['val_loss'].append(val_loss / len(val_loader))
        finetune_history['val_auc'].append(val_auc)
        scheduler.step(val_auc)
        
        print(f"Epoch {epoch}: Train AUC={train_auc:.4f}, Val AUC={val_auc:.4f}")
        
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            save_checkpoint({'encoder': encoder.state_dict(), 'classifier': classifier.state_dict(),
                            'val_auc': val_auc, 'epoch': epoch}, f'{OPTION_NAME}_best_model.pth')
            print(f"  ‚úÖ Best model saved! Val AUC: {val_auc:.4f}")
        
        if epoch % 5 == 0 or epoch == cfg.finetune_epochs:
            save_checkpoint({'epoch': epoch, 'classifier': classifier.state_dict(),
                            'optimizer': optimizer_ft.state_dict(), 'finetune_history': finetune_history,
                            'best_val_auc': best_val_auc}, f'{OPTION_NAME}_finetune_latest.pth')
    
    print(f"\nüèÜ Best Validation AUC: {best_val_auc:.4f}")

In [None]:
# ============================================
# üìä Step 14: Final Evaluation & Summary
# ============================================

checkpoint = torch.load('option3_best_model.pth')
encoder.load_state_dict(checkpoint['encoder'])
classifier.load_state_dict(checkpoint['classifier'])

encoder.eval()
classifier.eval()

all_preds, all_targets = [], []
with torch.no_grad():
    for images, targets in tqdm(val_loader, desc="Evaluating"):
        images = images.to(cfg.device)
        features = encoder(images)
        logits = classifier(features)
        all_preds.append(torch.sigmoid(logits).cpu())
        all_targets.append(targets)

all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

print("\nüìä Per-Disease AUC Scores:")
print("=" * 40)
auc_scores = []
for i, disease in enumerate(DISEASE_CATEGORIES):
    if len(np.unique(all_targets[:, i])) > 1:
        auc = roc_auc_score(all_targets[:, i], all_preds[:, i])
        auc_scores.append((disease, auc))
        print(f"{disease:20s}: {auc:.4f}")

mean_auc = np.mean([a for _, a in auc_scores])
print(f"\n{'Mean AUC':20s}: {mean_auc:.4f}")

auc_scores.sort(key=lambda x: x[1], reverse=True)
diseases, aucs = zip(*auc_scores)

plt.figure(figsize=(12, 6))
colors = ['green' if a >= 0.7 else 'orange' if a >= 0.6 else 'red' for a in aucs]
plt.barh(diseases, aucs, color=colors, alpha=0.8)
plt.axvline(0.5, color='red', linestyle='--', alpha=0.5, label='Random')
plt.axvline(mean_auc, color='blue', linestyle='--', alpha=0.7, label=f'Mean: {mean_auc:.3f}')
plt.xlabel('AUC Score')
plt.title('Option 3: Per-Disease AUC Performance', fontsize=14, fontweight='bold')
plt.legend()
plt.tight_layout()
plt.savefig('option3_auc_performance.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "=" * 60)
print("üìù OPTION 3: PATHOLOGY-AWARE SSL SUMMARY")
print("=" * 60)
print(f"Method: Adaptive Thresholding + Gradient-Based Pathology Detection")
print(f"Key: Emphasizes images with detected abnormal regions")
print(f"\nüèÜ Final Mean AUC: {mean_auc:.4f}")
print("=" * 60)