# Running External Validation

In [1]:
import json
import torch
import numpy as np
from timm import create_model
import sys, os, json
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve, auc, roc_curve
from scipy.special import expit #Note: this is a stable implementation of sigmoid
from torchvision import transforms
from sklearn.metrics import confusion_matrix
import pandas as pd
from torch.utils.data import Dataset, DataLoader

from run_external_validation import load_dataset, PercentileDomainAdaptation, XrayDataset

seed_value = 9999 # setting seed for bootstrapping reproducibility

In [2]:
choiPath = '/extra/xielab0/nhchoi1/xrays/'
with open(choiPath+"grid_winners-weights.json", "r") as file:
    data = json.load(file)

print("Model results:", len(data))

Model results: 12


In [3]:
data[0]

{'grid_idx': 6,
 'name': 'convnextv2_nano',
 'weight_decay': 2.5,
 'mean': 0.8381306010457822,
 'std': 0.09031020512304658,
 'logits': 'weight_decay_large/convnextv2_nano/objects/logits/0-fold-logits.pth',
 'weights': ['weight_decay_large/convnextv2_nano/objects/checkpoints/0-fold-0-state.pth',
  'weight_decay_large/convnextv2_nano/objects/checkpoints/0-fold-1-state.pth',
  'weight_decay_large/convnextv2_nano/objects/checkpoints/0-fold-2-state.pth',
  'weight_decay_large/convnextv2_nano/objects/checkpoints/0-fold-3-state.pth',
  'weight_decay_large/convnextv2_nano/objects/checkpoints/0-fold-4-state.pth']}

In [4]:
file_paths = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/file_paths.npy")
labels = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/labels.npy")
trPercentiles = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/trainClamp1-99Range.npy")
config_path = r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/ext_val_torch_loader_config.json"

with open(config_path) as f:
        cfg = json.load(f)

In [5]:
runInference = False

if not runInference: #load previously saved results if inference not rerun
    results = torch.load("external_validation_results.pt")
    print("prior external validation results loaded for analysis.")

prior external validation results loaded for analysis.


In [6]:
def adapt_batch_norm(model, data_loader, device):
    model.train()  # Update BN, but not step optimizer (no gradients)
    
    print("Adapting Batch Norm statistics to external domain...")
    
    with torch.no_grad():
        for x, _ in data_loader:
            x = x.to(device)
            _ = model(x) 
            
    print("Adaptation complete.")
    model.eval() # Switch back to eval for actual prediction
    return model

def safe_generator(iterable):
    iterator = iter(iterable)
    while True:
        try:
            yield next(iterator)
        except RuntimeError as e:
            print(f"Skipping corrupt batch/image: {e}")
            continue
        except StopIteration:
            break

def create_actual_model(model_name, num_classes=1):
    model = create_model(
        model_name,
        pretrained=False, num_classes=1
    ).to(device).eval()
    return model
        
