In [4]:
import os
import pandas as pd
from databricks import sql
from dotenv import load_dotenv

import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from sklearn.manifold import TSNE
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification

# -----------------------
# Config
# -----------------------
MAX_TSNE_SAMPLES = 2000
BATCH_SIZE = 32
MAX_LENGTH = 120
FIG_DIR = "./tsne_figures"
os.makedirs(FIG_DIR, exist_ok=True)

# Model + path config (matches your folder layout)
MODEL_CONFIGS = {
    "Bio_ClinicalBERT": {
        "tokenizer_path": "./trained_models/emilyalsentzer_Bio_ClinicalBERT_tokenizer",
        "baseline_path": "./trained_models/emilyalsentzer_Bio_ClinicalBERT_baseline",
        "contrastive_path": "./trained_models/emilyalsentzer_Bio_ClinicalBERT_contrastive_encoder",
        "title_prefix": "Bio_ClinicalBERT",
    },
    "DeBERTa-v3-base": {
        "tokenizer_path": "./trained_models/microsoft_deberta-v3-base_tokenizer",
        "baseline_path": "./trained_models/microsoft_deberta-v3-base_baseline",
        "contrastive_path": "./trained_models/microsoft_deberta-v3-base_contrastive_encoder",
        "title_prefix": "DeBERTa-v3-base",
    },
    "RadBERT": {
        "tokenizer_path": "./trained_models/zzxslp_RadBERT-RoBERTa-4m_tokenizer",
        "baseline_path": "./trained_models/zzxslp_RadBERT-RoBERTa-4m_baseline",
        "contrastive_path": "./trained_models/zzxslp_RadBERT-RoBERTa-4m_contrastive_encoder",
        "title_prefix": "RadBERT",
    },
}

# -----------------------
# 1. Load test data
# -----------------------
load_dotenv()

connection = sql.connect(
    server_hostname=os.getenv("DATABRICKS_HOST").replace("https://", "").replace("http://", ""),
    http_path="/sql/1.0/warehouses/fe659a9780b351a1",
    access_token=os.getenv("DATABRICKS_TOKEN"),
)

test_query = """
SELECT subject_id, study_id, findings, impression, label, confidence
FROM workspace.default.mimic_cxr_test_set_label_explanation_consensus_v1
WHERE findings IS NOT NULL
  AND impression IS NOT NULL
  AND label IN ('Normal', 'Abnormal')
"""

df_test = pd.read_sql(test_query, connection)
connection.close()

df_test["Context"] = (df_test["findings"].fillna("") + " " +
                      df_test["impression"].fillna("")).str.strip()
label_map = {"Normal": 0, "Abnormal": 1}
df_test["Result"] = df_test["label"].map(label_map)

test_texts_full = df_test["Context"].tolist()
test_labels_full = df_test["Result"].tolist()

print("Total test samples:", len(test_texts_full))

# Use the same subset for all models
texts = test_texts_full[:MAX_TSNE_SAMPLES]
labels = test_labels_full[:MAX_TSNE_SAMPLES]

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# -----------------------
# Helper: embeddings
# -----------------------
def get_sentence_embeddings(texts, labels, model, tokenizer, device,
                            batch_size=BATCH_SIZE, max_length=MAX_LENGTH):
    model.eval()

    all_embs = []
    all_labs = []

    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_labels = labels[i:i + batch_size]

            enc = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(device)

            # We always request hidden_states to be safe across model types
            outputs = model(**enc, output_hidden_states=True)

            # For AutoModel / AutoModelForSequenceClassification both will have hidden_states
            last_hidden = outputs.hidden_states[-1]  # [B, L, H]

            # Masked mean pooling
            mask = enc["attention_mask"].unsqueeze(-1)  # [B, L, 1]
            masked_hidden = last_hidden * mask
            lengths = mask.sum(dim=1).clamp(min=1)      # [B, 1]
            sent_emb = masked_hidden.sum(dim=1) / lengths  # [B, H]

            all_embs.append(sent_emb.cpu())
            all_labs.extend(batch_labels)

    emb_tensor = torch.cat(all_embs, dim=0)  # [N, H]
    emb_array = emb_tensor.numpy()
    label_array = np.array(all_labs)
    return emb_array, label_array


