In [None]:
import timm
import os
import wandb
import numpy as np 
import pandas as pd
from tqdm.notebook import tqdm
import torch
from torch import nn, optim
import torch.nn.functional as F
import torchmetrics
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from peft import get_peft_model, LoraConfig
wandb.login(key="")

In [None]:
config = {
    "root_dir": '',
    "image_size": 224,
    "embedding_dim": 768, 
    "num_workers": 2, 
    "vit_dropout": 0.1,
    # LoRA Config
    "lora_r": 64,
    "lora_alpha": 256,
    "lora_dropout": 0.1,
    # Phase 1: Contrastive Pre-training
    "batch_size_contrastive": 32,
    "projection_dim": 128,
    "lr_contrastive": 2e-4,
    "wd_contrastive": 1e-6,
    "epochs_contrastive": 10,
    "temperature": 0.07,
    # Phase 2: Supervised Fine-tuning
    "batch_size_finetune": 32,
    "lr_finetune": 1e-4,
    "wd_finetune": 1e-5,
    "epochs_finetune": 15,
    "early_stopping_patience": 3,
}

In [None]:
PRETRAINED_BACKBONE_PATH = "./vit-contrastive-lora-backbone.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
#Load data and clean image path
test = pd.read_csv('u0_test.csv', index_col=0)
train_df = pd.read_csv('u0_train.csv', index_col=0)
val_df = pd.read_csv('u0_val.csv', index_col=0)

train_df.index = train_df.index.str.replace('CheXpert-v1.0-small', 'chexpert')
test.index = test.index.str.replace('CheXpert-v1.0-small', 'chexpert')
val_df.index = val_df.index.str.replace('CheXpert-v1.0-small', 'chexpert')

class_names = train_df.columns.tolist()

print(f"Training data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")


In [None]:
def get_transforms():
    """Returns a dictionary of augmentation pipelines for each phase."""
    normalize_transform = A.Normalize(
        mean=[0.506, 0.506, 0.506],
        std=[0.287, 0.287, 0.287]
    )

    base_train_transforms = [
        A.Affine(scale=(0.95, 1.05), p=0.5),
        A.OneOf([A.Affine(rotate=(-20, 20), p=0.5), A.Affine(shear=(-5, 5), p=0.5)], p=0.5),
        A.Affine(translate_percent=(-0.05, 0.05), p=0.5),
        A.Resize(224, 224),
        normalize_transform,
        ToTensorV2()
    ]

    return {
        'contrastive': A.Compose(base_train_transforms),
        'supervised': A.Compose(base_train_transforms),
        'validation': A.Compose([
            A.Resize(224, 224),
            normalize_transform,
            ToTensorV2()
        ])
    }
def get_weighted_sampler(data_frame):
    """
    Creates a WeightedRandomSampler to handle class imbalance.
    It gives more weight to samples from under-represented classes.
    """
    class_weights = (1.0 / data_frame.sum()).values
    sample_weights = data_frame.dot(class_weights)
    return WeightedRandomSampler(
        weights=torch.tensor(sample_weights.values, dtype=torch.float),
        num_samples=len(sample_weights),
        replacement=True
    )

In [None]:
#Classes for contrastive learning, contrastive augmentations and contrastive loss

class ContrastiveViewDataset(Dataset):
    """
    Creates two augmented views of each image, in this case a pair of augmentations for contrastive learning.
    """
    def __init__(self, data_frame, root_dir, transform):
        self.img_paths = [os.path.join(root_dir, path) for path in data_frame.index]
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        try:
            image = np.array(Image.open(self.img_paths[idx]).convert("RGB"))
            view1 = self.transform(image=image)['image']
            view2 = self.transform(image=image)['image']
            return view1, view2
        except (IOError, FileNotFoundError):
            print(f"Warning: Could not load image at {self.img_paths[idx]}. Returning zeros.")
            return torch.zeros((3, config["image_size"], config["image_size"])), torch.zeros((3, config["image_size"], config["image_size"]))