if runInference:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
    results = []
    
    for pick in data:
        # Build paths dynamically based on your logic
        try:
            logits_path_str = pick["logits"]
            idx = logits_path_str.index("objects/")
            objects_path = logits_path_str[:idx] + "objects/"
        except (ValueError, KeyError):
            print("Could not find 'objects/' in path, skipping this pick.")
            continue
    
        # Load JSON settings
        try:
            with open(os.path.join(objects_path, "0-dataset_settings.json"), "r") as f:
                settings = json.load(f)
            with open(os.path.join(objects_path, "0-model_details.json"), "r") as f:
                model_details = json.load(f)
        except FileNotFoundError as e:
            print(f"Metadata file missing: {e}")
            continue
    
        # Prepare Data
        image_size = settings["image_size"][0]
        
        test_transformer = transforms.Compose([
                transforms.Resize((image_size, image_size), antialias=True),
                transforms.ConvertImageDtype(torch.float32), #float conversion
                PercentileDomainAdaptation(trPercentiles[0], trPercentiles[1]),
            ])
        
        dataset, loader = load_dataset(image_size, test_transform = test_transformer)
        
        # Initialize Model
        print(f"\n--- Loading Model: {model_details['model_name']} ---")
        model = create_actual_model(model_details["model_name"]).to(device)
        model.eval()
    
        pick["external_validation_results"] = []
    
        # Iterate through weight checkpoints
        for weight_file in pick.get("weights", []):
            full_weight_path = os.path.join(choiPath, weight_file)
            
            if not os.path.exists(full_weight_path):
                print(f"Weight file not found: {full_weight_path}")
                continue
    
            # Map location ensures weights load to the correct device
            state_dict = torch.load(full_weight_path, map_location=device)
            model.load_state_dict(state_dict)
            print(f"Loaded weights: {weight_file}")
            
            y_logits = []
            model = adapt_batch_norm(model, loader, device) #adapting batch normalization
            
            with torch.no_grad():
                for x, _ in loader:
                # for x, _ in safe_generator(loader):
                    x = x.to(device)
                    output = model(x)
                    
                    # Assuming binary/regression (output shape [batch, 1])
                    y_logits.extend(output.view(-1).cpu().numpy())
    
            # Save results for this weight set
            pick["external_validation_results"].append(np.array(y_logits))
    
        results.append(pick)
    
    # --- 5. Final Save ---
    if results:
        torch.save(results, "external_validation_results.pt")
        print("\nProcessing complete. Results saved.")
    else:
        print("\nNo results generated.")
else:
    print('runInference not enabled. Please enable if this was not intended.')

runInference not enabled. Please enable if this was not intended.


In [7]:
try:
    print(f"""Bootstrapping variable already set (value={needBootStrapping}). Are you sure you didn't already run this?
bootstrap indices are stored in 'bootStrapIdxs'.""")

except:
    needBootStrapping = True
    rng = np.random.default_rng(seed=seed_value)
    B = 5000
    bootStrapIdxs = []
    extValSize = len(labels)
    
    if needBootStrapping:
        print("""setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.""")
        for _ in range(B):
            idx = np.random.randint(0, extValSize, extValSize)
            while True:
                if len(np.unique(labels[idx])) == 2:
                    break #if samples only have one label, try again
                idx = np.random.randint(0, extValSize, extValSize)
            bootStrapIdxs.append(idx)
        print("Bootstrapping indices set.")
    else:
        print("Bootstrapping not enabled. Remember, only run this ONCE in script for reproducibility.")
    needBootstrapping = False

setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.
Bootstrapping indices set.


In [8]:
def ext_val_metrics(y_true, probs, target_value=0.90): # sensitivity at specificity and vice versa and PR/ROC AUCs

    fpr, tpr, roc_thresh = roc_curve(y_true, probs)
    precision, recall, pr_thresh = precision_recall_curve(y_true, probs)
    rocAUC = roc_auc_score(y_true, probs)
    prAUC = auc(recall[::-1], precision[::-1]) #need to flip precision and recall to calculate auc
    
    target_fpr = 1 - target_value
    calculated_sensitivity = np.interp(target_fpr, fpr, tpr)

    # Ensure monotonicity for finding specificity at sensitivity
    tpr_monotonic, idx = np.unique(tpr, return_index=True)
    fpr_at_tpr = fpr[idx]
    fpr_at_target = np.interp(target_value, tpr_monotonic, fpr_at_tpr)
    calculated_specificity = 1 - fpr_at_target

    return rocAUC, prAUC, calculated_sensitivity, calculated_specificity
    

In [10]:
ensembleModels = ['efficientnetv2_m', 'fastvit_ma36', 'mobilenetv4_conv_large', 'repvit_m3.native', 'resnetv2_34']
ensembleProb = np.zeros(len(labels))
ensembleIndivPreds = []
ensembleProbBoot = np.zeros((B, len(labels)))

