### steering effectiveness

In [None]:
from evaluate import generate_and_save_completions_for_dataset
from pipeline.model_utils.model_factory import construct_model_base
from data.load_datasets import load_data
from pipeline.utils.hook_utils import get_activation_addition_input_pre_hook
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json

datasets = ["squad", "repliqa", "nq", "musique"]
alphas = np.arange(-2.0, 2.0 + 0.25, 0.25)

all_rows = []
for model_name in ["llama3", "gemma3"]:
    for data_name in datasets:
        print(f"Running on dataset: {data_name}")
        path = f'data/abstain_aware_prompt/{data_name}'
        ans_val, unans_val = load_data(path, "val")

        model_path = {
            "llama3": "meta-llama/meta-llama-3-8b-instruct",
            "gemma3": "google/gemma-3-12b-it"
            }[model_name]
        model_base = construct_model_base(model_path) 
        with open(f'pipeline/runs/{model_name}/{data_name}/select_by_steering/direction_metadata.json', 'r') as f:
            direction_metadata = json.load(f)
        pos = direction_metadata['pos']
        layer = direction_metadata['layer']

        dirs_path = f'pipeline/runs/{model_name}/{data_name}'
        candidate_directions = torch.load(f'{dirs_path}/mean_diffs.pt')
        dir_vector = candidate_directions[pos, layer]

        for alpha in alphas:
            fwd_pre_hooks = [(model_base.model_block_modules[layer], get_activation_addition_input_pre_hook(vector=dir_vector, coeff=alpha))]
            fwd_hooks = []

            ans_completions = generate_and_save_completions_for_dataset(model_base, fwd_pre_hooks, fwd_hooks, ans_val)
            unans_completions = generate_and_save_completions_for_dataset(model_base, fwd_pre_hooks, fwd_hooks, unans_val)

            all_rows.append({"model_name": model_name, "alpha": alpha, "ans_completions": ans_completions, "prompt_type": "Answerable", "dataset": data_name.capitalize()})
            all_rows.append({"model_name": model_name, "alpha": alpha, "unans_completions": unans_completions, "prompt_type": "Unanswerable", "dataset": data_name.capitalize()})

with open(f'analysis/causal_interventions/completions_under_interventions.json', 'w') as f:
    json.dump(all_rows, f, indent=4)

In [3]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import json

# run eval_causal_interventions.py to save abstention rates by alpha

with open('analysis/causal_interventions/abstention_rates_by_alpha.json', 'r') as f:
    all_rows = json.load(f)

df = pd.DataFrame(all_rows)

df["dataset"] = df["dataset"].replace({
    "Squad": "SQuAD",
    "Repliqa": "RepliQA",
    "Nq": "NQ",
    "Musique": "MuSiQue"
})
df["model"] = df["model"].replace({
    "llama3": "Llama 3",
    "gemma3": "Gemma 3"
})

df["label"] = df["model"] + " – " + df["prompt_type"]

sns.set(style="whitegrid", context="talk")
lighter_colors = ["#539ecd", "#ff9c42", "#5cb85c", "#e15759"]

custom_palette = ["#7a6bbf", "#ff9c42", "#5cb85c", "#e15759"]
sns.set_palette(custom_palette)

g = sns.FacetGrid(
    df, 
    col="dataset", 
    hue="label", 
    col_wrap=2, 
    height=3.5,
    aspect=1.2,
    sharey=True,
    margin_titles=False
)

g.map_dataframe(
    sns.lineplot, 
    x="alpha", 
    y="percentage", 
    marker="o"
)

for idx, ax in enumerate(g.axes.flat):
    ax.axvline(0, linestyle="--", color="gray", linewidth=1)
    ax.axvline(1, linestyle=":", color="black", linewidth=1)
    
    ax.set_xlabel("")
    ax.set_ylabel("")
    title = ax.get_title().replace("dataset = ", "")
    ax.set_title(title, fontsize=22)
    for spine in ax.spines.values():
        spine.set_visible(True)


handles, labels = g.axes[0].get_legend_handles_labels()
g.fig.legend(
    handles,
    labels,
    loc="lower center",
    ncol = 2,
    bbox_to_anchor=(0.53, -0.21),
    frameon=True,
    fontsize=19,
    title_fontsize=13
)

g.fig.subplots_adjust(top=0.9)

g.tight_layout()

g.fig.suptitle("Direction Scaling Effects Across Datasets", fontsize=26, x=0.53, y=1.04)
g.fig.text(0.53, -0.02, "Direction Scale (α)", ha='center', fontsize=24)
g.fig.text(-0.01, 0.5, "Abstention Rate (%)", va='center', rotation='vertical', fontsize=24)
g.savefig('plots/causal_interventions.pdf', format='pdf', dpi=300, bbox_inches='tight')
plt.show()

### unanswerability score distributions

In [4]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from tqdm import tqdm
from evaluate import project_onto_dom, get_hidden_vector
from data.load_datasets import load_data
from pipeline.model_utils.model_factory import construct_model_base


model_path="meta-llama/meta-llama-3-8b-instruct"
model_name = "llama3"
model_base = construct_model_base(model_path)
layer = 16
pos = -1
thresholds = {'squad': -0.4, 'nq': -0.24}
dirs_to_run = ['squad', 'nq']
palette = sns.color_palette("tab10", 8)