class NTXentLoss(nn.Module):
    """Normalized Temperature-scaled Cross Entropy Loss."""
    def __init__(self, temperature, device):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.device = device
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

    def forward(self, z_i, z_j):
        batch_size = z_i.shape[0]
        z_i, z_j = F.normalize(z_i, p=2, dim=1), F.normalize(z_j, p=2, dim=1)
        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        labels = torch.cat([torch.arange(batch_size) for _ in range(2)]).to(self.device)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
        labels, similarity_matrix = labels[~mask].view(labels.shape[0], -1), similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
        
        logits = torch.cat([positives, negatives], dim=1) / self.temperature
        ground_truth = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
        return self.criterion(logits, ground_truth) / (2 * batch_size)

class ContrastiveViT(nn.Module):
    """ViT with a projection head for contrastive learning."""
    def __init__(self, model_name="vit_base_patch16_224", embedding_dim=768, projection_dim=128, drop_rate=0.1):
        super(ContrastiveViT, self).__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, drop_rate=drop_rate)
        self.backbone.head = nn.Identity()
        self.projection_head = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim), nn.ReLU(inplace=True), nn.Linear(embedding_dim, projection_dim)
        )
    def forward(self, x):
        return self.projection_head(self.backbone(x))

In [None]:
#Supervised learning phase classes
class SupervisedDataset(Dataset):
    """
    Standard dataset for supervised learning.
    Each call to __getitem__ returns an augmented image and its corresponding label.
    """
    def __init__(self, data_frame, root_dir, transform):
        self.img_paths = [os.path.join(root_dir, path) for path in data_frame.index]
        self.labels = torch.tensor(data_frame.values, dtype=torch.float32)
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        try:
            image = np.array(Image.open(self.img_paths[idx]).convert("RGB"))
            image_tensor = self.transform(image=image)['image']
            return image_tensor, self.labels[idx]
        except (IOError, FileNotFoundError):
            print(f"Warning: Could not load image at {self.img_paths[idx]}. Returning zeros.")
            return torch.zeros((3, config["image_size"], config["image_size"])), torch.zeros(14)


class FineTuningViT(nn.Module):
    """ViT with a classification head for fine-tuning."""
    def __init__(self, lora_config, num_classes=14, drop_rate=0.1):
        super(FineTuningViT, self).__init__()
        backbone = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=0, drop_rate=drop_rate)
        self.backbone = get_peft_model(backbone, lora_config)
        self.classifier = nn.Linear(self.backbone.embed_dim, num_classes)

    def load_from_pretrained(self, pretrained_path):
        self.backbone.load_state_dict(torch.load(pretrained_path, map_location=DEVICE), strict=False)
        print(f"Loaded pre-trained backbone from {pretrained_path}")

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