for modelData in results:
    if modelData['name'] in ensembleModels:
        # probs = expit(modelData['external_validation_results'][0]) #NOTE: this is just looking at one fold
        probs = np.zeros(len(labels))
        for fold in range(5):
            probs = probs + modelData['external_validation_results'][fold]
        probs = expit(probs/5) #average of the five folds

        rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(labels, probs, target_value=0.90)
        # print(len(modelData['external_validation_results'][0]))
        print(f"""Model name: {modelData['name']}\nInt Val ROC AUC: {modelData['mean']:.3f}
Ext Test ROC AUC: {rocAUC:.3f}
Ext Test PR AUC: {prAUC:.3f}
Ext Test sens @spec90%: {sensAtSpec:.3f}
Ext Test spec @sens90%: {specAtSens:.3f}
                ----------""")
        foldSimilarity = []
        for i in range(5): #correlation analysis between folds
            for j in range(i+1, 5):
                p_i = expit(modelData['external_validation_results'][i])
                p_j = expit(modelData['external_validation_results'][j])
                foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
        foldSimilarity = np.array(foldSimilarity)
        print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")

        ensembleProb = ensembleProb + probs
        ensembleIndivPreds.append(probs)

        metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
        bootMetrics = pd.DataFrame(columns=metricNames)
        print(f"----------\nStarting Bootstrapping:")
        for bootNum in range(B):
            bootIdx = bootStrapIdxs[bootNum] #indices for this bootstrap
            bootLabels = labels[bootIdx]
            bootProbs = np.zeros(len(bootLabels))
            for fold in range(5):
                bootProbs = bootProbs + modelData['external_validation_results'][fold][bootIdx]
            bootProbs = expit(bootProbs/5) #average of the five folds
            ensembleProbBoot[bootNum] = ensembleProbBoot[bootNum] + bootProbs
            
            rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(bootLabels, bootProbs, target_value=0.90)
            currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
            bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)
        # print(bootMetrics)
        CIs = pd.DataFrame(columns=metricNames)
        for metric in metricNames:
            lower = np.percentile(bootMetrics[metric], 2.5)
            upper = np.percentile(bootMetrics[metric], 97.5)
            CIs[metric] = [lower, upper]
            print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")

ensembleProbBoot = ensembleProbBoot/len(ensembleModels)
ensembleProb = ensembleProb/len(ensembleModels)

Model name: efficientnetv2_m
Int Val ROC AUC: 0.953
Ext Test ROC AUC: 0.800
Ext Test PR AUC: 0.441
Ext Test sens @spec90%: 0.467
Ext Test spec @sens90%: 0.533
                ----------
Fold corr coeff -- mean: 0.753| std dev: 0.051
----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.713, 0.876
----------
prAUC CIs (2.5,97.5): 0.292, 0.647
----------
sensAtSpec CIs (2.5,97.5): 0.231, 0.656
----------
specAtSens CIs (2.5,97.5): 0.396, 0.733
----------
Model name: fastvit_ma36
Int Val ROC AUC: 0.976
Ext Test ROC AUC: 0.812
Ext Test PR AUC: 0.563
Ext Test sens @spec90%: 0.433
Ext Test spec @sens90%: 0.475
                ----------
Fold corr coeff -- mean: 0.829| std dev: 0.026
----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.722, 0.890
----------
prAUC CIs (2.5,97.5): 0.374, 0.745
----------
sensAtSpec CIs (2.5,97.5): 0.261, 0.682
----------
specAtSens CIs (2.5,97.5): 0.360, 0.755
----------
Model name: mobilenetv4_conv_large
Int Val ROC AUC: 0.944
Ext Test ROC AUC: 0.772
Ext Test PR AUC: 0.501
Ext Test sens @spec90%: 0.533
Ext Test spec @sens90%: 0.383
                ----------
Fold corr coeff -- mean: 0.676| std dev: 0.115
----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.663, 0.871
----------
prAUC CIs (2.5,97.5): 0.323, 0.697
----------
sensAtSpec CIs (2.5,97.5): 0.333, 0.719
----------
specAtSens CIs (2.5,97.5): 0.109, 0.700
----------
Model name: repvit_m3.native
Int Val ROC AUC: 0.948
Ext Test ROC AUC: 0.832
Ext Test PR AUC: 0.603
Ext Test sens @spec90%: 0.533
Ext Test spec @sens90%: 0.550
                ----------
