In [None]:
import os
import timm
import torch
import wandb
import numpy as np
import pandas as pd
import albumentations as A
from PIL import Image
from tqdm.notebook import tqdm
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from peft import LoraConfig, get_peft_model
from albumentations.pytorch import ToTensorV2
import torchmetrics
wandb.login(key="") #your wandb key here

In [None]:
#parameters config
config = {
    "root_dir": '',
    "image_size": 224,
    "num_workers": 2,
    "batch_size": 32, 
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "epochs": 15,
    "early_stopping_patience": 3,
    # Model & LoRA Config
    "vit_dropout": 0.1,
    "lora_r": 64,
    "lora_alpha": 256,
    "lora_dropout": 0.1,
}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
#load data and adjust image df directory
train_df = pd.read_csv('u1_train.csv', index_col=0) 
val_df = pd.read_csv('u1_val.csv', index_col=0) 

train_df.index = train_df.index.str.replace('CheXpert-v1.0-small', 'chexpert')
val_df.index = val_df.index.str.replace('CheXpert-v1.0-small', 'chexpert')


class_names = train_df.columns.tolist()

print(f"Training data shape: {train_df.shape}")
print(f"Validation data shape: {val_df.shape}")

#Data augmentations
def get_transforms(image_size):
    """Returns a dictionary of augmentation pipelines for each phase."""
    # Define normalization stats
    normalize_transform = A.Normalize(
        mean=[0.506, 0.506, 0.506],
        std=[0.287, 0.287, 0.287]
    )

    return {
        'train': A.Compose([
            A.Affine(scale=(0.95, 1.05), p=0.5),
            A.OneOf([A.Affine(rotate=(-20, 20), p=0.5), A.Affine(shear=(-5, 5), p=0.5)], p=0.5),
            A.Affine(translate_percent=(-0.05, 0.05), p=0.5),
            A.Resize(image_size, image_size),
            normalize_transform,
            ToTensorV2()
        ]),
        'validation': A.Compose([
            A.Resize(image_size, image_size),
            normalize_transform,
            ToTensorV2()
        ])
    }

In [None]:
class CheXpertDataset(Dataset):
    """
    Each call to __getitem__ returns an augmented image and its corresponding label.
    """
    def __init__(self, data_frame, root_dir, transform):
        self.img_paths = [os.path.join(root_dir, path) for path in data_frame.index]
        self.labels = torch.tensor(data_frame.values, dtype=torch.float32)
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        try:
            image = np.array(Image.open(self.img_paths[idx]).convert("RGB"))
            image_tensor = self.transform(image=image)['image']
            return image_tensor, self.labels[idx]
        except (IOError, FileNotFoundError):
            print(f"Warning: Could not load image at {self.img_paths[idx]}. Returning zeros.")
            return torch.zeros((3, config["image_size"], config["image_size"])), torch.zeros(14)


def get_weighted_sampler(data_frame):
    """
    Creates a WeightedRandomSampler to handle class imbalance.
    It gives more weight to samples from under-represented classes.
    """
    class_weights = (1.0 / data_frame.sum(axis=0)).values
    sample_weights = data_frame.dot(class_weights)
    return WeightedRandomSampler(
        weights=torch.tensor(sample_weights.values, dtype=torch.float),
        num_samples=len(sample_weights),
        replacement=True
    )

In [None]:
def train_model(cfg):
    """Main function to run the training and validation loop."""
    print(f"Starting Supervised Training on {DEVICE}")
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    transforms = get_transforms(cfg['image_size'])
    train_dataset = CheXpertDataset(train_df, cfg["root_dir"], transforms['train'])
    val_dataset = CheXpertDataset(val_df, cfg["root_dir"], transforms['validation'])
    
    sampler = get_weighted_sampler(train_df)
    train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size"], sampler=sampler, num_workers=cfg["num_workers"], pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg["batch_size"], shuffle=False, num_workers=cfg["num_workers"], pin_memory=True)

    # Initialize model
    model = timm.create_model(
        "vit_base_patch16_224", 
        pretrained=True, 
        num_classes=14,
        drop_rate=cfg["vit_dropout"]
    )
    
    # Apply LoRA
    lora_config = LoraConfig(
        r=cfg["lora_r"], 
        lora_alpha=cfg["lora_alpha"], 
        target_modules=["qkv", "proj"], 
        lora_dropout=cfg["lora_dropout"], 
        bias="none"
    )
    model = get_peft_model(model, lora_config)
    model.to(DEVICE)
    model.print_trainable_parameters()
    
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg["learning_rate"], weight_decay=cfg["weight_decay"])
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg["epochs"])
    loss_fn = nn.BCEWithLogitsLoss()
    metric = torchmetrics.AUROC(task="multilabel", num_labels=14, average=None).to(DEVICE)
    scaler = torch.amp.GradScaler(enabled=(DEVICE.type == 'cuda'))

    wandb.init(project="ViT-CheXpert-Supervised", name="Standard-ViT-LoRA-Training", config=cfg)
    best_val_auc = 0
    epochs_no_improve = 0
    
    for epoch in range(cfg["epochs"]):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg['epochs']}", unit="batch")
        
        for images, labels in pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            with torch.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda')):
                outputs = model(images)
                loss = loss_fn(outputs, labels)
                
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
                
            total_loss += loss.item()
            pbar.set_postfix({"Loss": loss.item()})
            
        avg_train_loss = total_loss / len(train_loader)

        model.eval()
        metric.reset()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                with torch.autocast(device_type=DEVICE.type, enabled=(DEVICE.type == 'cuda')):
                    outputs = torch.sigmoid(model(images))
                metric.update(outputs, labels.long())
        
        val_aucs = metric.compute()
        mean_val_auc = torch.nanmean(val_aucs).item()
        
        scheduler.step()
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val mAUROC: {mean_val_auc:.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
        
        log_dict = {
            "epoch": epoch, 
            "train_loss": avg_train_loss, 
            "val_mAUC": mean_val_auc, 
            "learning_rate": scheduler.get_last_lr()[0]
        }
        for i, class_name in enumerate(class_names):
            log_dict[f"val_auc_{class_name}"] = val_aucs[i].item()
        wandb.log(log_dict)

        if mean_val_auc > best_val_auc:
            best_val_auc = mean_val_auc
            epochs_no_improve = 0
            
            # Merge the adapters into the base model and save the full state dict
            merged_model = model.merge_and_unload()
            torch.save(merged_model.state_dict(), f"./vit-lora-best-supervised-full.pth")
            
            print(f"New best model saved with mAUROC: {best_val_auc:.4f}")
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= cfg["early_stopping_patience"]:
            print(f"Early stopping triggered after {epochs_no_improve} epochs with no improvement.")
            break

    print("Training finished.")
    wandb.finish()



In [None]:
train_model(config)