rep_ans_prompts, rep_unans_prompts = load_data('data/abstain_aware_prompt/repliqa', "val")
squad_ans_prompts, squad_unans_prompts = load_data('data/abstain_aware_prompt/squad', "val")
nq_ans_prompts, nq_unans_prompts = load_data('data/abstain_aware_prompt/nq', "val")
musique_ans_prompts, musique_unans_prompts = load_data('data/abstain_aware_prompt/musique', "val")

prompt_groups = [
    ('RepLiQA Answerable', rep_ans_prompts, palette[0]),
    ('RepLiQA Unanswerable', rep_unans_prompts, palette[1]),
    ('SQuAD Answerable', squad_ans_prompts, palette[2]),
    ('SQuAD Unanswerable', squad_unans_prompts, palette[3]),
    ('NQ Answerable', nq_ans_prompts, palette[4]),
    ('NQ Unanswerable', nq_unans_prompts, palette[5]),
    ('MuSiQue Answerable', musique_ans_prompts, palette[6]),
    ('MuSiQue Unanswerable', musique_unans_prompts, palette[7]),
]

all_scores = {}

for dir_data in dirs_to_run:
    dirs_path = f'pipeline/runs/{model_name}/{dir_data}'
    candidate_directions = torch.load(f'{dirs_path}/mean_diffs.pt')
    dir_vector = candidate_directions[pos, layer].to(dtype=torch.float32)
    
    all_scores[dir_data] = {}
    
    for label, prompts, _ in prompt_groups:
        scores = [
            project_onto_dom(get_hidden_vector(prompt, model_base, pos, layer), dir_vector).item()
            for prompt in tqdm(prompts, desc=f"Processing {label} with {dir_data} direction")
        ]
        all_scores[dir_data][label] = scores

fig, axes = plt.subplots(2, 1, figsize=(9, 14), sharex=True)

for idx, dir_data in enumerate(dirs_to_run):
    ax = axes[idx]
    for i, (label, _, color) in enumerate(prompt_groups):
        scores = all_scores[dir_data][label]
        sns.histplot(
            scores, color=color, label=label, kde=True, bins=30, alpha=0.2,
            line_kws={'alpha': 1, 'linewidth': 2}, edgecolor=(0, 0, 0, 0.2), ax=ax
        )

    thresh = thresholds[dir_data]
    ax.axvline(thresh, color='black', linestyle='dashed', label=f'Threshold = {thresh}')
    ax.set_title(f"Direction from {dir_data.upper()}", fontsize=20)
    ax.set_ylabel("Density", fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=12)
    ax.set_xlabel(r"$\phi_{\mathrm{unans}}$", fontsize=16)

custom_lines = [Line2D([0], [0], color=palette[i], lw=2) for i in range(8)]
custom_lines.append(Line2D([0], [0], color='black', linestyle='dashed', lw=2))
labels = [
    'RepLiQA Answerable', 'RepLiQA Unanswerable',
    'SQuAD Answerable', 'SQuAD Unanswerable',
    'NQ Answerable', 'NQ Unanswerable',
    'MuSiQue Answerable', 'MuSiQue Unanswerable',
    'Threshold'
]
fig.legend(custom_lines, labels, fontsize=14, loc='lower center', ncol=3, bbox_to_anchor=(0.5, -0.01))

plt.suptitle("Distributions of Unanswerability Scores", fontsize=24, y=0.99)
plt.tight_layout()
plt.subplots_adjust(bottom=0.12)

plt.savefig(f"plots/{model_name}_unanswerability_scores.png", dpi=300, bbox_inches='tight')
plt.savefig(f"plots/{model_name}_unanswerability_scores.pdf", format='pdf', dpi=300, bbox_inches='tight')
plt.show()

### combining directions

In [None]:
from evaluate import evaluate_by_projecting_2layers
from pipeline.utils.threshold_utils import get_threshold_by_curve_2layers
from pipeline.model_utils.model_factory import construct_model_base
import json
import torch
from data.load_datasets import load_data

model_path="meta-llama/meta-llama-3-8b-instruct"
model_name = "llama3"
model_base = construct_model_base(model_path)

data_name = "squad" # choose direction data from ['squad', 'repliqa', 'nq', 'musique']
dirs_path = f'pipeline/runs/{model_name}/{data_name}'
with open(f'{dirs_path}/select_by_steering/direction_evaluations.json', 'r') as f:
    direction_evaluations = json.load(f)
pos1 = direction_evaluations[0]['position']
layer1 = direction_evaluations[0]['layer']
pos2 = direction_evaluations[1]['position']
layer2 = direction_evaluations[1]['layer']

candidate_directions = torch.load(f'{dirs_path}/mean_diffs.pt')
dir_vector1 = candidate_directions[pos1, layer1]
dir_vector2 = candidate_directions[pos2, layer2]

eval_data = 'nq' # choose eval data from ['squad', 'repliqa', 'nq', 'musique']
path = f'data/abstain_aware_prompt/{data_name}' #change to eval_data for threshold calibration 
ans_val, unans_val = load_data(path, "val")

fpr, tpr, roc_auc, best_roc_idx, threshold = get_threshold_by_curve_2layers(dir_vector1, dir_vector2, model_base, pos1, pos2, layer1, layer2, ans_val, unans_val)

ans_test, unans_test = load_data(f'data/abstain_aware_prompt/{eval_data}', "test")

evaluate_by_projecting_2layers(f'{dirs_path}/evaluations/two_vectors/eval_on_{eval_data}', ans_test, unans_test, model_base, dir_vector1, dir_vector2, pos1, pos2, layer1, layer2, threshold)