Fold corr coeff -- mean: 0.706| std dev: 0.059
----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.751, 0.903
----------
prAUC CIs (2.5,97.5): 0.421, 0.759
----------
sensAtSpec CIs (2.5,97.5): 0.323, 0.739
----------
specAtSens CIs (2.5,97.5): 0.455, 0.754
----------
Model name: resnetv2_34
Int Val ROC AUC: 0.959
Ext Test ROC AUC: 0.809
Ext Test PR AUC: 0.530
Ext Test sens @spec90%: 0.500
Ext Test spec @sens90%: 0.525
                ----------
Fold corr coeff -- mean: 0.695| std dev: 0.054
----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.719, 0.888
----------
prAUC CIs (2.5,97.5): 0.344, 0.722
----------
sensAtSpec CIs (2.5,97.5): 0.300, 0.731
----------
specAtSens CIs (2.5,97.5): 0.367, 0.712
----------


In [11]:
rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(labels, ensembleProb, target_value=0.90)
# print(len(modelData['external_validation_results'][0]))
print(f"""Model name: Ensemble (soft)
Ext Test ROC AUC: {rocAUC:.3f}
Ext Test PR AUC: {prAUC:.3f}
Ext Test sens @spec90%: {sensAtSpec:.3f}
Ext Test spec @sens90%: {specAtSens:.3f}
                ----------""")
foldSimilarity = []
for i in range(5): #correlation analysis between folds
    for j in range(i+1, 5):
        p_i = expit(ensembleIndivPreds[i])
        p_j = expit(ensembleIndivPreds[j])
        foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
foldSimilarity = np.array(foldSimilarity)
print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")

metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
ensBootMetrics = pd.DataFrame(columns=metricNames)
for bootNum in range(B):
    bootIdx = bootStrapIdxs[bootNum] #indices for this bootstrap
    bootLabels = labels[bootIdx]
    
    rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(bootLabels, ensembleProbBoot[bootNum], target_value=0.90)
    currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
    ensBootMetrics = pd.concat([ensBootMetrics, currMetrics], ignore_index=True)

bootCIs = pd.DataFrame(columns=metricNames)
for metric in metricNames:
    lower = np.percentile(ensBootMetrics[metric], 2.5)
    upper = np.percentile(ensBootMetrics[metric], 97.5)
    bootCIs[metric] = [lower, upper]
    print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")

Model name: Ensemble (soft)
Ext Test ROC AUC: 0.821
Ext Test PR AUC: 0.601
Ext Test sens @spec90%: 0.533
Ext Test spec @sens90%: 0.433
                ----------
Fold corr coeff -- mean: 0.813| std dev: 0.044


  ensBootMetrics = pd.concat([ensBootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.735, 0.897
----------
prAUC CIs (2.5,97.5): 0.416, 0.760
----------
sensAtSpec CIs (2.5,97.5): 0.344, 0.720
----------
specAtSens CIs (2.5,97.5): 0.356, 0.770
----------


# Test Few-shot Learning
## 1:2 ratio of fine tuning to true test data

In [15]:
from sklearn.model_selection import StratifiedShuffleSplit

def stratified_fewshot_split(file_paths, labels, n_fewshot=50, seed=9999):
    sss = StratifiedShuffleSplit(
        n_splits=1,
        test_size=len(labels) - n_fewshot,
        random_state=seed
    )
    fewshot_idx, test_idx = next(sss.split(file_paths, labels))

    fewshot_paths = [file_paths[i] for i in fewshot_idx]
    fewshot_labels = labels[fewshot_idx]

    test_paths = [file_paths[i] for i in test_idx]
    test_labels = labels[test_idx]

    return fewshot_paths, fewshot_labels, test_paths, test_labels
    
def fewshot_finetune(model, fewshot_loader, device,
                     lr=1e-4, epochs=30, pos_weight=4.0):

    # Freeze backbone
    for name, param in model.named_parameters():
        if "classifier" not in name and "head" not in name:
            param.requires_grad = False

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=lr,
        weight_decay=1e-4
    )

    criterion = torch.nn.BCEWithLogitsLoss(
        pos_weight=torch.tensor([pos_weight]).to(device)
    )

    model.train()
    for epoch in range(epochs):
        for x, y in fewshot_loader:
            x = x.to(device)
            y = y.float().to(device)

            optimizer.zero_grad()
            logits = model(x).view(-1)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

    model.eval()
    return model

def load_subset_dataset(image_size, file_paths, labels, batch_size: int = 32, test_transform = None, cfg = None):
    dataset = XrayDataset(
        file_paths=file_paths,
        labels=labels,
        transform=test_transform
    )
    if cfg:
        loader = DataLoader(
            dataset,
            batch_size=cfg["batch_size"],
            shuffle=cfg["shuffle"],
            pin_memory=True
        )
    else:
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True
        )
    return dataset, loader