In [None]:
#Main training loop for both contrastive learning and supervised learning
def contrastive_train(cfg, save_path):
    """PHASE 1: Self-supervised pre-training."""
    print(f"Starting Phase 1: Contrastive Pre-training on {DEVICE}")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    all_transforms = get_transforms()
    train_dataset = ContrastiveViewDataset(train_df, cfg["root_dir"], all_transforms['contrastive'])
    train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size_contrastive"], shuffle=True, num_workers=cfg["num_workers"], pin_memory=True, drop_last=True)
    
    model = ContrastiveViT(
        embedding_dim=cfg["embedding_dim"], 
        projection_dim=cfg["projection_dim"],
        drop_rate=cfg["vit_dropout"]
    ).to(DEVICE)
    lora_config = LoraConfig(r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], target_modules=["qkv", "proj"], lora_dropout=cfg["lora_dropout"], bias="none")
    model.backbone = get_peft_model(model.backbone, lora_config)
    model.backbone.print_trainable_parameters()
    
    optimizer = optim.AdamW(model.parameters(), lr=cfg["lr_contrastive"], weight_decay=cfg["wd_contrastive"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs_contrastive"])
    loss_fn = NTXentLoss(temperature=cfg["temperature"], device=DEVICE)
    scaler = torch.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

    wandb.init(project="ViT-CheXpert-Pipeline", name="Phase1-Contrastive", config=cfg)
    for epoch in range(cfg["epochs_contrastive"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['epochs_contrastive']}", unit="batch")

        for view1, view2 in pbar:
            view1, view2 = view1.to(DEVICE), view2.to(DEVICE)
            
            optimizer.zero_grad()
            with torch.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda')):
                proj1, proj2 = model(view1), model(view2)
                loss = loss_fn(proj1, proj2)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            pbar.set_postfix({"Loss": loss.item()})
        
        scheduler.step()
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}, Avg Contrastive Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
        wandb.log({"epoch": epoch, "contrastive_loss": avg_loss, "lr_contrastive": scheduler.get_last_lr()[0]})
    
    torch.save(model.backbone.state_dict(), save_path)
    print(f"Phase 1 finished. Pre-trained backbone saved to {save_path}")
    wandb.finish()


def supervised_finetune(cfg, pretrained_path):
    """PHASE 2: Supervised fine-tuning for classification."""
    print(f"Starting Phase 2: Supervised Fine-tuning on {DEVICE}")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    all_transforms = get_transforms()
    train_dataset = SupervisedDataset(train_df, cfg["root_dir"], all_transforms['supervised'])
    val_dataset = SupervisedDataset(val_df, cfg["root_dir"], all_transforms['validation'])
    
    sampler = get_weighted_sampler(train_df)
    train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size_finetune"], sampler=sampler, num_workers=cfg["num_workers"], pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg["batch_size_finetune"], shuffle=False, num_workers=cfg["num_workers"], pin_memory=True)

    lora_config = LoraConfig(r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], target_modules=["qkv", "proj"], lora_dropout=cfg["lora_dropout"], bias="none")
    model = FineTuningViT(lora_config, num_classes=14, drop_rate=cfg["vit_dropout"])
    model.load_from_pretrained(pretrained_path)
    model.to(DEVICE)
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["lr_finetune"], weight_decay=cfg["wd_finetune"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs_finetune"])
    loss_fn = nn.BCEWithLogitsLoss()
    metric = torchmetrics.AUROC(task="multilabel", num_labels=14, average=None).to(DEVICE)
    scaler = torch.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

    wandb.init(project="ViT-CheXpert-Pipeline", name="Phase2-Finetune", config=cfg)
    best_val_auc = 0
    epochs_no_improve = 0
    for epoch in range(cfg["epochs_finetune"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['epochs_finetune']}", unit="batch")
        
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            with torch.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda')):
                outputs = model(images)
                loss = loss_fn(outputs, labels)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
                
            total_loss += loss.item()
            pbar.set_postfix({"Loss": loss.item()})
            
        avg_train_loss = total_loss / len(train_loader)

        model.eval()
        metric.reset()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                with torch.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda')):
                    outputs = torch.sigmoid(model(images))
                metric.update(outputs, labels.long())
        
        val_aucs = metric.compute()
        mean_val_auc = torch.nanmean(val_aucs).item()
        
        scheduler.step()
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val mAUROC: {mean_val_auc:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
        
        log_dict = {
            "epoch": epoch, 
            "train_loss": avg_train_loss, 
            "val_mAUC": mean_val_auc, 
            "lr_finetune": scheduler.get_last_lr()[0]
        }
        for i, class_name in enumerate(class_names):
            log_dict[f"val_auc_{class_name}"] = val_aucs[i].item()
        wandb.log(log_dict)

        if mean_val_auc > best_val_auc:
            best_val_auc = mean_val_auc
            epochs_no_improve = 0
            torch.save(model.state_dict(), f"./vit-finetuned-best.pt")
            print(f"New best model saved with mAUROC: {best_val_auc:.4f}")
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= cfg["early_stopping_patience"]:
            print(f"Early stopping triggered after {epochs_no_improve} epochs with no improvement.")
            break

    print("Phase 2 finished.")
    wandb.finish()

In [None]:
contrastive_train(config, PRETRAINED_BACKBONE_PATH)
supervised_finetune(config, PRETRAINED_BACKBONE_PATH)