In [None]:
import os
import itertools
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import roc_auc_score

import timm

# Define the 15 disease classes
CLASSES = [
    "No Finding", "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration",
    "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation",
    "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia"
]

# Load and preprocess labels
df = pd.read_csv("/student/csc490_project/shared/labels.csv")
df["label_list"] = df["Finding Labels"].apply(lambda x: x.split("|"))

mlb = MultiLabelBinarizer(classes=CLASSES)
labels_array = mlb.fit_transform(df["label_list"])
df["labels"] = list(labels_array)

# Patient-level split into train, val, test
unique_patients = df["Patient ID"].unique()
np.random.seed(42)
np.random.shuffle(unique_patients)

train_end = int(0.7 * len(unique_patients))
val_end = int(0.8 * len(unique_patients))

train_df = df[df["Patient ID"].isin(unique_patients[:train_end])].reset_index(drop=True)
val_df = df[df["Patient ID"].isin(unique_patients[train_end:val_end])].reset_index(drop=True)
test_df = df[df["Patient ID"].isin(unique_patients[val_end:])].reset_index(drop=True)

class ChestXrayDataset(Dataset):
    """Custom Dataset class for Chest X-ray images."""
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.df.iloc[idx]["Image Index"])
        image = Image.open(img_path).convert("L")  # Convert to grayscale
        labels = torch.tensor(self.df.iloc[idx]["labels"], dtype=torch.float)
        if self.transform:
            image = self.transform(image)
        return image, labels

