In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Dataset, DatasetDict
from transformers import BertModel, AutoTokenizer
from sklearn.metrics import roc_curve, auc, f1_score, accuracy_score, roc_auc_score
import matplotlib.pyplot as plt

label_cols = [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
    "Lung Lesion", "Edema", "Consolidation", "Pneumonia",
    "Atelectasis", "Pneumothorax", "Pleural Effusion",
    "Pleural Other", "Fracture", "Support Devices"
]

def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
    df["text"] = df["text"].fillna("NA")
    df[label_cols] = df[label_cols].fillna(0).astype(float)
    df["labels"] = df[label_cols].values.tolist()
    return df.drop(columns=label_cols).reset_index(drop=True)

def load_and_clean_all(train_path: str, dev_path: str, test_path: str) -> DatasetDict:
    dfs = [clean_dataframe(pd.read_csv(p)) for p in (train_path, dev_path, test_path)]
    sets = [Dataset.from_pandas(df, preserve_index=False) for df in dfs]
    return DatasetDict({"train": sets[0], "validation": sets[1], "test": sets[2]})

train_file = "/text_test_data_noimpression.csv"
dev_file   = "/text_val_data_noimpression.csv"
test_file  = "/text_test_data_noimpression.csv"
raw_datasets = load_and_clean_all(train_file, dev_file, test_file)

class CLSEmbeddingZeroShot:
    def __init__(self, model_names, candidate_labels, device=-1):
        self.device = torch.device(f"cuda:{device}" if device>=0 and torch.cuda.is_available() else "cpu")
        self.models = {}
        self.tokenizers = {}
        self.label_embs = {}
        self.candidate_labels = candidate_labels
        for name in model_names:
            tok = AutoTokenizer.from_pretrained(name)
            model = BertModel.from_pretrained(name).to(self.device)
            model.eval()
            inputs = tok(candidate_labels, padding=True, return_tensors="pt").to(self.device)
            with torch.no_grad():
                out = model(**inputs)
                vec = out.pooler_output if hasattr(out, 'pooler_output') else out.last_hidden_state[:,0,:]
                normed = F.normalize(vec, p=2, dim=1)
            self.tokenizers[name] = tok
            self.models[name] = model
            self.label_embs[name] = normed.cpu()

    def classify_batch(self, texts):
        results = {}
        for name, model in self.models.items():
            tok = self.tokenizers[name]
            all_vecs = []
            for text in texts:
                inputs = tok(text, padding=True, return_tensors="pt").to(self.device)
                with torch.no_grad():
                    out = model(**inputs)
                    vec = out.pooler_output if hasattr(out, 'pooler_output') else out.last_hidden_state[:,0,:]
                    normed = F.normalize(vec, p=2, dim=1).cpu()
                all_vecs.append(normed)
            text_emb = torch.vstack(all_vecs)
            sims = text_emb @ self.label_embs[name].T
            results[name] = sims.numpy()
        return results

    @staticmethod
    def ensemble_scores(batch_scores, method="average"):
        arr = np.stack(list(batch_scores.values()), axis=0)
        return arr.mean(axis=0) if method=='average' else arr.max(axis=0)

    @staticmethod
    def evaluate(y_true, y_scores, threshold=0.5):
        y_pred = (y_scores >= threshold).astype(int)
        m = {
            'accuracy': accuracy_score(y_true, y_pred),
            'macro_f1': f1_score(y_true, y_pred, average='macro'),
            'micro_f1': f1_score(y_true, y_pred, average='micro')
        }
        try:
            m['macro_auc'] = roc_auc_score(y_true, y_scores, average='macro')
            m['micro_auc'] = roc_auc_score(y_true, y_scores, average='micro')
        except ValueError:
            m['macro_auc'] = m['micro_auc'] = -1
        for i in range(y_true.shape[1]):
            try:
                m[f'auc_class_{i}'] = roc_auc_score(y_true[:,i], y_scores[:,i])
            except ValueError:
                m[f'auc_class_{i}'] = -1
        return m

if __name__ == "__main__":
    model_list = [
        "dmis-lab/biobert-v1.1",
        "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
        "emilyalsentzer/Bio_ClinicalBERT",
    ]
    name_map = {
        "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext": "biomedbert",
        "dmis-lab/biobert-v1.1":                                      "biobert",
        "emilyalsentzer/Bio_ClinicalBERT":                           "clinicalbert"
    }
    candidate_labels = label_cols.copy()

    texts  = raw_datasets['test']['text']
    y_true = np.vstack(raw_datasets['test']['labels'])

    device = 0 if torch.cuda.is_available() else -1
    zs = CLSEmbeddingZeroShot(model_list, candidate_labels, device=device)

    batch_scores = zs.classify_batch(texts)
    #ensemble_scores = zs.ensemble_scores(batch_scores, method='average')

    for name, scores in batch_scores.items():
      m = zs.evaluate(y_true, scores)
      print(f"\n模型: {name}\n", m)

    y_score_dict = {name: scores for name, scores in batch_scores.items()}
    #y_score_dict["Ensemble"] = ensemble_scores

    plt.figure(figsize=(7,7))
    for full_name, short_name in name_map.items():
        scores = batch_scores[full_name]
        n_labels = y_true.shape[1]

        all_fpr = np.unique(
            np.concatenate([
                roc_curve(y_true[:, i], scores[:, i])[0]
                for i in range(n_labels)
            ])
        )
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(n_labels):
            fpr, tpr, _ = roc_curve(y_true[:, i], scores[:, i])
            mean_tpr += np.interp(all_fpr, fpr, tpr)
        mean_tpr /= n_labels

        model_auc = auc(all_fpr, mean_tpr)
        plt.plot(all_fpr, mean_tpr, lw=2, label=f"{short_name} (AUC {model_auc:.3f})")

    plt.plot([0, 1], [0, 1], "--", color="gray", lw=1)
    plt.title("Macro-average ROC Curves Across Zero-shot Models")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score

scores = y_score_dict["microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"]  # shape (n_samples, n_labels)
class_names = [
    "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
    "Lung Lesion", "Edema", "Consolidation", "Pneumonia",
    "Atelectasis", "Pneumothorax", "Pleural Effusion",
    "Pleural Other", "Fracture", "Support Devices"
]
n_labels = scores.shape[1]
roc_aucs = []
pr_aucs  = []

for i in range(n_labels):
    roc_aucs.append(roc_auc_score(y_true[:, i], scores[:, i]))
    pr_aucs.append(average_precision_score(y_true[:, i], scores[:, i]))

x = np.arange(n_labels)
width = 0.35

fig, ax = plt.subplots(figsize=(12, 6))
bars1 = ax.bar(x - width/2, roc_aucs, width, label="ROC AUC", color="skyblue")
bars2 = ax.bar(x + width/2, pr_aucs,  width, label="PR AUC",  color="gold")

for bar in bars1 + bars2:
    h = bar.get_height()
    ax.annotate(f"{h:.3f}",
                xy=(bar.get_x() + bar.get_width()/2, h),
                xytext=(0, 3), textcoords="offset points",
                ha="center", va="bottom", fontsize=8)

ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha="right")
ax.set_ylabel("AUC")
ax.set_title("Per-class ROC AUC and PR AUC Comparison (biomedbert)")
ax.legend()
ax.grid(axis="y", linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()