# -----------------------
# Helper: run t-SNE + plot
# -----------------------
def tsne_and_save(emb_array, label_array, title, outfile):
    print(f"Running t-SNE for {title} (N={emb_array.shape[0]})...")
    tsne = TSNE(
        n_components=3,
        learning_rate="auto",
        init="random",
        perplexity=30,
        random_state=42,
    )
    emb_3d = tsne.fit_transform(emb_array)  # [N, 3]

    labels_np = label_array
    normal_idx = labels_np == 0
    abnormal_idx = labels_np == 1

    fig = plt.figure(figsize=(8, 7))
    ax = fig.add_subplot(111, projection="3d")

    ax.scatter(
        emb_3d[normal_idx, 0],
        emb_3d[normal_idx, 1],
        emb_3d[normal_idx, 2],
        alpha=0.6,
        label="Normal (0)",
        s=10,
    )
    ax.scatter(
        emb_3d[abnormal_idx, 0],
        emb_3d[abnormal_idx, 1],
        emb_3d[abnormal_idx, 2],
        alpha=0.6,
        label="Abnormal (1)",
        marker="^",
        s=10,
    )

    ax.set_title(title)
    ax.set_xlabel("t-SNE Dim 1")
    ax.set_ylabel("t-SNE Dim 2")
    ax.set_zlabel("t-SNE Dim 3")
    ax.legend()
    plt.tight_layout()
    plt.savefig(outfile, dpi=300)
    plt.close(fig)
    print(f"Saved: {outfile}")


# -----------------------
# Main loop over models
# -----------------------
for model_key, cfg in MODEL_CONFIGS.items():
    print("=" * 80)
    print(f"Processing model: {model_key}")
    print("=" * 80)

    # Shared tokenizer for baseline + contrastive
    tokenizer = AutoTokenizer.from_pretrained(
        cfg["tokenizer_path"],
        use_fast=False
    )

    # ---------- Baseline ----------
    print(f"[{model_key}] Loading baseline model...")
    baseline_model = AutoModelForSequenceClassification.from_pretrained(
        cfg["baseline_path"],
        local_files_only=True,
    ).to(device)

    print(f"[{model_key}] Computing baseline embeddings...")
    emb_base, lab_base = get_sentence_embeddings(texts, labels, baseline_model, tokenizer, device)

    base_title = f"{cfg['title_prefix']} – Baseline (No Contrastive)"
    base_outfile = os.path.join(
        FIG_DIR,
        f"{model_key.replace('/', '_').replace(' ', '_')}_baseline_tsne.png",
    )
    tsne_and_save(emb_base, lab_base, base_title, base_outfile)

    # ---------- Contrastive Encoder ----------
    print(f"[{model_key}] Loading contrastive encoder...")
    contrastive_model = AutoModel.from_pretrained(
        cfg["contrastive_path"],
        local_files_only=True,
    ).to(device)

    print(f"[{model_key}] Computing contrastive embeddings...")
    emb_con, lab_con = get_sentence_embeddings(texts, labels, contrastive_model, tokenizer, device)

    con_title = f"{cfg['title_prefix']} – With Contrastive Learning"
    con_outfile = os.path.join(
        FIG_DIR,
        f"{model_key.replace('/', '_').replace(' ', '_')}_contrastive_tsne.png",
    )
    tsne_and_save(emb_con, lab_con, con_title, con_outfile)

print("\nAll t-SNE visualizations completed.")


  df_test = pd.read_sql(test_query, connection)


Total test samples: 4472
Using device: cuda
Processing model: Bio_ClinicalBERT
[Bio_ClinicalBERT] Loading baseline model...
[Bio_ClinicalBERT] Computing baseline embeddings...
Running t-SNE for Bio_ClinicalBERT – Baseline (No Contrastive) (N=2000)...
Saved: ./tsne_figures\Bio_ClinicalBERT_baseline_tsne.png
[Bio_ClinicalBERT] Loading contrastive encoder...
[Bio_ClinicalBERT] Computing contrastive embeddings...
Running t-SNE for Bio_ClinicalBERT – With Contrastive Learning (N=2000)...
Saved: ./tsne_figures\Bio_ClinicalBERT_contrastive_tsne.png
Processing model: DeBERTa-v3-base
[DeBERTa-v3-base] Loading baseline model...
[DeBERTa-v3-base] Computing baseline embeddings...
Running t-SNE for DeBERTa-v3-base – Baseline (No Contrastive) (N=2000)...
Saved: ./tsne_figures\DeBERTa-v3-base_baseline_tsne.png
[DeBERTa-v3-base] Loading contrastive encoder...
[DeBERTa-v3-base] Computing contrastive embeddings...
Running t-SNE for DeBERTa-v3-base – With Contrastive Learning (N=2000)...
Saved: ./tsne_fi

In [5]:
import os
from PIL import Image
import matplotlib.pyplot as plt

FIG_DIR = "./tsne_figures"
OUT_FILE = "./tsne_figures/unified_tsne_comparison.png"

# Ordered layout (row-wise)
IMAGE_ORDER = [
    ("Bio_ClinicalBERT_baseline_tsne.png", "BioBERT – No Contrastive"),
    ("Bio_ClinicalBERT_contrastive_tsne.png", "BioBERT – With Contrastive"),
    ("DeBERTa-v3-base_baseline_tsne.png", "DeBERTa – No Contrastive"),
    ("DeBERTa-v3-base_contrastive_tsne.png", "DeBERTa – With Contrastive"),
    ("RadBERT_baseline_tsne.png", "RadBERT – No Contrastive"),
    ("RadBERT_contrastive_tsne.png", "RadBERT – With Contrastive"),
]

