# Encode Training/Validation data using google cxr-foundation

In [1]:
import os
import ctypes
import tensorflow as tf
from scipy.special import expit
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve, auc, roc_curve

# 1. Force-load the correct cuDNN 9.3.0 library
lib_path = "/home/wuat2/anaconda3/pkgs/cudnn-9.3.0.75-cuda12.6/lib/libcudnn.so.9"
try:
    ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
    os.environ['LD_LIBRARY_PATH'] = "/home/wuat2/anaconda3/pkgs/cudnn-9.3.0.75-cuda12.6/lib:" + os.environ.get('LD_LIBRARY_PATH', '')
    print("✅ Successfully force-loaded cuDNN 9.3.0 into memory.")
except Exception as e:
    print(f"❌ Could not load library: {e}")

# 2. Configure Memory Growth IMMEDIATELY after import
# This MUST happen before the Conv2D test
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"✅ Memory growth enabled for {len(gpus)} GPU(s).")
    except RuntimeError as e:
        print(f"❌ Configuration Error: {e}")

# 3. NOW run the Functional Test
# This will use the loaded cuDNN 9.3 and honor the memory growth setting
try:
    with tf.device('/device:GPU:0'):
        # Dummy op to trigger cuDNN kernels
        _ = tf.keras.layers.Conv2D(2, 3)(tf.random.normal((1, 28, 28, 1)))
    print("✅ cuDNN 9.3.0 is functional and verified for Google CXR Foundation Model.")
except Exception as e:
    print(f"❌ Functional Test Failed: {e}")

2026-01-26 18:09:16.768481: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-26 18:09:16.790723: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769479756.816106 1046088 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769479756.823909 1046088 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769479756.842922 1046088 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

✅ Successfully force-loaded cuDNN 9.3.0 into memory.
✅ Memory growth enabled for 8 GPU(s).


I0000 00:00:1769479773.958139 1046088 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 39866 MB memory:  -> device: 0, name: NVIDIA L40, pci bus id: 0000:3d:00.0, compute capability: 8.9
I0000 00:00:1769479773.960654 1046088 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 43224 MB memory:  -> device: 1, name: NVIDIA L40, pci bus id: 0000:b2:00.0, compute capability: 8.9
I0000 00:00:1769479773.961530 1046088 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 9553 MB memory:  -> device: 2, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:1a:00.0, compute capability: 7.5
I0000 00:00:1769479773.962386 1046088 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 9553 MB memory:  -> device: 3, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:1b:00.0, compute capability: 7.5
I0000 00:00:1769479773.963239 1046088 gpu_device.cc:2019] Created device /job:

✅ cuDNN 9.3.0 is functional and verified for Google CXR Foundation Model.


In [2]:
# Need to set cudnn path here; DO NOT IMPORT TENSORFLOW HERE
from load_dataset import data_summary, make_dataset, set_seeds
import numpy as np
import pandas as pd
from timm.models import create_model, list_models
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split

SEED = 9999

In [None]:
base_dir = r"/extra/xielab0/wuat2/AryaQualityViewProjectData"

image_dir = os.path.join(base_dir, r"images")
encoded_dir = os.path.join(base_dir, r"encodedImgs")
ext_encoded_dir = os.path.join(base_dir, r"encodedExtImgs")

IMAGE_RESOLUTION = 384

dataset_settings = {
    "image_size": (IMAGE_RESOLUTION, IMAGE_RESOLUTION),
    "label_map": {
        'diagnostic quality': 0,
        'repeat needed': 1,
    }
}

dataset = make_dataset(base_dir, image_dir, 
                       image_size=dataset_settings["image_size"],
                       label_map=dataset_settings["label_map"])

In [None]:
from huggingface_hub import login
login()

In [None]:
from huggingface_hub import snapshot_download
import tensorflow_text as tf_text

