### generate candidate directions

In [6]:
# Choose dataset: "squad", "repliqa", "nq", "musique"
data_name = "squad"

# Choose model: "llama3", "gemma3"
model_name = "llama3"

In [None]:
from pipeline.generate_directions import generate_directions
from pipeline.model_utils.model_factory import construct_model_base
from data.load_datasets import load_data

path = f'data/abstain_aware_prompt/{data_name}'
model_path = {
    "llama3": "meta-llama/meta-llama-3-8b-instruct",
    "gemma3": "google/gemma-3-12b-it"
}[model_name]

model_base = construct_model_base(model_path)

In [None]:
ans_train, unans_train = load_data(path, "train")
ans_val, unans_val = load_data(path, "val")
dirs_path = f'pipeline/runs/{model_name}/{data_name}'
candidate_directions = generate_directions(model_base, unans_train, ans_train, dirs_path, batch_size=8)

### select unanswerability direction

In [None]:
from pipeline.select_by_steering import select_direction

ans_val, unans_val = load_data(path, "val")  
pos, layer, best_dir = select_direction(model_base, unans_val, ans_val, candidate_directions, f'{dirs_path}/select_by_steering', batch_size=4)


### find threshold

In [None]:
from pipeline.utils.threshold_utils import get_threshold_by_curve
import json

with open(f'{dirs_path}/select_by_steering/direction_metadata.json', 'r') as f:
    dir_metadata = json.load(f)

dir_vector = candidate_directions[pos, layer]
fpr, tpr, roc_auc, best_roc_idx, threshold = get_threshold_by_curve(dir_vector, model_base, pos, layer, ans_val, unans_val)
print(f"threshold for {data_name} is {threshold:.2f}")
dir_metadata['threshold'] = f"{threshold:.2f}"
with open(f'{dirs_path}/select_by_steering/direction_metadata.json', 'w') as f:
    json.dump(dir_metadata, f, indent=4)

### classify unanswerability

In [None]:
from evaluate import evaluate_by_projecting

eval_data = "musique" # Choose eval dataset: "squad", "repliqa", "nq", "musique"
eval_path = f'data/abstain_aware_prompt/{eval_data}'
ans_test, unans_test = load_data(eval_path, "test")

evaluate_by_projecting(f'{dirs_path}/evaluations/eval_on_{eval_data}', ans_test, unans_test, model_base, dir_vector, pos, layer, threshold)


### threshold calibration

In [None]:
ans_val, unans_val = load_data(eval_path, "val")

with open(f'{dirs_path}/select_by_steering/direction_metadata.json', 'r') as f:
    dir_metadata = json.load(f)
dir_vector = candidate_directions[pos, layer]
fpr, tpr, roc_auc, best_roc_idx, threshold = get_threshold_by_curve(dir_vector, model_base, pos, layer, ans_val, unans_val)
print(f"threshold for {eval_data} is {threshold:.2f}")
dir_metadata[f'threshold_on_{eval_data}'] = f"{threshold:.2f}"
with open(f'{dirs_path}/select_by_steering/direction_metadata.json', 'w') as f:
    json.dump(dir_metadata, f, indent=4)

In [None]:
evaluate_by_projecting(f'{dirs_path}/evaluations/calibrated_threshold/eval_on_{eval_data}', ans_test, unans_test, model_base, dir_vector, pos, layer, threshold)

## figures

In [7]:
import os
import json
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

datasets = ["SQuAD", "RepLiQA", "NQ", "MuSiQue"]
dataset_order = datasets
model_names = ["llama3", "gemma3"]
methods = ["Direction", "DirectionRefined"]

formal_names = {"llama3": "Llama 3", "gemma3": "Gemma 3"}
method_titles = {
    "Direction": "Direction",
    "DirectionRefined": "Direction (Calibrated Threshold)"
}

method_x_positions = [0.25, 0.72]

all_results = {"F1": {}, "Recall": {}}