In [16]:
runTest = False

if not runTest: #load previously saved results if inference not rerun
    fewShotResults = torch.load("external_validation_fewShotResults.pt")
    print("prior few-shot external validation results loaded for analysis.")

prior few-shot external validation results loaded for analysis.


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if runTest:
    ensembleModels = ['efficientnetv2_m.in21k_ft_in1k', 'fastvit_ma36.apple_dist_in1k', 'mobilenetv4_conv_large.e600_r384_in1k', 
                  'repvit_m3.dist_in1k', 'resnetv2_34.ra4_e3600_r224_in1k']
    fewShotResults = []

    for pick in data:
        try:
            logits_path_str = pick["logits"]
            idx = logits_path_str.index("objects/")
            objects_path = logits_path_str[:idx] + "objects/"
        except (ValueError, KeyError):
            print("Could not find 'objects/' in path, skipping this pick.")
            continue

        # --- Load JSON settings ---
        try:
            with open(os.path.join(objects_path, "0-dataset_settings.json"), "r") as f:
                settings = json.load(f)
            with open(os.path.join(objects_path, "0-model_details.json"), "r") as f:
                model_details = json.load(f)
        except FileNotFoundError as e:
            print(f"Metadata file missing: {e}")
            continue

        # --- Initialize model ---
        if model_details['model_name'] not in ensembleModels:
            continue
            
        print(f"\n--- Loading Model: {model_details['model_name']} ---")
        model = create_actual_model(model_details["model_name"]).to(device)
        model.eval()

        image_size = settings["image_size"][0]

        test_transformer = transforms.Compose([
            transforms.Resize((image_size, image_size), antialias=True),
            transforms.ConvertImageDtype(torch.float32),
            PercentileDomainAdaptation(trPercentiles[0], trPercentiles[1]),
        ])

        # --- Load full external dataset ---
        full_dataset, full_loader = load_dataset(
            image_size,
            test_transform=test_transformer
        )

        file_paths = full_dataset.file_paths
        labels = np.array(full_dataset.labels)

        # --- Stratified few-shot split ---
        fewshot_paths, fewshot_labels, test_paths, test_labels = \
            stratified_fewshot_split(
                file_paths,
                labels,
                n_fewshot=50,
                seed=seed_value
            )

        fewshot_dataset, fewshot_loader = load_subset_dataset(
            image_size = image_size,
            file_paths = fewshot_paths,
            labels = fewshot_labels,
            test_transform = test_transformer,
            cfg=cfgFewShot
        )

        test_dataset, test_loader = load_subset_dataset(
            image_size = image_size,
            file_paths = test_paths,
            labels = test_labels,
            test_transform = test_transformer,
            cfg=cfgFewShot
        )

        pick["external_validation_fewShotResults"] = []

        # --- Iterate through weight checkpoints ---
        for weight_file in pick.get("weights", []):
            full_weight_path = os.path.join(choiPath, weight_file)

            if not os.path.exists(full_weight_path):
                print(f"Weight file not found: {full_weight_path}")
                continue

            state_dict = torch.load(full_weight_path, map_location=device)
            model.load_state_dict(state_dict)
            print(f"Loaded weights: {weight_file}")

            # --- BN adaptation using ALL external images (unlabeled) ---
            model = adapt_batch_norm(model, full_loader, device)

            # --- Few-shot head-only fine-tuning ---
            pos_weight = (len(labels) - labels.sum()) / labels.sum()  #weighted tuning to account for data imbalance
            model = fewshot_finetune(
                model,
                fewshot_loader,
                device,
                lr=1e-4,
                epochs=30,
                pos_weight=pos_weight
            )

            # --- Predict on HELD-OUT external test subset ---
            y_logits = []
            with torch.no_grad():
                for x, _ in test_loader:
                    x = x.to(device)
                    output = model(x)
                    y_logits.extend(output.view(-1).cpu().numpy())

            pick["external_validation_fewShotResults"].append(
                np.array(y_logits)
            )

        fewShotResults.append({
            "name": model_details["model_name"],
            "fewshot_idx": fewshot_paths,
            "test_idx": test_paths,
            "fewshot_labels": fewshot_labels,
            "test_labels": test_labels,
            "external_validation_fewShotResults": pick["external_validation_fewShotResults"],
        })

    # --- Final Save ---
    if fewShotResults:
        torch.save(fewShotResults, "external_validation_fewShotResults.pt")
        print("\nProcessing complete. Few-shot test results saved.")
    else:
        print("\nNo results generated.")