snapshot_download(repo_id="google/cxr-foundation",local_dir='./content/hf',
                  allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'])

if 'elixrc_model' not in locals():
    elixrc_model = tf.saved_model.load('./content/hf/elixr-c-v2-pooled')
    elixrc_infer = elixrc_model.signatures['serving_default']

if 'qformer_model' not in locals():
    qformer_model = tf.saved_model.load("./content/hf/pax-elixr-b-text")

In [6]:
# # USE TO FOR EMBEDDING INT VAL IMAGES USING GOOGLE CXR FOUNDATION MODEL; ELIXR-B

# def create_tf_example(image_bytes):
#     feature = {
#         'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
#     }
#     example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
#     return example_proto.SerializeToString()
    
# for imgIndex in range(len(dataset.images)):
#     imageName = dataset.images[imgIndex]
#     saveName = imageName.replace('.png', '.npy').replace('/', '_')
#     # imageLabel = dataset.labels[imgIndex]
    
#     imagePath = os.path.join(image_dir, imageName)
    
#     try:
#         image_bytes = tf.io.read_file(imagePath).numpy()
#         serialized_example = create_tf_example(image_bytes)

#         # 3. ELIXR-C Inference (expects a 1D tensor of serialized strings)
#         elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
#         elixrc_embedding = elixrc_output['feature_maps_0'].numpy()

#         qformer_input = {
#             'image_feature': elixrc_embedding.tolist(),
#             'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
#             'paddings':np.zeros((1, 1, 128), dtype=np.float32).tolist(),
#         }
        
#         qformer_output = qformer_model.signatures['serving_default'](**qformer_input)
#         elixrb_embeddings = qformer_output['all_contrastive_img_emb']
        
#         # print("ELIXR-B - embedding shape: ", elixrb_embeddings.shape)

#         saveFileName = os.path.join(encoded_dir, saveName)
#         np.save(saveFileName, elixrb_embeddings)
#         # new_row_data = {'fileName': imageName, 'Encoded': elixrc_embedding, 'Label': imageLabel}
#         # transformedData.loc[imgIndex] = new_row_data
    
#     except FileNotFoundError:
#         print(f"Error: The file '{imagePath}' was not found.")
#     except Exception as e:
#         print(f"An error occurred: {e}")

# print("successfully embedded data with Google CXR Foundation Model. Img Tensor --> ELIXR-C --> ELIXR-B (embedding shape 1x32x128).")

In [44]:
# # # USE TO FOR EMBEDDING EXT VAL IMAGES USING GOOGLE CXR FOUNDATION MODEL; ELIXR-B
# def create_tf_example(image_bytes):
#     feature = {
#         'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
#     }
#     example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
#     return example_proto.SerializeToString()

# os.makedirs(ext_encoded_dir, exist_ok=True)
# file_paths

# for imgIndex in range(len(file_paths)):
#     imagePath = file_paths[imgIndex]
#     imageName = os.path.basename(imagePath)
#     saveName = imageName.replace('.png', '.npy').replace('/', '_')
    
#     try:
#         image_bytes = tf.io.read_file(imagePath).numpy()
#         serialized_example = create_tf_example(image_bytes)

#         # 3. ELIXR-C Inference (expects a 1D tensor of serialized strings)
#         elixrc_output = elixrc_infer(input_example=tf.constant([serialized_example]))
#         elixrc_embedding = elixrc_output['feature_maps_0'].numpy()

#         qformer_input = {
#             'image_feature': elixrc_embedding.tolist(),
#             'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
#             'paddings':np.zeros((1, 1, 128), dtype=np.float32).tolist(),
#         }
        
#         qformer_output = qformer_model.signatures['serving_default'](**qformer_input)
#         elixrb_embeddings = qformer_output['all_contrastive_img_emb']
        
#         # print("ELIXR-B - embedding shape: ", elixrb_embeddings.shape)

#         saveFileName = os.path.join(ext_encoded_dir, saveName)
#         np.save(saveFileName, elixrb_embeddings)
#         # new_row_data = {'fileName': imageName, 'Encoded': elixrc_embedding, 'Label': imageLabel}
#         # transformedData.loc[imgIndex] = new_row_data
    
#     except FileNotFoundError:
#         print(f"Error: The file '{imagePath}' was not found.")
#     except Exception as e:
#         print(f"An error occurred: {e}")

# print("successfully embedded ext val data with Google CXR Foundation Model. Img Tensor --> ELIXR-C --> ELIXR-B (embedding shape 1x32x128).")








































successfully embedded ext val data with Google CXR Foundation Model. Img Tensor --> ELIXR-C --> ELIXR-B (embedding shape 1x32x128).


# Create model and define F1 Metrics 

In [4]:
class MaxF1Score(tf.keras.metrics.Metric):
    def __init__(self, name='max_f1_score', **kwargs):
        super(MaxF1Score, self).__init__(name=name, **kwargs)
        # Define the exact thresholds you used in your PyTorch loop (0.01 to 1.00)
        self.thresholds = tf.constant(tf.linspace(0.01, 1.00, 100))
        
        self.tp = self.add_weight(name='tp', shape=(100,), initializer='zeros')
        self.fp = self.add_weight(name='fp', shape=(100,), initializer='zeros')
        self.fn = self.add_weight(name='fn', shape=(100,), initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None): #Updates cm stats for all thresholds.
        
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        pred_binary = tf.cast(y_pred > self.thresholds, tf.float32)
        
        # 2. Broadcast ground truth to match (Batch_Size, 100)
        true_broadcast = tf.tile(y_true, [1, 100])
        
        # 3. Calculate TP, FP, FN for all thresholds
        # We sum over axis 0 (the batch dimension) to get totals per threshold
        true_positives = tf.reduce_sum(true_broadcast * pred_binary, axis=0)
        false_positives = tf.reduce_sum((1 - true_broadcast) * pred_binary, axis=0)
        false_negatives = tf.reduce_sum(true_broadcast * (1 - pred_binary), axis=0)
        
        # 4. Update the state variables
        self.tp.assign_add(true_positives)
        self.fp.assign_add(false_positives)
        self.fn.assign_add(false_negatives)

    def result(self):
        precision = tf.math.divide_no_nan(self.tp, self.tp + self.fp)
        recall = tf.math.divide_no_nan(self.tp, self.tp + self.fn)
        
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-10)
        
        # Return the single highest F1 score found across all 100 thresholds
        return tf.reduce_max(f1_scores)

    def reset_state(self):
        self.tp.assign(tf.zeros((100,)))
        self.fp.assign(tf.zeros((100,)))
        self.fn.assign(tf.zeros((100,)))