# Define transformations
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define dataloaders
img_dir = "/student/csc490_project/shared/preprocessed_images/preprocessed_images"
train_dataset = ChestXrayDataset(train_df, img_dir, train_transform)
val_dataset = ChestXrayDataset(val_df, img_dir, val_transform)
test_dataset = ChestXrayDataset(test_df, img_dir, val_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# Load models and weights
models = {
    'maxvit': timm.create_model('maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k', pretrained=False, num_classes=15),
    'densenet': timm.create_model('densenet121', pretrained=False, num_classes=15),
    'coatnet': timm.create_model('coatnet_2_rw_224.sw_in12k_ft_in1k', pretrained=False, num_classes=15),
    'vgg19': timm.create_model('vgg19.tv_in1k', pretrained=False, num_classes=15),
    'swin': timm.create_model('swin_large_patch4_window7_224', pretrained=False, num_classes=15),
    'convnext': timm.create_model('convnext_large.fb_in22k', pretrained=False, num_classes=15)
}

# Load pretrained weights
models['maxvit'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k_model.pth'))
models['densenet'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_densenet121_model.pth'))
models['coatnet'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_coatnet_2_rw_224.sw_in12k_ft_in1k_model.pth'))
models['swin'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_swin_large_patch4_window7_224_model.pth'))
models['convnext'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_convnext_large.fb_in22k_model.pth'))
models['vgg19'].load_state_dict(torch.load('/student/csc490_project/shared/new_split_models/no_augment_vgg19.tv_in1k_model.pth'))

# Move models to device and set to eval mode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for model in models.values():
    model.to(device)
    model.eval()

def collect_predictions(loader):
    """
    Collects sigmoid predictions and labels from the provided DataLoader.

    Args:
        loader (DataLoader): DataLoader to get images and labels from.

    Returns:
        Tuple[Dict[str, np.ndarray], np.ndarray]:
            - Dictionary of model name to predictions.
            - Numpy array of true labels.
    """
    all_preds = {name: [] for name in models}
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            for name, model in models.items():
                all_preds[name].append(torch.sigmoid(model(images)).cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    all_preds = {k: np.concatenate(v) for k, v in all_preds.items()}
    all_labels = np.concatenate(all_labels)
    return all_preds, all_labels

# Get predictions
train_preds, train_labels = collect_predictions(train_loader)
val_preds, val_labels = collect_predictions(val_loader)
test_preds, test_labels = collect_predictions(test_loader)

# Evaluate equal-weighted ensemble combinations
results = []

for r in range(2, len(models) + 1):
    for combination in itertools.combinations(models.keys(), r):
        preds_val = [val_preds[model] for model in combination]
        preds_test = [test_preds[model] for model in combination]

        weights = np.ones(len(combination)) / len(combination)

        combined_val = sum(weights[i] * preds_val[i] for i in range(len(combination)))
        combined_test = sum(weights[i] * preds_test[i] for i in range(len(combination)))

        val_auroc = np.mean([
            roc_auc_score(val_labels[:, i], combined_val[:, i]) for i in range(15)
        ])
        test_auroc = np.mean([
            roc_auc_score(test_labels[:, i], combined_test[:, i]) for i in range(15)
        ])

        results.append({
            'combination': combination,
            'weights': weights,
            'val_auroc': val_auroc,
            'test_auroc': test_auroc
        })

# Sort combinations by test AUROC
results = sorted(results, key=lambda x: x['test_auroc'], reverse=True)

# Display results
print("\nAll Combinations (Equal Weights Only):")
for res in results:
    print(f"{res['combination']}: Test AUROC = {res['test_auroc']:.4f}, Val AUROC = {res['val_auroc']:.4f}")

# Print best result
best = results[0]
print(f"\nBest Combination: {best['combination']}")
print(f"Best Weights (equal): {best['weights']}")
print(f"Best Test AUROC: {best['test_auroc']:.4f}")
print(f"Best Val AUROC: {best['val_auroc']:.4f}")


All Combinations (Equal Weights Only):
('maxvit', 'densenet', 'coatnet', 'swin', 'convnext'): Test AUROC = 0.8562, Val AUROC = 0.8435
('maxvit', 'densenet', 'coatnet', 'convnext'): Test AUROC = 0.8557, Val AUROC = 0.8427
('maxvit', 'densenet', 'coatnet', 'vgg19', 'swin', 'convnext'): Test AUROC = 0.8555, Val AUROC = 0.8434
('maxvit', 'densenet', 'coatnet', 'vgg19', 'convnext'): Test AUROC = 0.8551, Val AUROC = 0.8428
('maxvit', 'coatnet', 'swin', 'convnext'): Test AUROC = 0.8550, Val AUROC = 0.8416
('maxvit', 'densenet', 'swin', 'convnext'): Test AUROC = 0.8550, Val AUROC = 0.8421
('maxvit', 'coatnet', 'vgg19', 'swin', 'convnext'): Test AUROC = 0.8544, Val AUROC = 0.8419
('maxvit', 'densenet', 'coatnet', 'swin'): Test AUROC = 0.8544, Val AUROC = 0.8420
('maxvit', 'densenet', 'vgg19', 'swin', 'convnext'): Test AUROC = 0.8542, Val AUROC = 0.8426
('maxvit', 'densenet', 'convnext'): Test AUROC = 0.8541, Val AUROC = 0.8407
('maxvit', 'coatnet', 'convnext'): Test AUROC = 0.8538, Val AUROC =

In [None]:
import numpy as np
import torch
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader
from torchvision import transforms

# ----------------------------------------------------
# Define Test-Time Augmentation (TTA) Transforms
# ----------------------------------------------------
# Each transform represents a different augmentation strategy that will be used during inference

tta_transforms = [
    transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=1.0),  # Always flip
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Slight color variation
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.RandomAffine(degrees=0, translate=(0.02, 0.02)),  # Minor translation
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
]

# ----------------------------------------------------
# Function to Collect TTA Predictions from Models
# ----------------------------------------------------
def collect_tta_predictions(df, root_dir, models, device, tta_transforms, batch_size=16):
    """
    Collect predictions from multiple models using Test-Time Augmentation (TTA).
    
    Args:
        df (pd.DataFrame): DataFrame containing image paths and labels.
        root_dir (str): Path to image directory.
        models (dict): Dictionary of model name to model object.
        device (torch.device): Device to run inference on (e.g., 'cuda' or 'cpu').
        tta_transforms (list): List of torchvision transforms for TTA.
        batch_size (int): Batch size for DataLoader.
    
    Returns:
        all_preds_avg (dict): Dictionary of model name to average TTA predictions [num_samples x num_classes].
        all_labels (np.ndarray): Ground truth labels [num_samples x num_classes].
    """
    all_preds = {name: [] for name in models}  # Store predictions for each model across TTA versions
    all_labels_tta = []  # Labels collected per TTA to ensure consistency

    # Iterate through each TTA transform
    for tform in tta_transforms:
        # Load dataset using the current TTA transform
        loader = DataLoader(
            ChestXrayDataset(df, root_dir, tform),
            batch_size=batch_size,
            shuffle=False,
            num_workers=4
        )

        preds_this_tta = {name: [] for name in models}
        labels_this_tta = []

        # Disable gradient computation for inference
        with torch.no_grad():
            for images, labels in loader:
                images = images.to(device)
                labels = labels.to(device)

                # Run each model on the current TTA-transformed batch
                for name in models:
                    logits = models[name](images)
                    probs = torch.sigmoid(logits).cpu().numpy()  # Convert to probabilities
                    preds_this_tta[name].append(probs)

                labels_this_tta.append(labels.cpu().numpy())

        # Aggregate predictions and labels from this TTA version
        for name in models:
            preds_concat = np.concatenate(preds_this_tta[name], axis=0)
            all_preds[name].append(preds_concat)

        labels_concat = np.concatenate(labels_this_tta, axis=0)
        all_labels_tta.append(labels_concat)

    # Ensure labels are consistent across TTA versions
    for i in range(1, len(all_labels_tta)):
        if all_labels_tta[i].shape != all_labels_tta[0].shape or not np.array_equal(all_labels_tta[i], all_labels_tta[0]):
            raise ValueError("Label mismatch across TTAs — check data consistency.")
    
    all_labels = all_labels_tta[0]

    # Compute average predictions across TTA versions
    all_preds_avg = {}
    for name in models:
        stacked = np.stack(all_preds[name], axis=0)  # Shape: [num_ttas, num_samples, num_classes]
        all_preds_avg[name] = np.mean(stacked, axis=0)  # Average across TTA runs

    return all_preds_avg, all_labels

# ----------------------------------------------------
# Run TTA-Based Ensemble Inference
# ----------------------------------------------------

# List of best performing models to use for ensemble
best_models = ['maxvit', 'densenet', 'coatnet', 'swin', 'convnext']

# Assign equal weights for averaging predictions
equal_weights = np.array([0.2] * len(best_models))

# Select only the models specified in `best_models`
selected_models = {name: models[name] for name in best_models}

# Get averaged TTA predictions and labels
test_preds_tta, test_labels_tta = collect_tta_predictions(
    df=test_df,
    root_dir=img_dir,
    models=selected_models,
    device=device,
    tta_transforms=tta_transforms
)

# Compute final ensemble predictions by weighted average
ensemble_test_preds_tta = sum(
    w * test_preds_tta[name] for w, name in zip(equal_weights, best_models)
)

# Compute AUROC for each class and overall mean
per_class_aurocs_tta = [
    roc_auc_score(test_labels_tta[:, i], ensemble_test_preds_tta[:, i])
    for i in range(test_labels_tta.shape[1])
]
mean_test_auroc_tta = np.mean(per_class_aurocs_tta)

# ----------------------------------------------------
# Print TTA Evaluation Results
# ----------------------------------------------------
print("\nTTA-Based Ensemble Evaluation")
print(f"Used Models: {best_models}")
print(f"Equal Weights: {equal_weights}")
print(f"\nMean Test AUROC (TTA): {mean_test_auroc_tta:.4f}")
print("\nPer-Class Test AUROC (TTA):")
for cls, auc in zip(CLASSES, per_class_aurocs_tta):
    print(f"{cls}: {auc:.4f}")


TTA-Based Ensemble Evaluation
Used Models: ['maxvit', 'densenet', 'coatnet', 'swin', 'convnext']
Equal Weights: [0.2 0.2 0.2 0.2 0.2]

Mean Test AUROC (TTA): 0.8571

Per-Class Test AUROC (TTA):
No Finding: 0.8018
Atelectasis: 0.8369
Cardiomegaly: 0.9168
Effusion: 0.8953
Infiltration: 0.7375
Mass: 0.8829
Nodule: 0.8105
Pneumonia: 0.7887
Pneumothorax: 0.8962
Consolidation: 0.8200
Edema: 0.9157
Emphysema: 0.9445
Fibrosis: 0.8474
Pleural_Thickening: 0.8447
Hernia: 0.9185