else:
    print("fewShotResults not enabled. Did you already load previous results? Please enable if this was not intended.")

fewShotResults not enabled. Did you already load previous results? Please enable if this was not intended.


In [17]:
labelsFewShot = fewShotResults[0]['test_labels'] #same for every fold and architecture

In [18]:
try:
    print(f"""Bootstrapping variable already set (value={needFewShotBootStrapping}). Are you sure you didn't already run this?
bootstrap indices are stored in 'fewShotBootStrapIdxs'.""")

except:
    needFewShotBootStrapping = True
    rng = np.random.default_rng(seed=seed_value)
    B = 5000
    fewShotBootStrapIdxs = []
    extFewShotValSize = len(labelsFewShot)
    
    if needFewShotBootStrapping:
        print("""setting random seed and needFewShotBootStrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.""")
        for _ in range(B):
            idx = np.random.randint(0, extFewShotValSize, extFewShotValSize)
            while True:
                if len(np.unique(labelsFewShot[idx])) == 2:
                    break #if samples only have one label, try again
                idx = np.random.randint(0, extFewShotValSize, extFewShotValSize)
            fewShotBootStrapIdxs.append(idx)
        print("needFewShotBootStrapping indices set.")
    else:
        print("needFewShotBootStrapping not enabled. Remember, only run this ONCE in script for reproducibility.")
    needFewShotBootStrapping = False

setting random seed and needFewShotBootStrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.
needFewShotBootStrapping indices set.


In [19]:
ensembleModels = ['tf_efficientnetv2_m.in21k_ft_in1k', 'fastvit_ma36.apple_dist_in1k', 'mobilenetv4_conv_large.e600_r384_in1k', 
                  'repvit_m3.dist_in1k', 'resnetv2_34.ra4_e3600_r224_in1k']
#same models as before, full name vs truncated name
ensembleProbFewShot = np.zeros(len(fewShotResults[0]['test_labels']))
ensembleListPredsFewShot = []
ensembleProbBootFewShot = np.zeros((B, len(fewShotResults[0]['test_labels'])))