In [5]:
# model creation adapted from Google CXR foundation model github documentation:
# https://github.com/Google-Health/imaging-research/blob/master/cxr-foundation/CXR_Foundation_Demo.ipynb
# note: using Adam vs LARS for stability within small-data regimes and consistency across other models tested

def create_model(heads,
                 token_num=32, # ELIXR-B uses 32x128
                 embeddings_size=128,
                 learning_rate=1e-4, # Standard for AdamW
                 end_lr_factor=0.1,
                 dropout=0.3,
                 decay_steps=1000,
                 loss_weights=None,
                 hidden_layer_sizes=[256, 128], # 4 factor lower sizes to accommodate small dataset
                 weight_decay=1e-4, # Standard for AdamW
                 activation='Sigmoid',
                 loss='binary_crossentropy',
                 seed=None) -> tf.keras.Model:

    # 1. Flattened input for ELIXR-B
    inputs = tf.keras.Input(shape=(token_num * embeddings_size,))
    
    # 2. Reshape and Pooling
    inputs_reshape = tf.keras.layers.Reshape((token_num, embeddings_size))(inputs)
    hidden = tf.keras.layers.GlobalAveragePooling1D()(inputs_reshape)

    # 3. Narrower MLP Layers
    for size in hidden_layer_sizes:
        hidden = tf.keras.layers.Dense(
            size,
            activation='relu',
            kernel_initializer=tf.keras.initializers.HeUniform(seed=seed),
            kernel_regularizer=tf.keras.regularizers.l2(l2=weight_decay)
        )(hidden)
        hidden = tf.keras.layers.BatchNormalization()(hidden)
        hidden = tf.keras.layers.Dropout(dropout, seed=seed)(hidden)

    # 4. Multi-head Output
    output_raw = tf.keras.layers.Dense(
        units=len(heads),
        activation=activation,
        kernel_initializer=tf.keras.initializers.HeUniform(seed=seed)
    )(hidden)

    outputs = {}
    for i, head in enumerate(heads):
        # Create label dictionary entries
        outputs[head] = tf.keras.layers.Lambda(
            lambda x: x[..., i:i + 1], name=head)(output_raw)

    model = tf.keras.Model(inputs, outputs)

    learning_rate_fn = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=tf.cast(learning_rate, tf.float32),
        decay_steps=tf.cast(decay_steps, tf.float32),
        alpha=tf.cast(end_lr_factor, tf.float32))

    model.compile(
        optimizer=tf.keras.optimizers.AdamW(learning_rate=learning_rate_fn, weight_decay=weight_decay),
        loss='binary_crossentropy',
          weighted_metrics=[
            MaxF1Score(),
            tf.keras.metrics.AUC(name='auc_roc'),
            tf.keras.metrics.AUC(curve='PR', name='auc_pr')
          ]
    )
    return model

## Build arrays containing all data for 5-fold validation

In [6]:
features = []
labels = [] #should be exactly the same order as the original pytorch dataset

head_name = 'repeatNeeded'
for imgIdx in range(len(dataset.images)):
    imageName = dataset.images[imgIdx]
    saveName = imageName.replace('.png', '.npy').replace('/', '_')
    embeddedDataPath = os.path.join(encoded_dir, saveName)
    if os.path.exists(embeddedDataPath):
        emb = np.load(embeddedDataPath)
        features.append(emb.flatten())
        labels.append(dataset.labels[imgIdx])
    else:
        print(f'data file {embeddedDataPath} cannot be found. Please revise loop logic.')

features = np.array(features).astype('float32')
labels = np.array(labels).astype('float32')

# Train model

In [11]:
import pickle
# from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
idxFolder = r'/home/wuat2/xray-quality/reruns/fastvit_ma36/objects'

SEED = 9999 #not needed right now since loading from saved indices; here for consistency
BATCH_SIZE = 32
EPOCHS = 100
numFolds = 5

foldHistories = []

def format_output(x, y):
    y_reshaped = {k: tf.expand_dims(v, axis=-1) for k, v in y.items()}
    return x, y_reshaped
    
