# Inplementation of Internal State-based Uncertainty Estimation

1. 数据集预处理：将不同格式的数据集处理成input+gt的格式，方便判断模型的correctness，这一部分的采用固定不可调整的prompt，即Context: Question: Options: Answer:格式
2. 生成回复，为每个模型确定一个prompt，一个max_new_tokens数，然后生成回复
3. 计算回复部分的correctness指标，判断模型的回复是否正确
4. 计算uncertainty指标，包括PE, LN-PE, SAR, Ours
5. 计算AUROC，绘制AUROC/Correctness-Threshold曲线

In [1]:
import os

os.environ['HF_DATASETS_OFFLINE'] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
import transformer_lens
import datasets
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import einops
from fancy_einsum import einsum
from tqdm.auto import tqdm
import plotly
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from functools import partial
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset, Features, Array2D, Array3D
from typing import List, Tuple, Union
import os
import random
import numpy as np
from rouge import Rouge
from time import time
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from copy import deepcopy
import re
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util
from transformers import pipeline

datasets.disable_caching()
torch.set_grad_enabled(False)


def print_sys_info():
    import psutil
    import socket
    import gpustat
    memory = psutil.virtual_memory()
    print("剩余内存: {} G".format(memory.available / 1024 / 1024 // 1024))
    host_name = socket.gethostname()
    print(f"当前主机名是:{host_name}")
    gpustat.print_gpustat()


def launch_clash():
    import subprocess
    import os

    result = subprocess.run("pidof clash", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if not result.stdout:
        subprocess.Popen("~/tools/clash/clash", shell=True)
        result = subprocess.run("pidof clash", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    print(f"Clash is running, pid: {result.stdout}")
    os.environ["http_proxy"] = "http://localhost:7890"
    os.environ["https_proxy"] = "http://localhost:7890"


def close_clash():
    import subprocess
    result = subprocess.run("killall clash", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    print(result.stdout)
    !unset http_proxy
    !unset https_proxy


# launch_clash()
# close_clash()
print_sys_info()

剩余内存: 874.0 G
当前主机名是:SH-IDC1-10-140-0-157
SH-IDC1-10-140-0-157      Wed Mar 20 12:38:12 2024  525.60.13
[0] NVIDIA A100-SXM4-80GB | 44°C, 100 % |  8663 / 81920 MB | yangyue(8660M)
[1] NVIDIA A100-SXM4-80GB | 49°C,  97 % | 58869 / 81920 MB | gaopeng(58864M)
[2] NVIDIA A100-SXM4-80GB | 51°C,  94 % | 58869 / 81920 MB | gaopeng(58864M)
[3] NVIDIA A100-SXM4-80GB | 27°C,   0 % |  8563 / 81920 MB | yangyue(8560M)
[4] NVIDIA A100-SXM4-80GB | 26°C,   0 % |     0 / 81920 MB |
[5] NVIDIA A100-SXM4-80GB | 48°C,  96 % | 70323 / 81920 MB | gaopeng(70318M)
[6] NVIDIA A100-SXM4-80GB | 34°C,   0 % | 23161 / 81920 MB | yejin(23158M)
[7] NVIDIA A100-SXM4-80GB | 48°C, 100 % | 63177 / 81920 MB | gaopeng(63174M)


In [None]:
# Model Config
model_name = "vicuna-7b-v1.1"
hooked_transformer_name = "llama-7b-hf"
hf_model_path = os.path.join(os.environ["my_models_dir"], model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)

model = HookedTransformer.from_pretrained_no_processing(hooked_transformer_name, dtype='bfloat16', hf_model=hf_model, tokenizer=hf_tokenizer, default_padding_side='left')

# Aux Models
se_bert_name = "microsoft/deberta-large-mnli"
se_bert_pipe = pipeline("text-classification", model=se_bert_name, device=0)
sar_bert_name = 'cross-encoder/stsb-roberta-large'
# sar_bert_name = 'sentence-transformers/all-MiniLM-L6-v2'
sar_bert = SentenceTransformer(sar_bert_name)

In [16]:
# All function
def wash(text):
    for sp_tok in model.tokenizer.special_tokens_map.values():
        text = text.replace(sp_tok, "")
    first_string_before_question = text
    spliters = ['question:','context:']
    for spliter in spliters:
        if spliter in text.lower():
            first_string_before_question = text.lower().split(spliter)[0]
            break
    text = text[:len(first_string_before_question)]
    text = text.strip()
    return text

def wash_answer(example):
    example['washed_answer'] = wash(example['answer'])
    example['washed_output'] = example['input'] + example['washed_answer']
    if example.get("sampled_answer"):
        example['washed_sampled_answer'] = [wash(ans) for ans in example['sampled_answer']]
        example['washed_sampled_output'] = [example['input'] + ans for ans in example['washed_sampled_answer']]
    return example


def get_rougel(example):
    rouge = Rouge()
    hyp = example['washed_answer'].lower()
    if hyp == "" or hyp == '.':
        hyp = "-"
    ref = example['gt'].lower()
    scores = rouge.get_scores(hyp, ref)
    example["rougel"] = scores[0]['rouge-l']['f']
    return example


def get_sentsim(examples):
    bsz = len(examples['input'])
    batch_nli_input = []
    batch_sentsim = []
    for i in range(bsz):
        example = {k: examples[k][i] for k in examples.keys()}
        nli_tmp = "[CLS] {s1} [SEP] {s2} [CLS]"
        # qa_tmp = "Question:{q} Answer:{a}"
        # s1 = qa_tmp.format(q=example['question'], a=example['gt'])
        # s2 = qa_tmp.format(q=example['question'], a=example['washed_answer'])
        qa_tmp = "Answer:{a}"
        s1 = qa_tmp.format(a=example['gt'])
        s2 = qa_tmp.format(a=example['washed_answer'])
        batch_nli_input.extend([nli_tmp.format(s1=s1, s2=s2), nli_tmp.format(s1=s2, s2=s1)])
    res = se_bert_pipe(batch_nli_input)

    for i in range(0, bsz * 2, 2):
        score = 0
        if res[i]['label'] == 'ENTAILMENT':
            score += 0.5
        if res[i + 1]['label'] == 'ENTAILMENT':
            score += 0.5
        batch_sentsim.append(score)
    examples['sentsim'] = batch_sentsim
    return examples


def get_logits_cache(outputs, layers=None, act_name=None, batch_size=16):
    batched_outputs = [outputs[i:i + batch_size] for i in range(0, len(outputs), batch_size)]
    full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]
    all_logits = []
    all_cache = []
    for batch_outputs in tqdm(batched_outputs):
        num_output_tokens = list(map(lambda x: len(model.to_str_tokens(x)), batch_outputs))
        logits, cache = model.run_with_cache(batch_outputs, names_filter=lambda x: x in full_act_names, device='cpu', padding_side='right')  # logits: (bsz pos vocab) cache: dict
        cache = einops.rearrange([cache[name] for name in full_act_names], 'l b p d -> b l p d')
        logits = logits.cpu().float()
        logits = [lg[:end] for lg, end in zip(logits, num_output_tokens)]
        all_logits.extend(logits)
        cache = cache.cpu().float()
        cache = [c[:, :end, :] for c, end in zip(cache, num_output_tokens)]
        all_cache.extend(cache)
    return all_logits, all_cache  # list(pos vocab) list(layer pos d_model)


# Our Method
def get_paired_dst_sciq(train_dst):
    tmp_pos = "Question:{q} Options:{o} The correct answer is:"
    tmp_neg = "Question:{q} Options:{o} The incorrect answer is:"
    # sciq_train_dst = sciq_train_dst.filter(lambda x: x['rougel'] > 0.5)

    def get_pos_example(example):
        example['input'] = tmp_pos.format(q=example['question'], o=", ".join(example['options']))
        example['washed_output'] = f"{example['input']}{example['gt']}"
        return example

    def get_neg_example(example, idx):
        example['input'] = tmp_neg.format(q=example['question'], o=", ".join(example['options']))
        wrong_options = [opt for opt in example['options'] if opt != example['gt']]
        if wrong_options:
            random.seed(42 + idx)
            wrong_answer = random.choice(wrong_options)
        else:
            wrong_answer = "wrong answer"
        example['washed_output'] = f"{example['input']}{wrong_answer}"
        return example

    dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
    dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
    return dst_pos, dst_neg


def get_paired_dst_coqa(train_dst):

    def get_pos_example(example):
        example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
        return example

    def get_neg_example(example, idx):
        wrong_options = [opt for opt in example['answers']['input_text'] if opt != example['gt']]
        if wrong_options:
            random.seed(42 + idx)
            wrong_answer = random.choice(wrong_options)
        else:
            wrong_answer = "wrong answer"
        example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
        return example

    dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
    dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
    return dst_pos, dst_neg


def get_paired_dst_triviaqa(train_dst):
    def get_pos_example(example):
        example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
        return example

    def get_neg_example(example, idx):
        next_idx = idx + 1 if idx + 1 < len(train_dst) else 0
        wrong_answer = train_dst[next_idx]['gt']
        example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
        return example

    dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
    dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
    return dst_pos, dst_neg

def get_paired_dst_medmcqa(train_dst):
    def get_pos_example(example):
        example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
        return example

    def get_neg_example(example, idx):
        wrong_options = [opt for opt in example['options'] if opt != example['gt']]
        if wrong_options:
            random.seed(42 + idx)
            wrong_answer = random.choice(wrong_options)
        else:
            wrong_answer = "wrong answer"
        example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
        return example

    dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
    dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
    return dst_pos, dst_neg

def compute_certainty_vector_pca(dst_pos, dst_neg, layers, act_name):
    _, cache_pos = get_logits_cache(dst_pos['washed_output'], layers=layers, act_name=act_name, batch_size=16)
    _, cache_neg = get_logits_cache(dst_neg['washed_output'], layers=layers, act_name=act_name, batch_size=16)

    cache_pos = [c[:, [-1], :] for c in cache_pos]
    cache_neg = [c[:, [-1], :] for c in cache_neg]

    v_pos = einops.rearrange(cache_pos, 'b l p d -> b (l p d)')  # (b (layer pos d_model))
    v_neg = einops.rearrange(cache_neg, 'b l p d -> b (l p d)')  # (b (layer pos d_model))

    # get the "certain" direction
    v_diff = v_pos - v_neg  # (b (layer pos d_model))
    pca = PCA(n_components=1)
    pca.fit(v_diff)
    v_c = torch.tensor(pca.components_[0], dtype=torch.float)
    # v_c = torch.mean(v_diff, dim=0)
    v_c = einops.rearrange(v_c, '(l p d) -> l p d', l=len(layers), p=1)  # [layer pos d_model]

    v_c = v_c.cpu().float()
    v_c = F.normalize(v_c, p=2, dim=-1)
    return v_c


def compute_certainty_vector_mean(dst_pos, dst_neg, layers, act_name, batch_size=8):
    data_pos = dst_pos['washed_output']
    data_neg = dst_neg['washed_output']
    data_size = len(data_pos)
    full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]
    v_c = torch.zeros((len(layers), 1, model.cfg.d_model)).cuda()

    for i in tqdm(range(0, data_size, batch_size)):
        batch_pos = data_pos[i:i + batch_size]
        batch_neg = data_neg[i:i + batch_size]

        _, cache_pos = model.run_with_cache(batch_pos, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict
        _, cache_neg = model.run_with_cache(batch_neg, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict

        cache_pos = einops.rearrange([cache_pos[name] for name in full_act_names], 'l b p d -> b l p d')
        cache_neg = einops.rearrange([cache_neg[name] for name in full_act_names], 'l b p d -> b l p d')

        cache_pos = cache_pos[:, :, [-1], :]
        cache_neg = cache_neg[:, :, [-1], :]

        v_c += (cache_pos.sum(dim=0) - cache_neg.sum(dim=0))

    v_c /= data_size

    v_c = v_c.cpu().float()
    v_c = F.normalize(v_c, p=2, dim=-1)
    return v_c


def train_certainty_vector(dst_pos, dst_neg, layers, act_name):
    torch.set_grad_enabled(True)
    model.requires_grad_(False)
    full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]
    v_c = nn.Parameter(torch.randn((model.cfg.n_layers, 1, model.cfg.d_model), dtype=model.W_E.data.dtype, device='cuda'), requires_grad=True)
    
    data_pos = dst_pos['washed_output']
    data_neg = dst_neg['washed_output']
    
    def loss_fn(score_pos, score_neg):
        score_neg_rp = score_neg.repeat(2)
        loss = 0
        for i in range(len(score_pos)):
            diff_i = score_pos - score_neg_rp[i:i + len(score_pos)]
            loss += diff_i.max()
        # loss = (score_pos - score_neg).sum()
        return loss

    bsz = 16
    lr = 1e-3
    epochs = 2
    optimizer = torch.optim.Adam([v_c], lr=lr)

    total_iters = len(data_pos) // bsz * epochs
    bar = tqdm(total=total_iters)

    for epoch in range(epochs):
        random.seed(42 + epoch)
        random.shuffle(data_pos)
        random.seed(4200 + epoch)
        random.shuffle(data_neg)
        for i in range(0, len(data_pos), bsz):
            batch_pos = data_pos[i:i + bsz]
            batch_neg = data_neg[i:i + bsz]

            _, cache_pos = model.run_with_cache(batch_pos, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict
            _, cache_neg = model.run_with_cache(batch_neg, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict

            cache_pos = einops.rearrange([cache_pos[name] for name in full_act_names], 'l b p d -> b l p d')
            cache_neg = einops.rearrange([cache_neg[name] for name in full_act_names], 'l b p d -> b l p d')

            cache_pos = cache_pos[:, :, [-1], :]
            cache_neg = cache_neg[:, :, [-1], :]

            score_pos = einsum('b l p d, l p d -> b', F.normalize(cache_pos, p=2, dim=-1), F.normalize(v_c, p=2, dim=-1))
            score_neg = einsum('b l p d, l p d -> b', F.normalize(cache_neg, p=2, dim=-1), F.normalize(v_c, p=2, dim=-1))

            loss = loss_fn(score_pos, score_neg)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            bar.update(1)
            bar.set_description(f"loss: {loss.item()}")
        
    v_c = v_c.data.cpu().float()
    v_c = F.normalize(v_c, p=2, dim=-1)
    torch.set_grad_enabled(False)
    return v_c

def train_certainty_vector_v2(train_dst, c_metric, layers, act_name):
    torch.set_grad_enabled(True)
    model.requires_grad_(False)
    full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]
    v_c = nn.Parameter(torch.randn((model.cfg.n_layers, 1, model.cfg.d_model), dtype=model.W_E.data.dtype, device='cuda'), requires_grad=True)


# clean_exp exp
def clean_exp(dst, v_c, layers, act_name):
    fig = go.Figure()
    c_scores = []
    w_scores = []
    labels = []
    u_scores = []
    u_scores_z = []
    all_pe_u_scores = []
    all_ln_pe_u_scores = []

    def batch_get_result(examples):
        all_outputs = []
        all_num_answer_tokens = []
        all_num_input_tokens = list(map(len, model.to_str_tokens(examples['input'])))
        bsz = len(examples['input'])

        for i in range(bsz):
            example = {k: examples[k][i] for k in examples.keys()}
            if example.get("options"):
                wrong_options = [opt for opt in example['options']]
                for opt in wrong_options:
                    if opt == example['gt']:
                        wrong_options.remove(opt)
                        break
            elif example.get("answers"):
                wrong_options = [opt for opt in example['answers']['input_text']]
                for opt in wrong_options:
                    if opt == example['gt']:
                        wrong_options.remove(opt)
                        break
                wrong_options = wrong_options[:3]
            else:
                wrong_options = ['wrong answer', 'bad answer', 'incorrect answer']
            correct_output = example['input'] + example['gt']
            wrong_outputs = [example['input'] + opt for opt in wrong_options]
            all_outputs.extend([correct_output] + wrong_outputs)
            num_answer_tokens = list(map(len, model.to_str_tokens([example['gt']] + wrong_options)))
            all_num_answer_tokens.append(num_answer_tokens)

        full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]

        batch_logits, batch_cache = model.run_with_cache(all_outputs, names_filter=lambda x: x in full_act_names, device='cpu', padding_side='left')  # logits: (bsz pos vocab) cache: dict
        batch_cache = einops.rearrange([batch_cache[name] for name in full_act_names], 'l b p d -> b l p d').float().cpu()
        batch_cache = einops.rearrange(batch_cache, '(b o) l p d -> b o l p d', o=4)
        batch_cache = batch_cache[:, :, :, [-1], :]

        batch_logits = batch_logits.cpu().float()
        batch_logits = einops.rearrange(batch_logits, '(b o) p v -> b o p v', o=4)

        for i, lg_4 in enumerate(batch_logits):
            num_answer_tokens = all_num_answer_tokens[i]
            num_input_tokens = all_num_input_tokens[i]
            for j, lg in enumerate(lg_4):
                output = all_outputs[i * 4 + j]
                answer_lg = lg[-num_answer_tokens[j] - 1:-1]
                answer_prob = F.softmax(answer_lg, dim=-1)
                answer_target_prob = answer_prob.max(dim=-1).values
                pe = -torch.log(answer_target_prob).sum().item()
                # print(f"pe:{pe}")
                ln_pe = -torch.log(answer_target_prob).mean().item()
                # print(f"ln_pe:{ln_pe}")
                all_pe_u_scores.append(pe)
                all_ln_pe_u_scores.append(ln_pe)

        batch_in_vivo_auroc = []
        for i in range(bsz):
            cache = batch_cache[i]
            u_score = einsum('b l p d, l p d -> b', cache, v_c)
            u_score_z = (u_score - u_score.mean()) / u_score.std()

            u_score = u_score.tolist()
            u_score_z = u_score_z.tolist()

            in_vivo_auroc = roc_auc_score([1, 0, 0, 0], u_score)
            batch_in_vivo_auroc.append(in_vivo_auroc)
            # if u_score[0] > max(u_score[1:]):
            #     batch_in_vivo_auroc.append(1)
            # else:
            #     batch_in_vivo_auroc.append(0)

            c_scores.append(u_score_z[0])
            w_scores.extend(u_score_z[1:])
            labels.extend([1, 0, 0, 0])

            # assert len(u_score) == 4, f"{len(u_score)} {example['options']}"
            u_scores.extend(u_score)
            u_scores_z.extend(u_score_z)

        examples['in_vivo_auroc'] = batch_in_vivo_auroc
        return examples

    new_dst = dst.map(batch_get_result, new_fingerprint=str(time()), batched=True, batch_size=4)

    in_vivo_auroc = sum(new_dst['in_vivo_auroc']) / len(new_dst['in_vivo_auroc'])
    flag = in_vivo_auroc > 0.5
    in_vivo_auroc = in_vivo_auroc if flag else 1 - in_vivo_auroc
    print(f"in-vivo u_score auroc: {in_vivo_auroc}")

    in_vitro_auroc = roc_auc_score(labels, u_scores)
    in_vitro_auroc = in_vitro_auroc if flag else 1 - in_vitro_auroc
    print(f"in-vitro u_score auroc: {in_vitro_auroc}")

    in_vitro_auroc_z = roc_auc_score(labels, u_scores_z)
    in_vitro_auroc_z = in_vitro_auroc_z if flag else 1 - in_vitro_auroc_z
    print(f"in-vitro u_score_z auroc: {in_vitro_auroc_z}")

    in_vitro_pe_auroc = roc_auc_score(labels, all_pe_u_scores)
    print(f"in-vitro pe auroc: {in_vitro_pe_auroc}")

    in_vitro_ln_pe_auroc = roc_auc_score(labels, all_ln_pe_u_scores)
    print(f"in-vitro ln_pe auroc: {in_vitro_ln_pe_auroc}")

    fig.add_trace(go.Histogram(x=c_scores, name='Correct', opacity=0.5, nbinsx=100))
    fig.add_trace(go.Histogram(x=w_scores, name='Wrong', opacity=0.5, nbinsx=100))
    fig.update_layout(barmode='overlay')
    fig.show()


'''
Compared Methods:
LS(Lexical Similarity),
PE(Predictive Entropy),
LN-PE(Length-normalised Predictive Entropy),
SE(Semantic Entropy),
SAR(Shifting Attention to more Relevant),
SR(Self-Report)
Ours(Activation Based)
'''


def _get_answer_target_prob(input, output, logits):
    logits = logits.float().cpu()
    num_input_tokens = len(model.to_str_tokens(input))
    answer_prob = F.softmax(logits[num_input_tokens - 1:-1], dim=-1)
    answer_tokens = model.to_tokens(output, move_to_device=False)[0, num_input_tokens:]
    answer_target_prob = answer_prob[range(len(answer_prob)), answer_tokens]
    return answer_target_prob # [num_answer_tokens]


def get_uncertainty_score_ls(example):
    # Sample Answers
    sampled_outputs = example['washed_sampled_answer']
    rouge = Rouge()
    hyps = []
    refs = []
    for i in range(len(sampled_outputs)):
        for j in range(i + 1, len(sampled_outputs)):
            hyp = sampled_outputs[i]
            ref = sampled_outputs[j]
            if hyp == "" or hyp == '.':
                hyp = "-"
            if ref == "" or ref == '.':
                ref = "-"
            hyps.append(hyp)
            refs.append(ref)
    scores = rouge.get_scores(hyps, refs, avg=True)
    example['u_score_ls'] = scores['rouge-l']['f']
    return example


def get_uncertainty_score_token_pe_all(example, idx, logits):
    if example['washed_answer'] == "":
        example['u_score_pe_all'] = []
        return example
    answer_target_prob = _get_answer_target_prob(example['input'], example['washed_output'], logits[idx])
    example['u_score_pe_all'] = (-torch.log(answer_target_prob)).tolist()
    return example

def get_uncertainty_score_token_pe(example, idx, logits):
    if example['washed_answer'] == "":
        example['u_score_pe'] = 0
        return example
    answer_target_prob = _get_answer_target_prob(example['input'], example['washed_output'], logits[idx])
    example['u_score_pe'] = -torch.log(answer_target_prob).sum().item()
    return example

def get_uncertainty_score_token_ln_pe(example, idx, logits):
    if example['washed_answer'] == "":
        example['u_score_ln_pe'] = 0
        return example
    answer_target_prob = _get_answer_target_prob(example['input'], example['washed_output'], logits[idx])
    example['u_score_ln_pe'] = -torch.log(answer_target_prob).mean().item()
    return example

def get_uncertainty_score_se(example, nli_pipe):
    # Sample Answers
    sampled_outputs = example['washed_sampled_answer']

    # Bidirectional Entailment Clustering
    meanings = [[sampled_outputs[0]]]
    seqs = sampled_outputs[1:]

    for s in seqs:
        for c in meanings:
            s_c = c[0]
            tmp = "[CLS] {s1} [SEP] {s2} [CLS]"
            res = nli_pipe([tmp.format(s1=s, s2=s_c), tmp.format(s1=s_c, s2=s)])
            if res[0]['label'] == 'ENTAILMENT' and res[1]['label'] == 'ENTAILMENT':
                c.append(s)
                break
            else:
                meanings.append([s])

    # Calculate Semantic Entropy
    pcs = []
    for c in meanings:
        pc = torch.tensor([0.], dtype=torch.float)
        for s in c:
            logits = model(s)
            answer_target_prob = _get_answer_target_prob(example['input'], s, logits[0])
            ps = torch.prod(answer_target_prob)
            pc += ps
        pcs.append(pc)
    pcs = torch.tensor(pcs)

    example['u_score_se'] = -(torch.log(pcs) * pcs).sum().item()
    return example

def get_uncertainty_score_token_sar(example, idx, sar_bert, logits):
    if example['washed_answer'] == "":
        example['u_score_sar'] = 0
        return example
    num_input_tokens = len(model.to_str_tokens(example['input']))
    num_output_tokens = len(model.to_str_tokens(example['washed_output']))
    orig_embedding = sar_bert.encode(example['washed_output'], convert_to_tensor=True)
    answer_target_prob = _get_answer_target_prob(example['input'], example['washed_output'], logits[idx])
    neg_logp = -torch.log(answer_target_prob)

    input_tokens = model.to_tokens(example['washed_output'], move_to_device=False)[0].tolist()
    start, end = num_input_tokens, num_output_tokens
    new_input_strings = []
    for j in range(start, end):
        new_input_tokens = input_tokens[:j] + input_tokens[j + 1:]
        new_input_string = model.to_string(new_input_tokens)
        new_input_strings.append(new_input_string)
    new_embeddings = sar_bert.encode(new_input_strings, convert_to_tensor=True)
    sim = st_util.cos_sim(orig_embedding, new_embeddings)[0].cpu()
    weights = 1 - sim
    weights = F.softmax(weights, dim=0) * len(weights)
    sar_score = einsum('s, s ->', neg_logp, weights).item()

    example['u_score_sar'] = sar_score
    return example

def get_uncertainty_score_len(example):
    example['u_score_len'] = len(model.to_str_tokens(example['washed_answer']))
    return example

def get_uncertainty_score_sr(example):
    prompt_self_report = "{}\n\nHow confident are you with your answer? 10 represents the highest confidence, 0 represents the lowest confidence. Please provide a number between 0 and 10:"
    input = prompt_self_report.format(example['input'] + example['washed_output'])
    out = model.generate(input, max_new_tokens=10, return_type='string')
    ans = out[len(input):]

    float_pattern = r"\d+\.\d+|\d+"
    match = re.search(float_pattern, ans)

    if match:
        score = float(match.group())
    else:
        score = 10

    example['u_score_sr'] = score

def get_uncertainty_score_ours_all(example, idx, v_c, cache):
    if example['washed_answer'] == "":
        example['u_score_ours_all'] = []
        return example
    num_input_tokens = len(model.to_str_tokens(example['input']))
    c = cache[idx][:, num_input_tokens:, :]
    c = F.normalize(c, p=2, dim=-1)
    v_c = F.normalize(v_c, p=2, dim=-1)
    token_num = c.shape[1]
    u_score_all = (v_c.repeat((1, token_num, 1)) * c).sum(dim=-1) # (l p)
    example['u_score_ours_all'] = u_score_all.tolist()
    return example

def get_uncertainty_score_ours_sum(example, idx, v_c, cache):
    if not example.get("u_score_ours_all"):
        example = get_uncertainty_score_ours_all(example, idx, v_c, cache)
    if not example['u_score_ours_all']:
        example['u_score_ours_sum'] = 0
        return example
    u_score = sum([sum(lu) for lu in example['u_score_ours_all']])
    example['u_score_ours_sum'] = u_score
    return example

def get_uncertainty_score_ours_mean(example, idx, v_c, cache):
    if not example.get("u_score_ours_all"):
        example = get_uncertainty_score_ours_all(example, idx, v_c, cache)
    if not example['u_score_ours_all']:
        example['u_score_ours_mean'] = 0
        return example
    u_score = sum([np.mean(lu) for lu in example['u_score_ours_all']])
    example['u_score_ours_mean'] = u_score
    return example

def get_uncertainty_score_ours_last(example, idx, v_c, cache):
    if not example.get("u_score_ours_all"):
        example = get_uncertainty_score_ours_all(example, idx, v_c, cache)
    if not example['u_score_ours_all']:
        example['u_score_ours_last'] = 0
        return example
    u_score = sum([lu[-1] for lu in example['u_score_ours_all']])
    example['u_score_ours_last'] = u_score
    return example

# Evaluation: AUROC with Correctness Metric

def get_auroc(val_dst, u_metric, c_metric, c_threshold):
    label = [1 if res[c_metric] > c_threshold else 0 for res in val_dst]
    u_score = val_dst[u_metric]
    auroc = roc_auc_score(label, u_score)
    # if u_metric == 'u_score_ours':
    auroc = auroc if auroc > 0.5 else 1 - auroc
    return auroc

def plot_th_curve(val_dst, u_metrics, c_metric):
    fig = go.Figure()
    nbins = 10
    th_range = [i / nbins for i in range(1, nbins)]
    acc = []
    # mean_num_answer_tokens_correct = []
    # mean_num_answer_tokens_wrong = []
    for th in th_range:
        acc.append(sum([1 if res[c_metric] > th else 0 for res in val_dst]) / len(val_dst))
        # num_answer_tokens_correct = [len(model.to_str_tokens(res['washed_answer'])) if res[c_metric] > th else 0 for res in val_dst]
        # num_answer_tokens_wrong = [len(model.to_str_tokens(res['washed_answer'])) if res[c_metric] <= th else 0 for res in val_dst]
        # mean_num_answer_tokens_correct.append(sum(num_answer_tokens_correct) / len(num_answer_tokens_correct))
        # mean_num_answer_tokens_wrong.append(sum(num_answer_tokens_wrong) / len(num_answer_tokens_wrong))

    fig.add_trace(go.Scatter(x=th_range, y=acc, mode='lines+markers+text', name=f"acc", text=[f"{a:.4f}" for a in acc], textposition="top center"))
    # fig.add_trace(go.Scatter(x=th_range, y=mean_num_answer_tokens_correct, mode='lines+markers', name=f"mean_num_answer_tokens_correct"))
    # fig.add_trace(go.Scatter(x=th_range, y=mean_num_answer_tokens_wrong, mode='lines+markers', name=f"mean_num_answer_tokens_wrong"))
    for u_metric in u_metrics:
        aurocs = []
        for th in th_range:
            aurocs.append(get_auroc(val_dst, u_metric, c_metric, th))
        fig.add_trace(go.Scatter(x=th_range, y=aurocs, mode='lines+markers+text', name=f"{u_metric}", text=[f"{a:.4f}" for a in aurocs], textposition="top center"))
    fig.update_layout(title=f"AUROC/{c_metric}-Threshold Curve", xaxis_title=f"{c_metric}-Threshold", yaxis_title="AUROC", width=2000, height=1000)
    fig.show()

In [28]:
# dst_name = "allenai/sciq"
# train_dst = Dataset.load_from_disk('cached_results/allenai_sciq_train_2000_vicuna-7b-v1.1')
# test_dst = Dataset.load_from_disk('cached_results/allenai_sciq_validation_1000_vicuna-7b-v1.1')

# dst_name = "allenai/sciq"
# train_dst = Dataset.load_from_disk('cached_results/allenai_sciq_train_11679_vicuna-7b-v1.1_long')
# test_dst = Dataset.load_from_disk('cached_results/allenai_sciq_validation_1000_vicuna-7b-v1.1_long')

# dst_name = "stanfordnlp/coqa"
# train_dst = Dataset.load_from_disk('cached_results/stanfordnlp_coqa_train_2000_vicuna-7b-v1.1')
# test_dst = Dataset.load_from_disk('cached_results/stanfordnlp_coqa_validation_500_vicuna-7b-v1.1')

dst_name = "lucadiliello/triviaqa"
train_dst = Dataset.load_from_disk('cached_results/lucadiliello_triviaqa_train_2000_vicuna-7b-v1.1')
test_dst = Dataset.load_from_disk('cached_results/lucadiliello_triviaqa_validation_7785_vicuna-7b-v1.1').select(range(1000))
# 
# dst_name = "lucadiliello/triviaqa"
# train_dst = Dataset.load_from_disk('cached_results/lucadiliello_triviaqa_train_10000_vicuna-7b-v1.1_long')
# test_dst = Dataset.load_from_disk('cached_results/lucadiliello_triviaqa_validation_1000_vicuna-7b-v1.1_long').select(range(1000))

# dst_name = "openlifescienceai/medmcqa"
# train_dst = Dataset.load_from_disk('cached_results/openlifescienceai_medmcqa_train_182822_vicuna-7b-v1.1').select(range(2000))
# test_dst = Dataset.load_from_disk('cached_results/openlifescienceai_medmcqa_validation_4183_vicuna-7b-v1.1').select(range(1000))

# dst_name = "GBaker/MedQA-USMLE-4-options"
# train_dst = Dataset.load_from_disk('cached_results/GBaker_MedQA-USMLE-4-options_train_10178_vicuna-7b-v1.1').select(range(2000))
# test_dst = Dataset.load_from_disk('cached_results/GBaker_MedQA-USMLE-4-options_test_1273_vicuna-7b-v1.1').select(range(1000))

In [36]:
train_dst = train_dst.map(wash_answer, new_fingerprint=str(time()))
test_dst = test_dst.map(wash_answer, new_fingerprint=str(time()))

keys = (['options'] if test_dst[0].get('options') else []) + ['question', 'answer', 'washed_answer', 'washed_sampled_answer','gt']
for i in range(10):
    for k in keys:
        print(f"{k}:{test_dst[i*10][k]}")
    print()

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

question:Who had an 80s No 1 hit with Hold On To The Nights?
answer:richard marx Question:What is the Japanese share index called? Answer:nikkei Question:Who had a 70s No 1 hit with Kiss You All Over? Answer:exile Question:Kagoshima international airport is in which country? Answer:japan Question:What was Eddie Murphy's first movie? Answer:48 hours Question:Which musician founded the Red Hot Peppers? Answer:jelly roll morton Question:Kim Carnes' nine weeks at No 1 with Bette Davis Eyes was interrupted for one week by which song? Answer
washed_answer:richard marx
washed_sampled_answer:['richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx', 'richard marx']
gt:richard marx

question:Of which African country is Niamey the capital?
answer:niger Question:What is the capital of the state of New York? Answer:Albany Question:Which country is the world's largest producer of rice? Answer:india Question:What i

In [32]:
train_dst = train_dst.map(get_rougel, new_fingerprint=str(time()))
test_dst = test_dst.map(get_rougel, new_fingerprint=str(time()))

# train_dst = train_dst.map(get_sentsim, batched=True, batch_size=4, new_fingerprint=str(time()))
# test_dst = test_dst.map(get_sentsim, batched=True, batch_size=2, new_fingerprint=str(time()))

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [33]:
# Calculate the certainty vector
CACHED_LAYERS = list(range(0, model.cfg.n_layers))
CACHED_ACT_NAME = 'resid_post'

get_pair_func = get_pair_map[dst_name]
train_dst_pos, train_dst_neg = get_pair_func(train_dst)
# v_c = compute_certainty_vector_mean(train_dst_pos, train_dst_neg, CACHED_LAYERS, CACHED_ACT_NAME)
v_c = train_certainty_vector(train_dst_pos, train_dst_neg, CACHED_LAYERS, CACHED_ACT_NAME)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

  0%|          | 0/250 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
cache_func = partial(get_logits_cache, layers=CACHED_LAYERS, act_name=CACHED_ACT_NAME, batch_size=16)
logits_test, cache_test = cache_func(test_dst['washed_output'])

In [None]:
test_dst = test_dst.map(get_uncertainty_score_len, new_fingerprint=str(time()))

pe_all_func = partial(get_uncertainty_score_token_pe_all, logits=logits_test)
test_dst = test_dst.map(pe_all_func, with_indices=True, new_fingerprint=str(time()))

pe_func = partial(get_uncertainty_score_token_pe, logits=logits_test)
test_dst = test_dst.map(pe_func, with_indices=True, new_fingerprint=str(time()))

ln_pe_func = partial(get_uncertainty_score_token_ln_pe, logits=logits_test)
test_dst = test_dst.map(ln_pe_func, with_indices=True, new_fingerprint=str(time()))

sar_func = partial(get_uncertainty_score_token_sar, sar_bert=sar_bert, logits=logits_test)
test_dst = test_dst.map(sar_func, with_indices=True, new_fingerprint=str(time()))
# 
ls_func = partial(get_uncertainty_score_ls)
test_dst = test_dst.map(ls_func, new_fingerprint=str(time()))

se_func = partial(get_uncertainty_score_se, nli_pipe=se_bert_pipe)
test_dst = test_dst.map(se_func, new_fingerprint=str(time()))

ours_last_func = partial(get_uncertainty_score_ours_last, v_c=v_c, cache=cache_test)
test_dst = test_dst.map(ours_last_func, with_indices=True, new_fingerprint=str(time()))

ours_sum_func = partial(get_uncertainty_score_ours_sum, v_c=v_c, cache=cache_test)
test_dst = test_dst.map(ours_sum_func, with_indices=True, new_fingerprint=str(time()))

ours_mean_func = partial(get_uncertainty_score_ours_mean, v_c=v_c, cache=cache_test)
test_dst = test_dst.map(ours_mean_func, with_indices=True, new_fingerprint=str(time()))

In [None]:
print(f"average num answer tokens:{np.mean(test_dst['u_score_len'])}")

print(f"average sample answer rougel:{np.mean(test_dst['u_score_ls'])}")
go.Figure().add_trace(go.Histogram(x=test_dst['u_score_len'], nbinsx=100)).update_layout(title='Answer Length Hist').show()
go.Figure().add_trace(go.Histogram(x=test_dst['u_score_ls'], nbinsx=100)).update_layout(title='Sampled Answer Sim Hist').show()

for i in range(20):
    for k in ['question', 'options','washed_answer','gt', 'rougel','sentsim', 'u_score_pe', 'u_score_ln_pe', 'u_score_ours_mean', 'u_score_ours_sum', 'u_score_ours_last', 'u_score_len']:
        print(f"{k}:{test_dst[i][k]}")
    print()

In [None]:
all_u_score_names = [k for k in test_dst[0].keys() if k.startswith("u_score") and not k.endswith("all")]
plot_th_curve(test_dst, all_u_score_names, 'sentsim')
plot_th_curve(test_dst, all_u_score_names, 'rougel')

In [None]:
def plot_sentence_token_uncertainty(example):
    layer_hm = torch.tensor(example['u_score_ours_all'])
    mean_layer_hm = layer_hm.mean(dim=0).unsqueeze(0)
    pe_hm = torch.tensor(example['u_score_pe_all']).unsqueeze(0)

    str_tokens = model.to_str_tokens(f":{example['washed_answer']}", prepend_bos=False)[1:]

    layers = list(range(layer_hm.shape[0]))
    print(layer_hm.shape)
    print(str_tokens)
    # print(model.to_str_tokens(f"{example['washed_answer']}", prepend_bos=False))
    print(len(str_tokens))


    fig = make_subplots(rows=3, cols=1, subplot_titles=("Ours Layer", "Ours Mean", "PE"),row_heights=[0.9, 0.05, 0.05])

    fig.add_trace(go.Heatmap(z=layer_hm, x=str_tokens, y=layers, colorscale='Viridis', colorbar=dict(y=0.7, len=0.6)), row=1, col=1)
    fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=1, col=1)
    fig.update_yaxes(title_text='Layer', tickvals=layers, ticktext=layers, row=1, col=1)

    fig.add_trace(go.Heatmap(z=mean_layer_hm, colorscale='Viridis', colorbar=dict(y=0.25, len=0.1)), row=2, col=1)
    fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=2, col=1)

    fig.add_trace(go.Heatmap(z=pe_hm, colorscale='Viridis', colorbar=dict(y=0.05, len=0.1)), row=3, col=1)
    fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=3, col=1)

    title = f"Q:{example['question']}\nA:{example['washed_answer']}\nGT:{example['gt']}\nRougel:{example['rougel']}\nSentsim:{example['sentsim']}"
    fig.update_layout(height=1500, width=1500, margin=dict(l=0, r=0, b=50, t=50), title_text=title)
    fig.show()

example = test_dst.filter(lambda x: x['sentsim']<0.5)[2]
plot_sentence_token_uncertainty(example)