for modelData in fewShotResults:
    if modelData['name'] in ensembleModels:
        probs = np.zeros(len(labelsFewShot))
        for fold in range(5):
            probs = probs + modelData['external_validation_fewShotResults'][fold]
        probs = expit(probs/5) #average of the five folds

        rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(labelsFewShot, probs, target_value=0.90)
        # print(len(modelData['external_validation_results'][0]))
        print(f"""Model name: {modelData['name']}\n
Ext Test ROC AUC: {rocAUC:.3f}
Ext Test PR AUC: {prAUC:.3f}
Ext Test sens @spec90%: {sensAtSpec:.3f}
Ext Test spec @sens90%: {specAtSens:.3f}
                ----------""")
        foldSimilarity = []
        for i in range(5): #correlation analysis between folds
            for j in range(i+1, 5):
                p_i = expit(modelData['external_validation_fewShotResults'][i])
                p_j = expit(modelData['external_validation_fewShotResults'][j])
                foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
        foldSimilarity = np.array(foldSimilarity)
        print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")

        ensembleProbFewShot = ensembleProbFewShot + probs
        ensembleListPredsFewShot.append(probs)

        metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
        bootMetrics = pd.DataFrame(columns=metricNames)
        print(f"----------\nStarting Bootstrapping for Few Shot Results:")
        for bootNum in range(B):
            bootIdx = fewShotBootStrapIdxs[bootNum] #indices for this bootstrap
            bootLabels = labelsFewShot[bootIdx]
            bootProbs = np.zeros(len(bootLabels))
            for fold in range(5):
                bootProbs = bootProbs + modelData['external_validation_fewShotResults'][fold][bootIdx]
            bootProbs = expit(bootProbs/5) #average of the five folds
            if modelData['name'] in ensembleModels:
                ensembleProbBootFewShot[bootNum] = ensembleProbBootFewShot[bootNum] + bootProbs
            
            rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(bootLabels, bootProbs, target_value=0.90)
            currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
            bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)
        # print(bootMetrics)
        CIs = pd.DataFrame(columns=metricNames)
        for metric in metricNames:
            lower = np.percentile(bootMetrics[metric], 2.5)
            upper = np.percentile(bootMetrics[metric], 97.5)
            CIs[metric] = [lower, upper]
            print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")
    # else:
    #     print(modelData['name'])
ensembleProbBootFewShot = ensembleProbBootFewShot/len(ensembleModels)
ensembleProbFewShot = ensembleProbFewShot/len(ensembleModels)

Model name: tf_efficientnetv2_m.in21k_ft_in1k

Ext Test ROC AUC: 0.605
Ext Test PR AUC: 0.424
Ext Test sens @spec90%: 0.300
Ext Test spec @sens90%: 0.075
                ----------
Fold corr coeff -- mean: 0.452| std dev: 0.185
----------
Starting Bootstrapping for Few Shot Results:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.435, 0.767
----------
prAUC CIs (2.5,97.5): 0.213, 0.632
----------
sensAtSpec CIs (2.5,97.5): 0.133, 0.600
----------
specAtSens CIs (2.5,97.5): 0.026, 0.463
----------
Model name: fastvit_ma36.apple_dist_in1k

Ext Test ROC AUC: 0.859
Ext Test PR AUC: 0.651
Ext Test sens @spec90%: 0.550
Ext Test spec @sens90%: 0.594
                ----------
Fold corr coeff -- mean: 0.650| std dev: 0.036
----------
Starting Bootstrapping for Few Shot Results:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.770, 0.935
----------
prAUC CIs (2.5,97.5): 0.424, 0.836
----------
sensAtSpec CIs (2.5,97.5): 0.316, 0.800
----------
specAtSens CIs (2.5,97.5): 0.486, 0.849
----------
Model name: mobilenetv4_conv_large.e600_r384_in1k

Ext Test ROC AUC: 0.781
Ext Test PR AUC: 0.550
Ext Test sens @spec90%: 0.450
Ext Test spec @sens90%: 0.488
                ----------
Fold corr coeff -- mean: 0.461| std dev: 0.084
----------
Starting Bootstrapping for Few Shot Results:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.652, 0.895
----------
prAUC CIs (2.5,97.5): 0.321, 0.760
----------
sensAtSpec CIs (2.5,97.5): 0.227, 0.720
----------
specAtSens CIs (2.5,97.5): 0.250, 0.772
----------
Model name: repvit_m3.dist_in1k