for i in range(numFolds): #load the indices for 5 fold CV to ensure consistency
    #build training and validation data
    trainFileName = f"fold_{i}_train_indices.arr"
    testFileName = f"fold_{i}_test_indices.arr"
    trainFilePath = os.path.join(idxFolder, trainFileName)
    testFilePath = os.path.join(idxFolder, testFileName)
    
    archiveTrain = np.load(trainFilePath, allow_pickle=True)
    archiveTest = np.load(testFilePath, allow_pickle=True)

    pklTrain = archiveTrain[f'fold_{i}_train_indices/data.pkl']
    pklTest = archiveTest[f'fold_{i}_test_indices/data.pkl']

    trainIdxs = pickle.loads(pklTrain)
    testIdxs = pickle.loads(pklTest)

    train_dataset = tf.data.Dataset.from_tensor_slices((features[trainIdxs], {head_name: labels[trainIdxs]}))
    train_dataset = train_dataset.shuffle(len(trainIdxs)).batch(32).prefetch(tf.data.AUTOTUNE)
    test_dataset = tf.data.Dataset.from_tensor_slices((features[testIdxs], {head_name: labels[testIdxs]}))
    test_dataset = test_dataset.batch(32).prefetch(tf.data.AUTOTUNE)

    train_dataset = train_dataset.map(format_output)
    test_dataset = test_dataset.map(format_output)

    #generate model
    tf.keras.backend.clear_session()
    foundationMLP = create_model([head_name], token_num=32, embeddings_size=128, dropout=0.3, hidden_layer_sizes=[256, 128], seed=SEED)

    history = foundationMLP.fit(
        x=train_dataset,
        validation_data=test_dataset,
        epochs=100,
        verbose=1
    )

    raw_probs = foundationMLP.predict(test_dataset)
    history.history['int_val_probs'] = raw_probs
    history.history['int_val_gts'] = labels[testIdxs]
    history.history['int_val_indices'] = testIdxs

    foldHistories.append(history)

    saveWeightPath = os.path.join(base_dir, 'foundationCXRMLP', f'{i}_CXRWeights.weights.h5')
    foundationMLP.save_weights(saveWeightPath)


Epoch 1/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 290ms/step - auc_pr: 0.4537 - auc_roc: 0.4025 - loss: 1.2418 - max_f1_score: 0.6663 - val_auc_pr: 0.5500 - val_auc_roc: 0.5622 - val_loss: 0.7705 - val_max_f1_score: 0.6636
Epoch 2/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.5720 - auc_roc: 0.5372 - loss: 0.9589 - max_f1_score: 0.6782 - val_auc_pr: 0.6807 - val_auc_roc: 0.7124 - val_loss: 0.7766 - val_max_f1_score: 0.7135
Epoch 3/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.6503 - auc_roc: 0.6662 - loss: 0.8196 - max_f1_score: 0.6994 - val_auc_pr: 0.7304 - val_auc_roc: 0.7752 - val_loss: 0.7912 - val_max_f1_score: 0.7286
Epoch 4/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.6609 - auc_roc: 0.6498 - loss: 0.8405 - max_f1_score: 0.6849 - val_auc_pr: 0.7498 - val_auc_roc: 0.8071 - val_loss: 0.8115 - val_max_f1_score: 0.7692
Epoch 5/10




