In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score

In [None]:
# -------------------------------
# CONFIG
# -------------------------------
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_CLASSES = 14
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TRAIN_CSV = "data/train_cheXbert.csv"
VALID_CSV = "data/valid.csv"
IMG_FOLDER = "data"

In [None]:
# -------------------------------
# LABELS
# -------------------------------
CHEXPERT_LABELS = [
    "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Enlarged Cardiomediastinum",
    "Fracture", "Lung Lesion", "Lung Opacity", "No Finding", "Pleural Effusion",
    "Pleural Other", "Pneumonia", "Pneumothorax", "Support Devices"
]


In [None]:
# -------------------------------
# DATASET
# -------------------------------
class CheXpertDataset(Dataset):
    def __init__(self, csv_path, img_folder, transform=None):
        self.df = pd.read_csv(csv_path)
        self.img_folder = img_folder
        self.transform = transform

        # Replace NaNs and uncertain labels (-1) with 0
        self.df[CHEXPERT_LABELS] = self.df[CHEXPERT_LABELS].fillna(0).replace(-1, 0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_folder, row['Path'])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        labels = torch.tensor(row[CHEXPERT_LABELS].values, dtype=torch.float32)
        return image, labels

In [None]:
# -------------------------------
# MODEL
# -------------------------------
class CheXpertModel(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(CheXpertModel, self).__init__()
        base = models.densenet121(pretrained=True)
        in_features = base.classifier.in_features
        base.classifier = nn.Linear(in_features, num_classes)
        self.backbone = base

    def forward(self, x):
        return self.backbone(x)

In [None]:
# -------------------------------
# METRICS
# -------------------------------
def compute_auc(y_true, y_pred):
    aucs = []
    for i in range(y_true.shape[1]):
        try:
            auc = roc_auc_score(y_true[:, i], y_pred[:, i])
            aucs.append(auc)
        except:
            aucs.append(np.nan)
    return np.nanmean(aucs)

In [None]:
# -------------------------------
# TRAINING
# -------------------------------
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(loader):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

def validate(model, loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(DEVICE)
            outputs = torch.sigmoid(model(imgs)).cpu().numpy()
            all_preds.append(outputs)
            all_labels.append(labels.numpy())
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    return compute_auc(all_labels, all_preds)

In [None]:
# Transforms
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
])

# Load data
train_ds = CheXpertDataset(TRAIN_CSV, IMG_FOLDER, transform=transform)
valid_ds = CheXpertDataset(VALID_CSV, IMG_FOLDER, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Model, optimizer, loss
model = CheXpertModel().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion)
    val_auc = validate(model, valid_loader)
    print(f"Train Loss: {train_loss:.4f} | Val AUC: {val_auc:.4f}")

    # Save checkpoint
    torch.save(model.state_dict(), f"checkpoint_epoch{epoch+1}.pth")