In [None]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [2]:
import logging
import os
import random
from datetime import datetime

import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler

from open_clip import get_tokenizer
from open_clip.transform import image_transform_v2, PreprocessCfg
import webdataset as wds
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import Imagenette, ImageFolder, DTD
import torch.nn as nn

from open_clip import create_model_from_pretrained, get_tokenizer
from src.clipn import CLIPNAdapter, AltCLIPNAdapter

In [3]:
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
checkpoint = torch.load("/repo/ALTMEDCLIPN/checkpoints/PMC_model_cosine/latest_checkpoint.pth", weights_only=False)["model_state_dict"]

model = AltCLIPNAdapter(
    model,
    tokenizer,
    "text.transformer.encoder"
)

model.load_state_dict(checkpoint)

<All keys matched successfully>

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
import pandas as pd
from PIL import Image, ImageFile

# Allows loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Load the test split of the dataset
full_dataset = load_dataset(
    "alkzar90/NIH-Chest-X-ray-dataset", 
    "image-classification", 
    # cache_dir="/repo/MEDCLIPN/data/NIH_chest", 
    split="test"
)

print("Dataset loaded successfully:")
print(full_dataset)

Dataset loaded successfully:
Dataset({
    features: ['image', 'labels'],
    num_rows: 23828
})


In [5]:
# Get the names of the pathology labels from the dataset's features
pathology_names = full_dataset.features['labels'].feature.names

# Convert the dataset to a pandas DataFrame
df = full_dataset.to_pandas()

# CORRECTED LOGIC:
# Create a binary column for each pathology.
# A row gets a '1' if the pathology's index is in the 'labels' list, and '0' otherwise.
for i, name in enumerate(pathology_names):
    df[name] = df['labels'].apply(lambda label_indices: 1 if i in label_indices else 0)

# Drop the original 'labels' column as it's no longer needed
df = df.drop(columns=['labels'])

print("DataFrame created with correct pathology columns:")
print(df.head())

DataFrame created with correct pathology columns:
                                               image  No Finding  Atelectasis  \