[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 516ms/step - auc_pr: 0.5091 - auc_roc: 0.4523 - loss: 1.1464 - max_f1_score: 0.6689 - val_auc_pr: 0.4598 - val_auc_roc: 0.3959 - val_loss: 0.7754 - val_max_f1_score: 0.6667
Epoch 2/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.5761 - auc_roc: 0.5931 - loss: 0.9017 - max_f1_score: 0.6681 - val_auc_pr: 0.5998 - val_auc_roc: 0.5540 - val_loss: 0.7821 - val_max_f1_score: 0.6728
Epoch 3/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.6434 - auc_roc: 0.6627 - loss: 0.8187 - max_f1_score: 0.6394 - val_auc_pr: 0.6872 - val_auc_roc: 0.6791 - val_loss: 0.7977 - val_max_f1_score: 0.6800
Epoch 4/100
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - auc_pr: 0.6376 - auc_roc: 0.6740 - loss: 0.8304 - max_f1_score: 0.6859 - val_auc_pr: 0.7478 - val_auc_roc: 0.7315 - val_loss: 0.8172 - val_max_f1_score: 0.6970
Epoch 5/100
[1m19/19

In [7]:
toSaveHistory = False
if toSaveHistory:
    clean_histories = [h.history for h in foldHistories]
    
    save_path = os.path.join(base_dir, 'cxrfoundation_train_int_val_fold_histories.pkl')
    
    with open(save_path, 'wb') as f:
        pickle.dump(clean_histories, f)
    
    print(f"Successfully saved histories to: {save_path}")

In [8]:
import pickle
loadHistory = True

if loadHistory:
    load_path = os.path.join(base_dir, 'cxrfoundation_train_int_val_fold_histories.pkl')
    
    with open(load_path, 'rb') as f:
        loaded_histories = pickle.load(f)
    
    # Access data just like before (but now it's a list of dicts, not History objects)
    fold_0_data = loaded_histories[0]
    print(loaded_histories[0].keys())

dict_keys(['auc_pr', 'auc_roc', 'loss', 'max_f1_score', 'val_auc_pr', 'val_auc_roc', 'val_loss', 'val_max_f1_score', 'int_val_probs', 'int_val_gts', 'int_val_indices'])


In [74]:
def max_f1_score(y_true, probs): #fixed to be specific to each fold; takes GT and post-sigmoid confidences as probs
    max_f1 = 0
    best_thresh = 0

    for thresh in np.arange(0.01, 1.01, 0.01):
        y_pred = (probs.flatten() > thresh).astype(int)

        local_f1 = f1_score(y_true, y_pred)
        if local_f1 > max_f1:
            max_f1 = local_f1
            best_thresh = thresh

    return max_f1, best_thresh #maximum score, threshold

In [79]:
foundationF1s = []
foundationAccs = []

for i in range(5):
    # 1. Get raw data
    raw_gts = loaded_histories[i]['int_val_gts']
    raw_probs = loaded_histories[i]['int_val_probs']['repeatNeeded']
    
    # 2. Call function (This was already working because you flattened inside the call)
    max_f1, best_thresh = max_f1_score(raw_gts, raw_probs)
    foundationF1s.append(max_f1)
    
    # 3. CRITICAL FIX HERE: Flatten both arrays before comparison
    # Without .flatten(), (N, 1) > threshold creates (N, 1)
    # And if gts is (N,), comparison becomes (N, N) matrix broadcasting!
    foundPreds = (raw_probs.flatten() > best_thresh).astype(int)
    gts_flat = raw_gts.flatten().astype(int)
    
    # 4. Calculate Accuracy
    foundAccs = (foundPreds == gts_flat).mean()
    foundationAccs.append(foundAccs)

foundationF1s = np.array(foundationF1s)
foundationAccs = np.array(foundationAccs)

print(f"Foundation F1s: , {foundationF1s.mean():.3f}, {foundationF1s.std():.3f}")
print(f"Foundation Accs:, {foundationAccs.mean():.3f}, {foundationAccs.std():.3f}")

Foundation F1s: , 0.770, 0.017
Foundation Accs:, 0.764, 0.021


In [9]:
print("NOTE: does not work with history objects. please save and load back in as pkl.")
def val_metrics(y_true, probs, target_value=0.90): # sensitivity at specificity and vice versa and PR/ROC AUCs

    fpr, tpr, roc_thresh = roc_curve(y_true, probs)
    precision, recall, pr_thresh = precision_recall_curve(y_true, probs)
    rocAUC = roc_auc_score(y_true, probs)
    prAUC = auc(recall[::-1], precision[::-1]) #need to flip precision and recall to calculate auc
    
    target_fpr = 1 - target_value
    calculated_sensitivity = np.interp(target_fpr, fpr, tpr)

    # Ensure monotonicity for finding specificity at sensitivity
    tpr_monotonic, idx = np.unique(tpr, return_index=True)
    fpr_at_tpr = fpr[idx]
    fpr_at_target = np.interp(target_value, tpr_monotonic, fpr_at_tpr)
    calculated_specificity = 1 - fpr_at_target

    return rocAUC, prAUC, calculated_sensitivity, calculated_specificity

foundationMetrics = {'roc auc':[], 'pr auc':[], 'sens@spec90':[], 'spec@sens90':[]}
for fold in range(len(loaded_histories)):
    rocAUC, prAUC, calculated_sensitivity, calculated_specificity = val_metrics(loaded_histories[fold]['int_val_gts'],loaded_histories[fold]['int_val_probs']['repeatNeeded'])
    foundationMetrics['roc auc'].append(rocAUC)
    foundationMetrics['pr auc'].append(prAUC)
    foundationMetrics['sens@spec90'].append(calculated_sensitivity)
    foundationMetrics['spec@sens90'].append(calculated_specificity)
    
foundationMetrics['roc auc'] = np.array(foundationMetrics['roc auc'])
foundationMetrics['pr auc'] = np.array(foundationMetrics['pr auc'])
foundationMetrics['sens@spec90'] = np.array(foundationMetrics['sens@spec90'])
foundationMetrics['spec@sens90'] = np.array(foundationMetrics['spec@sens90'])

print(f"Foundation CXR MLP| "
      f"avg roc auc: {foundationMetrics['roc auc'].mean():.3f}±{foundationMetrics['roc auc'].std():.3f}| "
      f"avg pr auc: {foundationMetrics['pr auc'].mean():.3f}±{foundationMetrics['pr auc'].std():.3f}| "
      f"avg sens@spec90: {foundationMetrics['sens@spec90'].mean():.3f}±{foundationMetrics['sens@spec90'].std():.3f}| "
      f"avg spec@sens90: {foundationMetrics['spec@sens90'].mean():.3f}±{foundationMetrics['spec@sens90'].std():.3f}")


NOTE: does not work with history objects. please save and load back in as pkl.
Foundation CXR MLP| avg roc auc: 0.824±0.017| avg pr auc: 0.830±0.028| avg sens@spec90: 0.565±0.036| avg spec@sens90: 0.476±0.065


# External Validation Testing (Zero-Shot)

In [10]:
labels = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/labels.npy")
file_paths = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/file_paths.npy")

In [11]:
extFeatures = []
extLabels = [] #should be exactly the same order as the original pytorch dataset

head_name = 'repeatNeeded'
for imgIdx in range(len(file_paths)):
    imageName = os.path.basename(file_paths[imgIdx])
    saveName = imageName.replace('.png', '.npy').replace('/', '_')
    embeddedDataPath = os.path.join(ext_encoded_dir, saveName)
    if os.path.exists(embeddedDataPath):
        emb = np.load(embeddedDataPath)
        extFeatures.append(emb.flatten())
        extLabels.append(dataset.labels[imgIdx])
    else:
        print(f'data file {embeddedDataPath} cannot be found. Please revise loop logic.')

extFeatures = np.array(extFeatures).astype('float32')
extLabels = np.array(extLabels).astype('float32')

In [13]:
SEED = 9999 #not needed right now since loading from saved indices; here for consistency
BATCH_SIZE = 32
EPOCHS = 100
numFolds = 5

def format_output(x, y):
    y_reshaped = {k: tf.expand_dims(v, axis=-1) for k, v in y.items()}
    return x, y_reshaped
    
foundCXRLogits = np.zeros((len(extLabels),1), dtype=np.float32)
foldLogits = []
ext_val_dataset = tf.data.Dataset.from_tensor_slices((extFeatures, {head_name: extLabels}))
ext_val_dataset = ext_val_dataset.map(format_output)
ext_val_dataset = ext_val_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for i in range(numFolds):
    #generate model
    tf.keras.backend.clear_session()
    #use no activation --> can combine logits later to use sigmoid of average like what was done in pytorch models
    
    foundationMLP = create_model([head_name], token_num=32, embeddings_size=128, dropout=0.3, hidden_layer_sizes=[256, 128], activation=None, loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),seed=SEED)

    weightPath = os.path.join(base_dir, 'foundationCXRMLP', f'{i}_CXRWeights.weights.h5')
    foundationMLP.load_weights(weightPath, skip_mismatch=True)

    raw_logits = foundationMLP.predict(ext_val_dataset)[head_name]
    foundCXRLogits = foundCXRLogits + raw_logits
    foldLogits.append(raw_logits)
    
foundCXRLogits = (foundCXRLogits/5)
foundCXRProbs = expit(foundCXRLogits)

    # print(history.history.keys())
foldSimilarity = []
for i in range(5): #correlation analysis between folds
    for j in range(i+1, 5):
        p_i = expit(foldLogits[i]).flatten()
        p_j = expit(foldLogits[j]).flatten()
        foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
foldSimilarity = np.array(foldSimilarity)

print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")
rocAUC, prAUC, sensAtSpec, specAtSens = val_metrics(extLabels, foundCXRProbs, target_value=0.90)

print(f"""Model name: CXR Foundation MLP \n
Ext Test ROC AUC: {rocAUC:.3f}
Ext Test PR AUC: {prAUC:.3f}
Ext Test sens @spec90%: {sensAtSpec:.3f}
Ext Test spec @sens90%: {specAtSens:.3f}
                ----------""")

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 109ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 114ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 113ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 139ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 113ms/step
Fold corr coeff -- mean: 0.884| std dev: 0.040
Model name: CXR Foundation MLP 

Ext Test ROC AUC: 0.527
Ext Test PR AUC: 0.413
Ext Test sens @spec90%: 0.032
Ext Test spec @sens90%: 0.113
                ----------


In [14]:
try:
    print(f"""Bootstrapping variable already set (value={needBootStrapping}). Are you sure you didn't already run this?
bootstrap indices are stored in 'bootStrapIdxs'.""")

except:
    needBootStrapping = True
    rng = np.random.default_rng(seed=SEED)
    B = 5000
    bootStrapIdxs = []
    extValSize = len(extLabels)
    
    if needBootStrapping:
        print("""setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.""")
        for _ in range(B):
            idx = np.random.randint(0, extValSize, extValSize)
            while True:
                if len(np.unique(extLabels[idx])) == 2:
                    break #if samples only have one label, try again
                idx = np.random.randint(0, extValSize, extValSize)
            bootStrapIdxs.append(idx)
        print("Bootstrapping indices set.")
    else:
        print("Bootstrapping not enabled. Remember, only run this ONCE in script for reproducibility.")
    needBootstrapping = False

setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.
Bootstrapping indices set.


In [16]:
metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
bootMetrics = pd.DataFrame(columns=metricNames)
print(f"----------\nStarting Bootstrapping:")
for bootNum in range(B):
    bootIdx = bootStrapIdxs[bootNum] #indices for this bootstrap
    bootLabels = extLabels[bootIdx]
    bootProbs = np.zeros(len(bootLabels))
    for fold in range(5):
        bootProbs = bootProbs + foldLogits[fold].flatten()[bootIdx]
    bootProbs = expit(bootProbs/5) #average of the five folds
    
    rocAUC, prAUC, sensAtSpec, specAtSens = val_metrics(bootLabels, bootProbs, target_value=0.90)
    currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
    bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)