# Load images
images = []
titles = []

for fname, title in IMAGE_ORDER:
    path = os.path.join(FIG_DIR, fname)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Missing image: {path}")
    images.append(Image.open(path))
    titles.append(title)

# Create figure: 3 rows × 2 columns
fig, axes = plt.subplots(3, 2, figsize=(16, 22))

for ax, img, title in zip(axes.flatten(), images, titles):
    ax.imshow(img)
    ax.set_title(title, fontsize=16)
    ax.axis("off")

plt.suptitle(
    "Latent Space Visualization (t-SNE, Test Set)\nBaseline vs Contrastive Learning",
    fontsize=20,
    y=0.98,
)

plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.savefig(OUT_FILE, dpi=300)
plt.close()

print(f"\nUnified figure saved to:\n{OUT_FILE}")



Unified figure saved to:
./tsne_figures/unified_tsne_comparison.png


In [6]:
import os
import glob
import pandas as pd

METRICS_DIR = "trained_model_metrics"
OUT_CSV = os.path.join(METRICS_DIR, "summary_table.csv")

def short_model_name(full_name: str) -> str:
    """Map HuggingFace model names to nice display names."""
    if "Bio_ClinicalBERT" in full_name:
        return "Bio_ClinicalBERT-document"
    if "deberta-v3-base" in full_name:
        return "DeBERTa-v3-base-document"
    if "RadBERT-RoBERTa-4m" in full_name:
        return "RadBERT-4m-document"
    # fallback
    return full_name

def load_and_augment(path: str) -> pd.DataFrame:
    df = pd.read_csv(path)

    # Specificity = TN / (TN + FP)
    df["specificity"] = df["tn"] / (df["tn"] + df["fp"])

    # Sensitivity = recall (TP / (TP + FN))
    df["sensitivity"] = df["recall"]

    # Nice model + loss labels
    df["Model"] = df["model_name"].apply(short_model_name)
    df["Loss"] = df["method"].map({
        "Baseline": "baseline loss",
        "Contrastive_Last_Layer": "contrastive loss",
    })

    # Convert to percentages and round (like the paper)
    for col in ["accuracy", "specificity", "sensitivity", "f1"]:
        df[col] = (df[col] * 100).round(2)

    return df[["Model", "Loss", "accuracy", "specificity", "sensitivity", "f1"]]

# ---------------------- main ---------------------- #
all_dfs = []

for path in glob.glob(os.path.join(METRICS_DIR, "*_metrics.csv")):
    print(f"Loading {path}...")
    all_dfs.append(load_and_augment(path))

summary_df = pd.concat(all_dfs, ignore_index=True)

# Sort for a nice table order (optional)
summary_df.sort_values(["Model", "Loss"], inplace=True)

print("\nSummary table:")
print(summary_df)

summary_df.to_csv(OUT_CSV, index=False)
print(f"\nSaved summary CSV to: {OUT_CSV}")


Loading trained_model_metrics\emilyalsentzer_Bio_ClinicalBERT_metrics.csv...
Loading trained_model_metrics\microsoft_deberta-v3-base_metrics.csv...
Loading trained_model_metrics\zzxslp_RadBERT-RoBERTa-4m_metrics.csv...

Summary table:
                       Model              Loss  accuracy  specificity  \
0  Bio_ClinicalBERT-document     baseline loss     95.59        92.52   
1  Bio_ClinicalBERT-document  contrastive loss     95.48        94.56   
2   DeBERTa-v3-base-document     baseline loss     95.84        92.89   
3   DeBERTa-v3-base-document  contrastive loss     94.99        93.65   
4        RadBERT-4m-document     baseline loss     95.86        93.65   
5        RadBERT-4m-document  contrastive loss     95.73        96.07   

   sensitivity     f1  
0        96.89  96.87  
1        95.87  96.76  
2        97.08  97.05  
3        95.55  96.41  
4        96.79  97.05  
5        95.59  96.92  

Saved summary CSV to: trained_model_metrics\summary_table.csv


In [7]:
summary_df

Unnamed: 0,Model,Loss,accuracy,specificity,sensitivity,f1
0,Bio_ClinicalBERT-document,baseline loss,95.59,92.52,96.89,96.87
1,Bio_ClinicalBERT-document,contrastive loss,95.48,94.56,95.87,96.76
2,DeBERTa-v3-base-document,baseline loss,95.84,92.89,97.08,97.05
3,DeBERTa-v3-base-document,contrastive loss,94.99,93.65,95.55,96.41
4,RadBERT-4m-document,baseline loss,95.86,93.65,96.79,97.05
5,RadBERT-4m-document,contrastive loss,95.73,96.07,95.59,96.92