Ext Test ROC AUC: 0.854
Ext Test PR AUC: 0.648
Ext Test sens @spec90%: 0.550
Ext Test spec @sens90%: 0.650
                ----------
Fold corr coeff -- mean: 0.713| std dev: 0.051
----------
Starting Bootstrapping for Few Shot Results:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.763, 0.930
----------
prAUC CIs (2.5,97.5): 0.435, 0.821
----------
sensAtSpec CIs (2.5,97.5): 0.296, 0.800
----------
specAtSens CIs (2.5,97.5): 0.512, 0.832
----------
Model name: resnetv2_34.ra4_e3600_r224_in1k

Ext Test ROC AUC: 0.823
Ext Test PR AUC: 0.574
Ext Test sens @spec90%: 0.600
Ext Test spec @sens90%: 0.562
                ----------
Fold corr coeff -- mean: 0.811| std dev: 0.027
----------
Starting Bootstrapping for Few Shot Results:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.709, 0.920
----------
prAUC CIs (2.5,97.5): 0.348, 0.789
----------
sensAtSpec CIs (2.5,97.5): 0.316, 0.821
----------
specAtSens CIs (2.5,97.5): 0.300, 0.842
----------


In [43]:
np.save('ensembleFewShotProbs.npy', ensembleProbFewShot)
np.save('ensembleZeroShotProbs.npy', ensembleProb)
np.save('fewShotEnsembleGTs.npy', fewShotResults[0]['test_labels'])
np.save('zeroShotEnsembleGTs.npy', labels)

In [20]:
rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(labelsFewShot, ensembleProbFewShot, target_value=0.90)
# print(len(modelData['external_validation_results'][0]))
print(f"""Model name: Ensemble (soft; Few Shot)
Ext Test ROC AUC: {rocAUC:.3f}
Ext Test PR AUC: {prAUC:.3f}
Ext Test sens @spec90%: {sensAtSpec:.3f}
Ext Test spec @sens90%: {specAtSens:.3f}
                ----------""")

foldSimilarity = []
for i in range(5): #correlation analysis between folds
    for j in range(i+1, 5):
        p_i = expit(ensembleListPredsFewShot[i])
        p_j = expit(ensembleListPredsFewShot[j])
        foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
foldSimilarity = np.array(foldSimilarity)
print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")

metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
ensBootMetricsFewShot = pd.DataFrame(columns=metricNames)
for bootNum in range(B):
    bootIdx = fewShotBootStrapIdxs[bootNum] #indices for this bootstrap
    bootLabels = labelsFewShot[bootIdx]
    
    rocAUC, prAUC, sensAtSpec, specAtSens = ext_val_metrics(bootLabels, ensembleProbBootFewShot[bootNum], target_value=0.90)
    currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
    ensBootMetricsFewShot = pd.concat([ensBootMetricsFewShot, currMetrics], ignore_index=True)

bootFewShotCIs = pd.DataFrame(columns=metricNames)
for metric in metricNames:
    lower = np.percentile(ensBootMetricsFewShot[metric], 2.5)
    upper = np.percentile(ensBootMetricsFewShot[metric], 97.5)
    bootFewShotCIs[metric] = [lower, upper]
    print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")

Model name: Ensemble (soft; Few Shot)
Ext Test ROC AUC: 0.847
Ext Test PR AUC: 0.601
Ext Test sens @spec90%: 0.500
Ext Test spec @sens90%: 0.650
                ----------
Fold corr coeff -- mean: 0.667| std dev: 0.189


  ensBootMetricsFewShot = pd.concat([ensBootMetricsFewShot, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.757, 0.924
----------
prAUC CIs (2.5,97.5): 0.375, 0.796
----------
sensAtSpec CIs (2.5,97.5): 0.238, 0.783
----------
specAtSens CIs (2.5,97.5): 0.511, 0.831
----------