# print(bootMetrics)
CIs = pd.DataFrame(columns=metricNames)
for metric in metricNames:
    lower = np.percentile(bootMetrics[metric], 2.5)
    upper = np.percentile(bootMetrics[metric], 97.5)
    CIs[metric] = [lower, upper]
    print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")

----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.435, 0.621
----------
prAUC CIs (2.5,97.5): 0.322, 0.521
----------
sensAtSpec CIs (2.5,97.5): 0.000, 0.164
----------
specAtSens CIs (2.5,97.5): 0.000, 0.359
----------


# External Validation Testing (Few-Shot)

In [21]:
# extValSet; fewshot_idx shows paths
# fewShotResults = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')
# domainAdaptIdxs = fewShotResults[0]['fewshot_idx']
# textDAIdxs = fewShotResults[0]['test_idx']

labels = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/labels.npy")
file_paths = np.load(r"/extra/xielab0/wuat2/AryaQualityViewProjectData/ExternalValData/file_paths.npy")

extTuneImgPaths = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')[0]['fewshot_idx']
extTuneLabels = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')[0]['fewshot_labels']

extTestImgPaths = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')[0]['test_idx']
extTestLabels = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')[0]['test_labels']

In [40]:
fineTuneFeatures = []
fineTuneLabels = []
fewShowTestFeatures = []
fewShotTestLabels = []

