In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import mlflow

from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score, average_precision_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models
from torchvision.transforms import v2
from datasets import load_dataset, Image as HFImage, ClassLabel, Sequence
import cv2
import timm
from datetime import datetime
import os
from tqdm import tqdm



# Configurations

In [None]:
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

training_params = {
    "batch_size": 16,
    "num_epochs": 100,
    "learning_rate": 1e-3,
    "weight_decay": 1e-4
}

# Class Blueprints

## Dataset related

In [None]:
class ApplyHE:
    def __init__(self):
        pass

    def __call__(self, img):
        img_np = np.array(img)

        # Convert to grayscale and ensure 2D if not already
        if len(img_np.shape) == 3: # Image has channels
            if img_np.shape[2] == 4: # RGBA to GRAY
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)
            elif img_np.shape[2] == 3: # RGB to GRAY
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            elif img_np.shape[2] == 1: # Grayscale (H,W,1) to 2D (H,W)
                img_np = img_np.squeeze(axis=2) 
            else:
                # Handle other unexpected number of channels
                raise ValueError(f"Unsupported number of channels: {img_np.shape[2]} for image of shape {img_np.shape}")
        elif len(img_np.shape) != 2: # Not 2D (grayscale) or 3D (color/grayscale with 1 channel)
             raise ValueError(f"Unsupported image shape dimensions: {img_np.shape}")

        # Ensure the image is 8-bit unsigned integer type. 
        # np.array(PIL Image) typically returns uint8, but good to be explicit.
        if img_np.dtype != np.uint8:
            img_np = img_np.astype(np.uint8)

        eq_img = cv2.equalizeHist(img_np)

        eq_img_rgb = cv2.cvtColor(eq_img, cv2.COLOR_GRAY2RGB)
        return Image.fromarray(eq_img_rgb)

class ApplyCLAHE:
    def __init__(self, clip_limit=2.0, tile_grid_size=(8, 8)):
        self.clip_limit = clip_limit
        self.tile_grid_size = tile_grid_size

    def __call__(self, img):
        img_np = np.array(img)

        # Convert to grayscale and ensure 2D if not already
        if len(img_np.shape) == 3: # Image has channels
            if img_np.shape[2] == 4: # RGBA to GRAY
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2GRAY)
            elif img_np.shape[2] == 3: # RGB to GRAY
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
            elif img_np.shape[2] == 1: # Grayscale (H,W,1) to 2D (H,W)
                img_np = img_np.squeeze(axis=2)
            else:
                # Handle other unexpected number of channels
                raise ValueError(f"Unsupported number of channels: {img_np.shape[2]} for image of shape {img_np.shape}")
        elif len(img_np.shape) != 2: # Not 2D (grayscale) or 3D (color/grayscale with 1 channel)
             raise ValueError(f"Unsupported image shape dimensions: {img_np.shape}")

        # Ensure the image is 8-bit unsigned integer type.
        if img_np.dtype != np.uint8:
            img_np = img_np.astype(np.uint8)

        clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=self.tile_grid_size)
        clahe_img = clahe.apply(img_np)

        clahe_img_rgb = cv2.cvtColor(clahe_img, cv2.COLOR_GRAY2RGB)
        return Image.fromarray(clahe_img_rgb)