for model_name in model_names:
    root = f"pipeline/runs/{model_name}"
    for method in methods:
        rows = []
        for train_ds in datasets:
            for eval_ds in datasets:
                if method == "Direction":
                    path = os.path.join(
                        root, train_ds.lower(),
                        f"evaluations/eval_on_{eval_ds.lower()}/evaluation_results.json"
                    )
                else:
                    path = os.path.join(
                        root, train_ds.lower(),
                        f"evaluations/calibrated_threshold/eval_on_{eval_ds.lower()}/evaluation_results.json"
                    )

                if not os.path.exists(path):
                    continue

                with open(path, "r") as f:
                    data = json.load(f)

                try:
                    f1 = data["overall"]["f1"] * 100
                    recall = data["unanswerable"]["recall"] * 100
                except (KeyError, TypeError):
                    continue

                rows.append({
                    "Train Dataset": train_ds,
                    "Eval Dataset": eval_ds,
                    "F1": f1,
                    "Recall": recall,
                })

        df = pd.DataFrame(rows)

        for metric in ["F1", "Recall"]:
            pivot = (
                df.pivot(index="Train Dataset", columns="Eval Dataset", values=metric)
                  .reindex(index=dataset_order, columns=dataset_order)
            )
            all_results[metric][(model_name, method)] = pivot

def plot_metric_grid(metric_name: str, save_suffix: str):
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8.5), sharex=True, sharey=True)

    for j, method in enumerate(methods):
        fig.text(method_x_positions[j], 0.94, method_titles[method], fontsize=18, ha='center')

    for i, model_name in enumerate(model_names):
        for j, method in enumerate(methods):
            ax = axes[i][j]
            df_plot = all_results[metric_name].get((model_name, method))
            if df_plot is not None:
                sns.heatmap(
                    df_plot,
                    annot=True,
                    fmt=".1f",
                    cmap="Blues",
                    vmin=0, vmax=100,
                    ax=ax,
                    annot_kws={"size": 13}
                )

            ax.set_title(formal_names[model_name], fontsize=16)

            if j == 0:
                ax.set_yticklabels(ax.get_yticklabels(), fontsize=14)
            if i == 1:
                ax.set_xticklabels(ax.get_xticklabels(), fontsize=14)

    for ax in axes[1]:
        ax.set_xlabel("Evaluation Dataset", fontsize=16, labelpad=10)
    for ax in axes[:, 0]:
        ax.set_ylabel("Training Dataset", fontsize=16, labelpad=10)
    for ax in axes[:, 1]:
        ax.set_ylabel("")
    for ax in axes[0]:
        ax.set_xlabel("")

    plt.tight_layout(rect=[0, 0, 1, 0.93])
    os.makedirs("plots", exist_ok=True)
    plt.savefig(f"plots/methods_by_model_heatmaps_{save_suffix}.png", bbox_inches='tight', dpi=300)
    plt.savefig(f"plots/methods_by_model_heatmaps_{save_suffix}.pdf", format='pdf', bbox_inches='tight', dpi=300)
    plt.show()

### Unanswerable prompts recall heatmap

In [None]:
plot_metric_grid("Recall", "recall")

### F1 scores heatmap

In [None]:
plot_metric_grid("F1", "f1")

### ROC curves

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pipeline.utils.threshold_utils import get_threshold_by_curve
from pipeline.model_utils.model_factory import construct_model_base
from data.load_datasets import load_data
import json

def plot_roc(ax, curves, model_label):
    sns.set(style="whitegrid", rc={"lines.linewidth": 2})
    #sns.set_palette(sns.color_palette("Paired"))
    custom_palette = ["#7a6bbf", "#ff9c42", "#5cb85c", "#e15759"]
    sns.set_palette(custom_palette)

    
    for name, d in curves.items():
        sns.lineplot(x=d["fpr"], y=d["tpr"], label=f"{name} (AUC = {d['auc']:.2f})", ax=ax)
        ax.scatter(d["fpr"][d["idx"]], d["tpr"][d["idx"]],
                   s=60, marker='o', edgecolors='black',
                   label=f"{name} threshold = {d['thr']:.2f}")
    ax.plot([0, 1], [0, 1], linestyle='--', color='gray', linewidth=1)
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate', fontsize=18)
    ax.set_ylabel('True Positive Rate', fontsize=18)
    ax.set_title(f'{model_label}', fontsize=20)
    ax.legend(loc='lower right', frameon=True, fontsize=14)
    ax.grid(True, linestyle='--', linewidth=0.5)

# Load and plot
fig, axs = plt.subplots(nrows=2, figsize=(7, 10))

with open("pipeline/runs/llama3/roc_curves_data.json", "r") as f:
    llama_curves = json.load(f)
with open("pipeline/runs/gemma3/roc_curves_data.json", "r") as f:
    gemma_curves = json.load(f)

plot_roc(axs[0], llama_curves, "Llama 3")
plot_roc(axs[1], gemma_curves, "Gemma 3")
axs[0].set_xlabel("")

plt.tight_layout()
plt.savefig('plots/roc_curves.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()