head_name = 'repeatNeeded'
for imgIdx in range(len(extTuneImgPaths)):
    imageName = os.path.basename(extTuneImgPaths[imgIdx])
    saveName = imageName.replace('.png', '.npy').replace('/', '_')
    embeddedDataPath = os.path.join(ext_encoded_dir, saveName)
    if os.path.exists(embeddedDataPath):
        emb = np.load(embeddedDataPath)
        fineTuneFeatures.append(emb.flatten())
        fineTuneLabels.append(extTuneLabels[imgIdx])
    else:
        print(f'data file {embeddedDataPath} cannot be found. Please revise loop logic.')
        
fineTuneFeatures = np.array(fineTuneFeatures).astype('float32')
fineTuneLabels = np.array(fineTuneLabels).astype('float32')

for imgIdx in range(len(extTestImgPaths)):
    imageName = os.path.basename(extTestImgPaths[imgIdx])
    saveName = imageName.replace('.png', '.npy').replace('/', '_')
    embeddedDataPath = os.path.join(ext_encoded_dir, saveName)
    if os.path.exists(embeddedDataPath):
        emb = np.load(embeddedDataPath)
        fewShowTestFeatures.append(emb.flatten())
        fewShotTestLabels.append(extTestLabels[imgIdx])
    else:
        print(f'data file {embeddedDataPath} cannot be found. Please revise loop logic.')

fewShowTestFeatures = np.array(fewShowTestFeatures).astype('float32')
fewShotTestLabels = np.array(fewShotTestLabels).astype('float32')

In [54]:
# uses a smaller LR and smaller weight decay due to small MLP; original causes it to break
EPOCHS = 30
BATCH_SIZE = 32
LR = 1e-3
WEIGHT_DECAY = 1e-4
POS_WEIGHT = 4.0
SEED = 9999

fewShotData = torch.load(r'/home/wuat2/xray-quality/external_validation_fewShotResults.pt')[0]

print(f"Preparing Few-Shot Data: {len(fineTuneLabels)} training, {len(fewShotTestLabels)} testing.")

def format_example(x, y):
    return x, tf.expand_dims(y, -1)

train_ds = tf.data.Dataset.from_tensor_slices((fineTuneFeatures, fineTuneLabels))
train_ds = train_ds.shuffle(len(fineTuneFeatures)).batch(BATCH_SIZE).map(format_example).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((fewShowTestFeatures, fewShotTestLabels))
test_ds = test_ds.batch(BATCH_SIZE).map(format_example).prefetch(tf.data.AUTOTUNE)

tf.keras.utils.set_random_seed(SEED)
tf.keras.backend.clear_session()

foldLogits = []
fineTunedLogits = np.zeros((len(fewShotTestLabels),1), dtype=np.float32)
for fold in range(5):
    finetune_model = create_model(
        [head_name], 
        token_num=32, 
        embeddings_size=128, 
        dropout=0.3, 
        hidden_layer_sizes=[256, 128], 
        activation=None, # Logits output
        seed=SEED
    )
    weightPath = os.path.join(base_dir, 'foundationCXRMLP', f'{fold}_CXRWeights.weights.h5')
    finetune_model.load_weights(weightPath, skip_mismatch=True)
    
    def weighted_bce_loss(y_true, y_pred):
        # y_pred are logits
        return tf.nn.weighted_cross_entropy_with_logits(
            labels=tf.cast(y_true, tf.float32),
            logits=y_pred,
            pos_weight=POS_WEIGHT,
        )
    
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=LR, 
        weight_decay=WEIGHT_DECAY
    )
    
    finetune_model.compile(
        optimizer=optimizer,
        loss=weighted_bce_loss,
        metrics=[
            tf.keras.metrics.AUC(name='auc', from_logits=True),
            tf.keras.metrics.BinaryAccuracy(name='acc', threshold=0.0) # 0.0 is threshold for logits
        ]
    )
    
    # --- 4. Training Loop ---
    print("Starting Fine-Tuning...")
    history = finetune_model.fit(
        train_ds,
        validation_data=test_ds,
        epochs=EPOCHS,
        verbose=1
    )
    
    # --- 5. Final Evaluation ---
    print("Evaluating on Test Set...")
    # Predict logits
    pred_logits = finetune_model.predict(test_ds)
    # Handle dict output if your model returns {head_name: logits}
    if isinstance(pred_logits, dict):
        pred_logits = pred_logits[head_name]

    fineTunedLogits = fineTunedLogits + pred_logits
    foldLogits.append(fineTunedLogits)
    
fineTunedLogits = fineTunedLogits/5
fewShotFoundProbs = tf.nn.sigmoid(fineTunedLogits).numpy().flatten()