0  {'bytes': None, 'path': '/root/.cache/huggingf...           0            0   
1  {'bytes': None, 'path': '/root/.cache/huggingf...           0            0   
2  {'bytes': None, 'path': '/root/.cache/huggingf...           0            0   
3  {'bytes': None, 'path': '/root/.cache/huggingf...           0            0   
4  {'bytes': None, 'path': '/root/.cache/huggingf...           0            0   

   Cardiomegaly  Effusion  Infiltration  Mass  Nodule  Pneumonia  \
0             0         0             0     0       0          0   
1             0         0             0     0       0          0   
2             0         0             0     0       0          0   
3             0         0             1     0       0          0   
4             0         0             0     0       0          0   

   Pneumothorax  Consolidation  Edema  Emphysema  Fibr

In [6]:
# Use the pathology names from the dataset, excluding 'No Finding' for the split
NIH_PATHOLOGIES = pathology_names

# First 7 pathologies for in-distribution (ID)
ID_PATHOLOGIES = NIH_PATHOLOGIES[:10]
# Remaining pathologies for out-of-distribution (OOD)
OOD_PATHOLOGIES = NIH_PATHOLOGIES[10:]

print("In-Distribution Pathologies:")
print(ID_PATHOLOGIES)
print("\nOut-of-Distribution Pathologies:")
print(OOD_PATHOLOGIES)


In-Distribution Pathologies:
['No Finding', 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation']

Out-of-Distribution Pathologies:
['Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia']


In [7]:
# Create the in-distribution (ID) dataframe
# Condition: No OOD pathologies are present (value is not 1)
ood_clean = ~df[OOD_PATHOLOGIES].eq(1).any(axis=1)
id_df = df[ood_clean].reset_index(drop=True)

# Create the out-of-distribution (OOD) dataframe
# Condition 1: At least one OOD pathology is present (value is 1)
ood_positive = df[OOD_PATHOLOGIES].eq(1).any(axis=1)
# Condition 2: No ID pathologies are present (value is not 1)
id_clean = ~df[ID_PATHOLOGIES].eq(1).any(axis=1)
ood_df = df[ood_positive & id_clean].reset_index(drop=True)

print(f"Original test set size: {len(df)}")
print(f"In-distribution (ID) set size: {len(id_df)}")
print(f"Out-of-distribution (OOD) set size: {len(ood_df)}")


Original test set size: 23828
In-distribution (ID) set size: 20599
Out-of-distribution (OOD) set size: 1044


In [8]:
class NIHChestXrayDataset(Dataset):
    def __init__(self, dataframe, pathologies, transform=None):
        self.transform = transform
        self.pathologies = pathologies
        self.classes = pathologies
        self.image_paths = dataframe['image'].apply(lambda x: x['path']).tolist()
        labels_df = dataframe[self.pathologies]
        self.labels = torch.tensor(labels_df.values, dtype=torch.float32)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.image_paths[idx]

        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            return self.__getitem__((idx + 1) % len(self))

        if self.transform:
            image = self.transform(image)

        labels = self.labels[idx]

        return image, labels


In [9]:
# Create in-distribution dataset (id_dataset)
chest_in_dataset = NIHChestXrayDataset(
    dataframe=id_df,
    pathologies=ID_PATHOLOGIES,
    transform=preprocess
)

# Create out-of-distribution dataset (ood_dataset)
chest_out_dataset = NIHChestXrayDataset(
    dataframe=ood_df,
    pathologies=OOD_PATHOLOGIES,
    transform=preprocess
)

# Create data loaders
chest_in_loader = DataLoader(
    chest_in_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=1,
    pin_memory=True
)

chest_out_loader = DataLoader(
    chest_out_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=1,
    pin_memory=True
)

print("\nDatasets and DataLoaders created successfully:")
print(f"ID Dataset: {chest_in_dataset}")
print(f"OOD Dataset: {chest_out_dataset}")


Datasets and DataLoaders created successfully:
ID Dataset: <__main__.NIHChestXrayDataset object at 0x7ff2d02b1fc0>
OOD Dataset: <__main__.NIHChestXrayDataset object at 0x7ff2d0293d30>


In [10]:
# frac_atlas_dataset = ImageFolder("/repo/MEDCLIPN/data/FracAtlas_dataset/images", transform=preprocess)
# isic_dataset = ImageFolder("/repo/MEDCLIPN/data/ISIC2020_test_dataset", transform=preprocess)

# frac_atlas_dataloader = DataLoader(
#     frac_atlas_dataset,
#     batch_size = 32,
#     shuffle = False,
#     num_workers = 1,
#     pin_memory = True
# )

# isic_dataloader = DataLoader(
#     isic_dataset,
#     batch_size = 32,
#     shuffle = False,
#     num_workers=1,
#     pin_memory = True
# )

ood_loaders = {
    "chest_xray": chest_out_loader,
    # "frac_atlas": frac_atlas_dataloader,
    # "isic2020": isic_dataloader
}

In [11]:
def merge_yes_no_feature(dataset, model, tokenizer, device):
    txt = []
    N = len(dataset.classes)
    model.to(device)
    model.eval()
    if N:
        with open("/repo/MEDCLIPN/src/prompt/med_prompt.txt") as f:
            prompt_lis = f.readlines()
        num_prom = len(prompt_lis)
    for idx in range(num_prom):
        for name in dataset.classes:
            txt.append(tokenizer(prompt_lis[idx].replace("\n", "").format(name)).unsqueeze(0))
    txt = torch.cat(txt, dim=0)
    txt = txt.reshape(num_prom, len(dataset.classes), -1)
    text_inputs = txt.to(device)
    
    text_yes_ttl = torch.zeros(len(dataset.classes), 512).to(device)
    text_no_ttl = torch.zeros(len(dataset.classes), 512).to(device)
    
    with torch.no_grad():
        for i in range(num_prom):
            text_yes_i = model.encode_text(text_inputs[i], normalize=True)
            text_no_i = model.encode_text_no(text_inputs[i])
            text_no_i = F.normalize(text_no_i, dim=-1)
            
            text_yes_ttl += text_yes_i
            text_no_ttl += text_no_i
            
    return F.normalize(text_yes_ttl, dim=-1), F.normalize(text_no_ttl, dim=-1)

class ViT_Classifier(torch.nn.Module):
    def __init__(self, image_encoder, classification_head_yes, classification_head_no):
        super().__init__()
        self.image_encoder = image_encoder
        flag = True
        self.fc_yes = nn.Parameter(classification_head_yes, requires_grad=flag)    # num_classes  num_feat_dimension
        self.fc_no = nn.Parameter(classification_head_no, requires_grad=flag)      # num_classes  num_feat_dimension
        self.scale = 100. # this is from the parameter of logit scale in CLIPN
        
    def set_frozen(self, module):
        for module_name in module.named_parameters():
            module_name[1].requires_grad = False
    def set_learnable(self, module):
        for module_name in module.named_parameters():
            module_name[1].requires_grad = True
            
    def forward(self, x):
        inputs = self.image_encoder(x)
        inputs_norm = F.normalize(inputs, dim=-1)
        fc_yes = F.normalize(self.fc_yes, dim=-1)
        fc_no = F.normalize(self.fc_no, dim=-1)
        
        logits_yes = self.scale * inputs_norm @ fc_yes.T 
        logits_no = self.scale * inputs_norm @ fc_no.T
        return logits_yes, logits_no, inputs

In [12]:
yes_fff, no_fff = merge_yes_no_feature(chest_in_dataset, model, model.tokenizer, "cuda")
clipn_classifier = ViT_Classifier(model.visual, yes_fff, no_fff)
clipn_classifier.fc_yes.requires_grad = False
clipn_classifier.fc_no.requires_grad = False

In [None]:
from tqdm import tqdm
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from scipy import interpolate

def maybe_dictionarize(batch):
    if isinstance(batch, dict):
        return batch

    if len(batch) == 2:
        batch = {'images': batch[0], 'labels': batch[1]}
    elif len(batch) == 3:
        batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
    else:
        raise ValueError(f'Unexpected number of elements: {len(batch)}')

    return batch

to_np = lambda x: x.detach().cpu().numpy()
def max_logit_score(logits):
    return to_np(torch.max(logits, -1)[0])
def msp_score(logits):
    prob = torch.softmax(logits, -1)
    return to_np(torch.max(prob, -1)[0])
def energy_score(logits):
    return to_np(torch.logsumexp(logits, -1))


def cal_all_metric(id_dataset, model, epoch, ood_dataset=None, flag=True):
    model.eval()
    
    # ADDED: Make code device-agnostic (works on CPU or GPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # --- In-Distribution (ID) Data Evaluation ---
    id_logits_all = []
    id_labels_all = []
    
    ind_scores = {
        "MaxLogit": [],
        "MSP": [],
        "Energy": [],
        "CTW": [],
        "ATD": []
    }

    with torch.no_grad():
        for batch in tqdm(id_dataset, desc="Evaluating on ID data"):
            batch = maybe_dictionarize(batch)
            inputs = batch["images"].to(device)
            labels = batch['labels'].to(device)
            
            logits, logits_no, _ = model(inputs)
            
            # CHANGED: Collect raw logits and labels for proper multi-label evaluation
            id_logits_all.append(logits.detach().cpu())
            id_labels_all.append(labels.detach().cpu())

            # Collect OOD detection scores for ID samples
            ind_scores["MaxLogit"].extend(max_logit_score(logits))
            ind_scores["MSP"].extend(msp_score(logits))
            ind_scores["Energy"].extend(energy_score(logits))
            
            if flag:
                idex = torch.argmax(logits, -1).unsqueeze(-1)
                yesno = torch.softmax(torch.cat([logits.unsqueeze(-1), logits_no.unsqueeze(-1)], -1), dim=-1)[:, :, 0]
                yesno_s = torch.gather(yesno, dim=1, index=idex)
                ind_scores["CTW"].extend(to_np(yesno_s))
                ind_scores["ATD"].extend(to_np((yesno * torch.softmax(logits, -1)).sum(1)))

    # CHANGED: Concatenate all batch results once for efficiency
    id_logits_all = torch.cat(id_logits_all)
    id_labels_all = torch.cat(id_labels_all)

    # CHANGED: Calculate a proper multi-label classification metric (Area Under ROC Curve)
    # This correctly evaluates the model's ability to identify each of the multiple possible pathologies.
    id_probs = torch.sigmoid(id_logits_all)
    # Using 'micro' averaging treats each sample-label pair as an individual prediction, suitable for this task.
    id_auc_micro = roc_auc_score(to_np(id_labels_all), to_np(id_probs), average='micro')
    
    id_lis_epoch = [[epoch, id_auc_micro]]
    
    # --- Out-of-Distribution (OOD) Data Evaluation ---
    ood_lis_epoch = []
    if ood_dataset:
        for name, ood_data in ood_dataset.items():
            ood_scores = {
                "MaxLogit": [],
                "MSP": [],
                "Energy": [],
                "CTW": [],
                "ATD": []
            }
            with torch.no_grad():
                for batch in tqdm(ood_data, desc=f"Evaluating on OOD data: {name}"):
                    batch = maybe_dictionarize(batch)
                    inputs = batch["images"].to(device)
                    
                    logits, logits_no, _ = model(inputs)

                    ood_scores["MaxLogit"].extend(max_logit_score(logits))
                    ood_scores["MSP"].extend(msp_score(logits))
                    ood_scores["Energy"].extend(energy_score(logits))
            
                    if flag:
                        idex = torch.argmax(logits, -1).unsqueeze(-1)
                        yesno = torch.softmax(torch.cat([logits.unsqueeze(-1), logits_no.unsqueeze(-1)], -1), dim=-1)[:, :, 0]
                        yesno_s = torch.gather(yesno, dim=1, index=idex)
                        ood_scores["CTW"].extend(to_np(yesno_s))
                        ood_scores["ATD"].extend(to_np((yesno * torch.softmax(logits, -1)).sum(1)))

            # Calculate OOD detection metrics
            score_map = {"MSP": "MSP", "MaxLogit": "MaxLogit", "Energy": "Energy", "CTW": "CTW", "ATD": "ATD"}
            for score_key, report_name in score_map.items():
                if not flag and score_key in ["CTW", "ATD"]:
                    continue
                
                auc, fpr = cal_auc_fpr(ind_scores[score_key], ood_scores[score_key])
                ood_lis_epoch.append([epoch, report_name, name, auc, fpr])

    print("\n--- Evaluation Results ---")
    print(f"Epoch {epoch} In-Distribution (ID) Micro-AUC: {id_lis_epoch[0][1]:.4f}")
    for lis in ood_lis_epoch:
        print(f"Epoch {lis[0]}, Method: {lis[1]:<10}, OOD Set: {lis[2]:<10}, AUROC: {lis[3]:.4f}, FPR@95TPR: {lis[4]:.4f}")
        
    return id_lis_epoch, ood_lis_epoch


def cal_auc_fpr(ind_conf, ood_conf):
    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))
    auroc = metrics.roc_auc_score(ind_indicator, conf)
    fpr,tpr,thresh = roc_curve(ind_indicator, conf, pos_label=1)
    fpr = float(interpolate.interp1d(tpr, fpr)(0.95))
    return auroc, fpr

def cal_fpr_recall(ind_conf, ood_conf, tpr=0.95):
    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))
    fpr,tpr,thresh = roc_curve(ind_indicator, conf, pos_label=1)
    fpr = float(interpolate.interp1d(tpr, fpr)(0.95))
    return fpr, thresh

In [14]:
res = cal_all_metric(chest_in_loader, clipn_classifier, 1, ood_loaders)

Evaluating on ID data: 100%|██████████| 644/644 [04:35<00:00,  2.33it/s]
Evaluating on OOD data: chest_xray: 100%|██████████| 33/33 [00:13<00:00,  2.38it/s]


--- Evaluation Results ---
Epoch 1 In-Distribution (ID) Micro-AUC: 0.4188
Epoch 1, Method: MSP       , OOD Set: chest_xray, AUROC: 0.4814, FPR@95TPR: 0.9483
Epoch 1, Method: MaxLogit  , OOD Set: chest_xray, AUROC: 0.5001, FPR@95TPR: 0.9454
Epoch 1, Method: Energy    , OOD Set: chest_xray, AUROC: 0.5017, FPR@95TPR: 0.9444
Epoch 1, Method: CTW       , OOD Set: chest_xray, AUROC: 0.5033, FPR@95TPR: 0.9470
Epoch 1, Method: ATD       , OOD Set: chest_xray, AUROC: 0.5134, FPR@95TPR: 0.9377





In [15]:
import pandas as pd

def create_comparison_table(results):
    df = pd.DataFrame(results, columns=['epoch', 'method', 'dataset', 'auroc', 'fpr95'])
    
    # Create multi-index table
    metrics_table = df.pivot_table(
        index='method',
        columns='dataset',
        values=['auroc', 'fpr95'],
        aggfunc='first'
    )
    
    # Sort methods in meaningful order
    method_order = ['MSP', 'MaxLogit', 'Energy', 'CTW', 'ATD']
    metrics_table = metrics_table.reindex(method_order)
    
    # Format numbers and column names
    metrics_table = metrics_table.round(2)
    metrics_table.columns = pd.MultiIndex.from_tuples([
        (f"{metric.upper()} ({dataset})" if dataset else metric.upper(),)
        for metric, dataset in metrics_table.columns
    ])
    
    # Apply styling
    styled_df = metrics_table.style \
        .set_caption("OOD Detection Performance Comparison") \
        .format("{:.2f}") \
        .set_properties(**{'text-align': 'center'}) \
        .set_table_styles([{
            'selector': 'caption',
            'props': [('font-size', '16px'), ('font-weight', 'bold')]
        }])
    
    return styled_df

create_comparison_table(res[1])

Unnamed: 0_level_0,"('AUROC (chest_xray)',)","('FPR95 (chest_xray)',)"
method,Unnamed: 1_level_1,Unnamed: 2_level_1
MSP,0.48,0.95
MaxLogit,0.5,0.95
Energy,0.5,0.94
CTW,0.5,0.95
ATD,0.51,0.94