In [None]:
class NIHChestXrayDataset(Dataset):
    def __init__(self, hf_dataset, transform=None, num_classes=15):
        """
        Args:
            hf_dataset (Dataset): The loaded Hugging Face dataset split (e.g., 'train').
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.hf_dataset = hf_dataset
        self.transform = transform
        self.num_classes = num_classes

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        # 1. Get the item from the Hugging Face dataset
        item = self.hf_dataset[idx]
        image = item['image']
        if self.transform:
            image = self.transform(image)
        
        labels = item['labels']
        
        # Initialize a zero vector for 15 classes
        label_tensor = torch.zeros(self.num_classes, dtype=torch.float32)
        for label_idx in labels:
            label_tensor[label_idx] = 1.0

        return image, label_tensor

In [119]:
class FocalLoss(nn.Module):

    def __init__(self, alpha=0.75, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)

        alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)

        focal = alpha_t * (1 - p_t).pow(self.gamma) * bce

        if self.reduction == 'mean':
            return focal.mean()
        elif self.reduction == 'sum':
            return focal.sum()
        else:
            return focal

## Model related

In [99]:
class ResNet50(nn.Module):

    def __init__(self, num_classes=15):
        super().__init__()
        self.model = models.resnet50(weights="DEFAULT")
        self.num_classes = num_classes
        for param in self.model.parameters():
            param.requires_grad = False
        in_feat = self.model.fc.in_features
        self.model.fc = nn.Linear(in_feat, num_classes)

    def forward(self, x):
        return self.model(x)
        

In [102]:
class NextVit(nn.Module):

    def __init__(self, num_classes=15):
        super().__init__()
        # Load the model with its original head
        self.model = timm.create_model(
            "nextvit_base.bd_in1k",
            pretrained=True
        )
        
        # Freeze the backbone
        for param in self.model.parameters():
            param.requires_grad = False
            
        # Get the number of input features from the model's head
        in_feat = self.model.head.fc.in_features
        
        # Create a new trainable head
        self.model.head.fc = nn.Linear(in_feat, num_classes)

    def forward(self, x):
        return self.model(x)

In [105]:
class DINOv3(nn.Module):

    def __init__(self, num_classes=15):
        super().__init__()
        self.model = torch.hub.load("facebookresearch/dinov3",
                                    "dinov3_vitb16",
                                    source="github", 
                                    weights="../checkpoints/dinov3_vitb16_pretrain.pth")
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.head = torch.nn.Linear(768, num_classes)

    def forward(self, x):
        return self.model(x)

# Processing Dataset

In [76]:
# baseline transformations
training_tfms = v2.Compose([
    v2.Resize((224, 224)),
    ApplyCLAHE(clip_limit=2.0, tile_grid_size=(8, 8)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

validation_tfms = v2.Compose([
    v2.Resize((224, 224)),
    ApplyCLAHE(clip_limit=2.0, tile_grid_size=(8, 8)),
    v2.ToTensor(),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])




In [None]:
dataset = load_dataset("Sohaibsoussi/NIH-Chest-X-ray-dataset-small")
class_names = dataset['train'].features['labels'].feature.names
class_to_idx = {class_name: i for i, class_name in enumerate(class_names)}

training_ds = NIHChestXrayDataset(dataset['train'], transform=training_tfms)
val_ds = NIHChestXrayDataset(dataset['validation'], transform=validation_tfms)
test_ds = NIHChestXrayDataset(dataset['test'], transform=validation_tfms)

train_loader = DataLoader(training_ds, batch_size=training_params['batch_size'], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=training_params['batch_size'], shuffle=False)
test_loader = DataLoader(test_ds, batch_size=training_params['batch_size'], shuffle=False)

# Helper functions

In [None]:
def initialize_optimizer(optimizer_name, parameters, learning_rate):
    """Returns the optimizer based on the given name."""
    if optimizer_name == "Adam":
        return optim.Adam(parameters, lr=learning_rate)
    elif optimizer_name == "AdamW":
        return optim.AdamW(parameters, lr=learning_rate, weight_decay=training_params["weight_decay"])
    elif optimizer_name == "SGD":
        return optim.SGD(parameters, lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_name}")

def initialize_model(model_name):
    """Returns the model based on the given name."""
    if model_name.lower() == "resnet":
        model = ResNet50().to(config['device'])
    elif model_name.lower() == "nextvit":
        model = NextVit().to(config['device'])
    elif model_name.lower() == 'dinov3':
        model = DINOv3().to(config['device'])
    else:
        raise ValueError(f"Model {model_name} not recognized. Use 'resnet', 'nextvit or 'dinov3'")
    return model

def initialize_criterion(loss_fn='bce', pos_weight=None):
    """Returns the loss function based on the given name."""
    if loss_fn == 'bce':
        if pos_weight:
            return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        return nn.BCEWithLogitsLoss()
    elif loss_fn == 'focal loss':
        return FocalLoss()
    else:
        raise ValueError(f"Use only 'bce' with or without pos_weight or 'focal loss' only")

In [None]:
def train_epoch(train_loader, model, criterion, optimizer, device):
    model.train()
    losses = []

    lr = optimizer.param_groups[0]['lr']
    print(f"Current LR: {lr:.5f}")
    
    for feat, labels in tqdm(train_loader, desc='Training'):
        feat, labels = feat.to(device), labels.to(device)
        optimizer.zero_grad()
        preds = model(feat)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    return float(np.mean(losses))


In [None]:
def evaluate(val_loader, model, criterion, device):
    model.eval()
    with torch.no_grad():
        all_probs = []
        all_labels = []
        losses = []

        for features, labels in tqdm(val_loader, desc='Validation'):
            features, labels = features.to(device), labels.to(device)
            outputs = model(features)

            loss = criterion(outputs, labels)
            losses.append(loss.item())

            probs = torch.sigmoid(outputs).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(labels.cpu().numpy())

        y_prob = np.concatenate(all_probs, axis=0)
        y_true = np.concatenate(all_labels, axis=0)
        y_pred = (y_prob > 0.5).astype(int)

        pr, rc, f1, _ = precision_recall_fscore_support(
            y_true.ravel(), y_pred.ravel(), average='micro', zero_division=0
        )
        pr_auc = average_precision_score(y_true, y_prob, average='micro')
        roc_auc = roc_auc_score(y_true, y_prob, average='micro')
        avg_loss = np.mean(losses)
    return losses, avg_loss, {"precision": pr, "recall": rc, "f1": f1, "pr_auc": pr_auc, "roc_auc": roc_auc}

In [115]:
def early_stopping(metrics_in_epochs, gap):
    best_metric_idx = np.argmin(metrics_in_epochs)
    if (len(metrics_in_epochs) - best_metric_idx >= gap):
        return True
    else:
        return False

In [None]:
def train_modelcv(train_loader, val_loader, model, criterion, optimizer, num_epochs, device, scheduler=None):
    train_losses = []
    val_losses = []
    val_avg_loss = []
    val_measure = []
    lowest_loss = float('inf')
    best_metrics = None
    best_epoch = -1

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs-1}")
        print("-" * 10)

        train_loss = train_epoch(train_loader, model, criterion, optimizer, device)
        train_losses.append(train_loss)

        val_loss, avg_loss, metrics = evaluate(val_loader, model, criterion, device)
        val_losses.append(val_loss)

        print("Performance ROC-AUC: ", metrics['roc_auc'])
        val_avg_loss.append(avg_loss)
        val_measure.append(metrics['roc_auc'])

        if scheduler:
            scheduler.step(avg_loss)

        # save best model
        if avg_loss < lowest_loss:
            best_weights = model.state_dict()
            lowest_loss = avg_loss
            best_metrics = metrics
            best_epoch = epoch
            print(f"Current Best is epoch {best_epoch} with loss: {lowest_loss}")

        if early_stopping(val_avg_loss, 10):
            print("Early stopping triggered.")
            break

    return best_weights, best_metrics, best_epoch, train_losses, val_losses

In [None]:
def plot_train_val_loss(train_loss, val_loss, save=False, show=False):
    plt.plot(train_loss, label='train loss')
    plt.plot(val_loss, label='val loss')
    plt.title(f"Train and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    if show:
        plt.show()

    if save:
        plot_path = f"loss_curve.png"
        plt.savefig(plot_path)
        plt.close()
        return plot_path
    else:
        plt.close()
        return None

# Training

In [None]:
mlflow.set_experiment('CXR')

In [None]:
model_list = ['resnet', 'nextvit', 'dinov3']

for model_name in model_list:
    run_name = f"{model_name}_AdamW_{training_params['learning_rate']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    with mlflow.start_run(run_name=run_name) as run:
        mlflow.log_params(training_params)
        mlflow.log_param("model", model_name)

        model = initialize_model(model_name)
        criterion = initialize_criterion('bce')
        optimizer = initialize_optimizer('AdamW', model.parameters(), training_params['learning_rate'])

        best_weights, best_metrics, _, train_losses, val_losses = train_modelcv(
            train_loader,
            val_loader,
            model,
            criterion,
            optimizer,
            num_epochs=training_params['num_epochs'],
            device=config['device']
        )
        mlflow.log_metrics(best_metrics)

        # saving loss plots as local file to log as artifact, then remove local file
        plot_path = plot_train_val_loss(train_losses, val_losses, save=True, show=False)
        mlflow.log_artifact(plot_path)
        os.remove(plot_path)

        # run inference
        print("=" * 10)
        print("\n\n")
        print("===RUNNING INFERENCE ON TEST SET===")

        model.load_state_dict(best_weights)
        _, _, test_metrics = evaluate(test_loader, model, criterion, device=config['device'])
        mlflow.log_metrics({
            'test_precision': test_metrics['precision'],
            'test_recall': test_metrics['recall'],
            'test_f1': test_metrics['f1'],
            'test_prauc': test_metrics['pr_auc'],
            'test_rocauc': test_metrics['roc_auc']
        })

        print("Inference results:\n", test_metrics)

        