foldSimilarity = []
for i in range(5): #correlation analysis between folds
    for j in range(i+1, 5):
        p_i = expit(foldLogits[i]).flatten()
        p_j = expit(foldLogits[j]).flatten()
        foldSimilarity.append(np.corrcoef(p_i, p_j)[0,1])
foldSimilarity = np.array(foldSimilarity)

print(f"Fold corr coeff -- mean: {foldSimilarity.mean():.3f}| std dev: {foldSimilarity.std():.3f}")
rocAUC, prAUC, sens, spec = val_metrics(fewShotTestLabels, fewShotFoundProbs, target_value=0.90)

print("="*40)
print(f"Few-Shot Results ({len(fineTuneLabels)} shots):")
print(f"ROC AUC: {rocAUC:.3f}")
print(f"PR AUC:  {prAUC:.3f}")
print(f"Sens@Spec90: {sens:.3f}")
print(f"Spec@Sens90: {spec:.3f}")
print("="*40)

Preparing Few-Shot Data: 50 training, 100 testing.
Starting Fine-Tuning...
Epoch 1/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 4s/step - acc: 0.6692 - auc: 0.7213 - loss: 1.3509 - val_acc: 0.6500 - val_auc: 0.4297 - val_loss: 2.0834
Epoch 2/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 42ms/step - acc: 0.6112 - auc: 0.6726 - loss: 1.1780 - val_acc: 0.7200 - val_auc: 0.4891 - val_loss: 2.0562
Epoch 3/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - acc: 0.6646 - auc: 0.7815 - loss: 1.0479 - val_acc: 0.7800 - val_auc: 0.5369 - val_loss: 2.1103
Epoch 4/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 41ms/step - acc: 0.7508 - auc: 0.8336 - loss: 0.9391 - val_acc: 0.7800 - val_auc: 0.5756 - val_loss: 2.1641
Epoch 5/30
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - acc: 0.7404 - auc: 0.9259 - loss: 0.7214 - val_acc: 0.7800 - val_auc: 0.5925 - val_loss: 2.2172
Epoch 6/30
[1m2/2[0m

In [65]:
np.save('foundationZeroShotPreds.npy', foundCXRProbs)
np.save('foundationFewShotPreds.npy', fewShotFoundProbs)
np.save('foundationZeroShotGTs.npy', extLabels)
np.save('foundationFewShotGTs.npy', fewShotTestLabels)

In [56]:
try:
    print(f"""Bootstrapping variable already set (value={needExtValBootStrapping}). Are you sure you didn't already run this?
bootstrap indices are stored in 'bootExtValStrapIdxs'.""")

except:
    needExtValBootStrapping = True
    rng = np.random.default_rng(seed=SEED)
    B = 5000
    bootExtValStrapIdxs = []
    extValSize = len(fewShotTestLabels)
    
    if needExtValBootStrapping:
        print("""setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.""")
        for _ in range(B):
            idx = np.random.randint(0, extValSize, extValSize)
            while True:
                if len(np.unique(fewShotTestLabels[idx])) == 2:
                    break #if samples only have one label, try again
                idx = np.random.randint(0, extValSize, extValSize)
            bootExtValStrapIdxs.append(idx)
        print("needExtValBootStrapping indices set.")
    else:
        print("needExtValBootStrapping not enabled. Remember, only run this ONCE in script for reproducibility.")
    needExtValBootStrapping = False

setting random seed and bootstrapping indices. For reproducibility, only run this ONCE. 
If accidentally ran again, restart kernel and try again.
Bootstrapping indices set.


In [58]:
metricNames = ["rocAUC","prAUC","sensAtSpec","specAtSens"]
bootMetrics = pd.DataFrame(columns=metricNames)
print(f"----------\nStarting Bootstrapping:")
for bootNum in range(B):
    bootIdx = bootExtValStrapIdxs[bootNum] #indices for this bootstrap
    bootLabels = fewShotTestLabels[bootIdx]
    bootProbs = np.zeros(len(bootLabels))
    for fold in range(5):
        bootProbs = bootProbs + foldLogits[fold].flatten()[bootIdx]
    bootProbs = expit(bootProbs/5) #average of the five folds
    
    rocAUC, prAUC, sensAtSpec, specAtSens = val_metrics(bootLabels, bootProbs, target_value=0.90)
    currMetrics = pd.DataFrame([{"rocAUC": rocAUC, "prAUC": prAUC, "sensAtSpec": sensAtSpec, "specAtSens": specAtSens}])
    bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)
# print(bootMetrics)
CIs = pd.DataFrame(columns=metricNames)
for metric in metricNames:
    lower = np.percentile(bootMetrics[metric], 2.5)
    upper = np.percentile(bootMetrics[metric], 97.5)
    CIs[metric] = [lower, upper]
    print(f"{metric} CIs (2.5,97.5): {lower:.3f}, {upper:.3f}\n----------")

----------
Starting Bootstrapping:


  bootMetrics = pd.concat([bootMetrics, currMetrics], ignore_index=True)


rocAUC CIs (2.5,97.5): 0.460, 0.739
----------
prAUC CIs (2.5,97.5): 0.142, 0.446
----------
sensAtSpec CIs (2.5,97.5): 0.000, 0.435
----------
specAtSens CIs (2.5,97.5): 0.150, 0.462
----------
