# Inplementation of Internal State-based Uncertainty Estimation

1. 数据集预处理：将不同格式的数据集处理成input+gt的格式，方便判断模型的correctness，这一部分的采用固定不可调整的prompt，即Context: Question: Options: Answer:格式
2. 生成回复，为每个模型确定一个prompt，一个max_new_tokens数，然后生成回复
3. 计算回复部分的correctness指标，判断模型的回复是否正确
4. 计算uncertainty指标，包括PE, LN-PE, SAR, Ours
5. 计算AUROC，绘制AUROC/Correctness-Threshold曲线

In [27]:
import math
import os

os.environ['HF_DATASETS_OFFLINE'] = "1"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import transformer_lens
import datasets
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
import einops
from fancy_einsum import einsum
from tqdm.auto import tqdm
import plotly
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from functools import partial
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset, Features, Array2D, Array3D
from typing import List, Tuple, Union
import os
import random
import numpy as np
from rouge import Rouge
from time import time
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from copy import deepcopy
import re
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_util
from transformers import pipeline
from livelossplot import PlotLosses
from livelossplot.outputs import MatplotlibPlot
import math

datasets.disable_caching()
torch.set_grad_enabled(False)


def print_sys_info():
    import psutil
    import socket
    import gpustat
    memory = psutil.virtual_memory()
    print("剩余内存: {} G".format(memory.available / 1024 / 1024 // 1024))
    host_name = socket.gethostname()
    print(f"当前主机名是:{host_name}")
    gpustat.print_gpustat()

print_sys_info()

剩余内存: 597.0 G
当前主机名是:SH-IDC1-10-140-0-183
SH-IDC1-10-140-0-183      Mon Apr  1 14:02:47 2024  525.60.13
[0] NVIDIA A100-SXM4-80GB | 57°C, 100 % | 67905 / 81920 MB | gaopeng(67902M)
[1] NVIDIA A100-SXM4-80GB | 69°C, 100 % | 67151 / 81920 MB | gaopeng(67148M)
[2] NVIDIA A100-SXM4-80GB | 29°C,   0 % |     0 / 81920 MB |
[3] NVIDIA A100-SXM4-80GB | 27°C,   0 % |     0 / 81920 MB |
[4] NVIDIA A100-SXM4-80GB | 27°C,   0 % | 17329 / 81920 MB | guoyiqiu(17326M)
[5] NVIDIA A100-SXM4-80GB | 29°C,   0 % |     0 / 81920 MB |
[6] NVIDIA A100-SXM4-80GB | 45°C,  61 % | 68729 / 81920 MB | gaopeng(68726M)
[7] NVIDIA A100-SXM4-80GB | 39°C,  61 % | 68729 / 81920 MB | gaopeng(68726M)


In [4]:
# Model Config
model_name = "vicuna-7b-v1.1"
hooked_transformer_name = "llama-7b-hf"
hf_model_path = os.path.join(os.environ["my_models_dir"], model_name)
hf_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path)

model = HookedTransformer.from_pretrained_no_processing(hooked_transformer_name, dtype='bfloat16', hf_model=hf_model, tokenizer=hf_tokenizer, default_padding_side='left')

# Aux Models
se_bert_name = "microsoft/deberta-large-mnli"
nli_pipe = pipeline("text-classification", model=se_bert_name, device=0)

# sar_bert_name = 'cross-encoder/stsb-roberta-large'
# # sar_bert_name = 'sentence-transformers/all-MiniLM-L6-v2'
# sar_bert = SentenceTransformer(sar_bert_name)

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Loaded pretrained model llama-7b-hf into HookedTransformer


In [55]:
class Timer:
    def __enter__(self):
        self.ts = time()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.te = time()
        self.t = self.te - self.ts
        
def _get_answer_prob(inp, out, prob):
    num_input_tokens = len(model.to_str_tokens(inp))
    output_tokens = model.to_tokens(out, move_to_device=False)[0].tolist()
    if len(output_tokens) == num_input_tokens:
        return []
    answer_tokens = output_tokens[num_input_tokens:]
    answer_prob = prob[num_input_tokens - 1:-1, :]
    answer_prob = answer_prob[range(len(answer_tokens)), answer_tokens]
    answer_prob = answer_prob.tolist()
    return answer_prob
        
def get_sampled_answer_prob(example):
    batch_answer_prob = []
    washed_sampled_output = example['washed_sampled_output']
    washed_sampled_output_unique = list(set(washed_sampled_output))
    batch_prob = F.softmax(model(washed_sampled_output_unique, padding_side='right'), dim=-1)  # logits: (bsz pos vocab)

    for i in range(len(example['washed_sampled_output'])):
        inp = example['input']
        out = example['washed_sampled_output'][i]
        prob = batch_prob[washed_sampled_output_unique.index(out)]
        answer_prob = _get_answer_prob(inp, out, prob)
        batch_answer_prob.append(answer_prob)

    example['sampled_answer_prob'] = batch_answer_prob
    return example

def get_uncertainty_score_se(example, nli_pipe, eps=1e-9):
    # Sample Answers
    washed_sampled_answer = example['washed_sampled_answer']
    if not example.get('sampled_answer_prob'):
        example = get_sampled_answer_prob(example)
    with Timer() as timer:
        # Bidirectional Entailment Clustering
        meanings = [[washed_sampled_answer[0]]]
        seqs = washed_sampled_answer[1:]
        for s in seqs:
            in_existing_meaning = False
            for c in meanings:
                s_c = c[0]
                tmp = "[CLS] {s1} [SEP] {s2} [CLS]"
                res = nli_pipe([tmp.format(s1=s, s2=s_c), tmp.format(s1=s_c, s2=s)])
                if res[0]['label'] == 'ENTAILMENT' and res[1]['label'] == 'ENTAILMENT':
                    c.append(s)
                    in_existing_meaning = True
                    break
            if not in_existing_meaning:
                meanings.append([s])
        # Calculate Semantic Entropy
        pcs = []
        for c in meanings:
            pc = eps
            for s in c:
                idx = example['washed_sampled_answer'].index(s)
                answer_prob = example['sampled_answer_prob'][idx]
                ps = np.prod(answer_prob)
                pc += ps
            pcs.append(pc)
        example['u_score_se'] = -np.sum(np.log(pcs) * pcs)
    example['time_se'] = timer.t
    return example

In [52]:
test_dst = Dataset.load_from_disk("/mnt/petrelfs/guoyiqiu/coding/trainable_uncertainty/cached_results/short/allenai_sciq_validation_1000_vicuna-7b-v1.1")
# test_dst = test_dst.filter(lambda x: 'Esters can be formed by heating carboxylic acids and alcohols in the presence of' in x['question'])
test_dst = test_dst.select(range(100,150))
test_dst = test_dst.map(get_sampled_answer_prob, new_fingerprint=str(time()))

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [56]:
test_dst.map(partial(get_uncertainty_score_se,nli_pipe=nli_pipe), new_fingerprint=str(time()))

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'distractor3', 'distractor1', 'distractor2', 'correct_answer', 'support', 'input', 'dst_template', 'options', 'gt', 'answer', 'washed_answer', 'output', 'washed_output', 'sampled_answer', 'washed_sampled_answer', 'sampled_output', 'washed_sampled_output', 'sampled_answer_prob', 'u_score_se', 'time_se'],
    num_rows: 50
})

In [64]:
dst = datasets.load_dataset("/mnt/petrelfs/guoyiqiu/coding/huggingface/datasets/GBaker___med_qa-usmle-4-options/default/0.0.0/0fb93dd23a7339b6dcd27e241cb9b5eca62d4d18")
dst

DatasetDict({
    train: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx', 'metamap_phrases'],
        num_rows: 10178
    })
    test: Dataset({
        features: ['question', 'answer', 'options', 'meta_info', 'answer_idx', 'metamap_phrases'],
        num_rows: 1273
    })
})

In [None]:
# Our Method OLD
def compute_certainty_vector_mean(train_dst, layers, act_name, batch_size=8):
    def get_paired_dst_sciq(train_dst):
        tmp_pos = "Question:{q} Options:{o} The correct answer is:"
        tmp_neg = "Question:{q} Options:{o} The incorrect answer is:"
    
        # sciq_train_dst = sciq_train_dst.filter(lambda x: x['rougel'] > 0.5)
    
        def get_pos_example(example):
            example['input'] = tmp_pos.format(q=example['question'], o=", ".join(example['options']))
            example['washed_output'] = f"{example['input']}{example['gt']}"
            return example
    
        def get_neg_example(example, idx):
            example['input'] = tmp_neg.format(q=example['question'], o=", ".join(example['options']))
            wrong_options = [opt for opt in example['options'] if opt != example['gt']]
            if wrong_options:
                random.seed(42 + idx)
                wrong_answer = random.choice(wrong_options)
            else:
                wrong_answer = "wrong answer"
            example['washed_output'] = f"{example['input']}{wrong_answer}"
            return example
    
        dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
        dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
        return dst_pos, dst_neg
    
    
    def get_paired_dst_coqa(train_dst):
        def get_pos_example(example):
            example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
            return example
    
        def get_neg_example(example, idx):
            wrong_options = [opt for opt in example['answers']['input_text'] if opt != example['gt']]
            if wrong_options:
                random.seed(42 + idx)
                wrong_answer = random.choice(wrong_options)
            else:
                wrong_answer = "wrong answer"
            example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
            return example
    
        dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
        dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
        return dst_pos, dst_neg
    
    
    def get_paired_dst_triviaqa(train_dst):
        def get_pos_example(example):
            example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
            return example
    
        def get_neg_example(example, idx):
            next_idx = idx + 1 if idx + 1 < len(train_dst) else 0
            wrong_answer = train_dst[next_idx]['gt']
            example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
            return example
    
        dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
        dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
        return dst_pos, dst_neg
    
    
    def get_paired_dst_medmcqa(train_dst):
        def get_pos_example(example):
            example['washed_output'] = f"{example['input']}The correct answer is {example['gt']}"
            return example
    
        def get_neg_example(example, idx):
            wrong_options = [opt for opt in example['options'] if opt != example['gt']]
            if wrong_options:
                random.seed(42 + idx)
                wrong_answer = random.choice(wrong_options)
            else:
                wrong_answer = "wrong answer"
            example['washed_output'] = f"{example['input']}The wrong answer is {wrong_answer}"
            return example
    
        dst_pos = train_dst.map(get_pos_example, new_fingerprint=str(time()))
        dst_neg = train_dst.map(get_neg_example, with_indices=True, new_fingerprint=str(time()))
        return dst_pos, dst_neg
    
    func_map = {
        'allenai/sciq': get_paired_dst_sciq,
        'stanfordnlp/coqa': get_paired_dst_coqa,
        'lucadiliello/triviaqa': get_paired_dst_triviaqa,
        'openlifescienceai/medmcqa': get_paired_dst_medmcqa,
        'GBaker/MedQA-USMLE-4-options': get_paired_dst_medmcqa
    }
    func = func_map[dst_name]
    dst_pos, dst_neg = func(train_dst)
    
    data_pos = dst_pos['washed_output']
    data_neg = dst_neg['washed_output']
    data_size = len(data_pos)
    full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]
    v_c = torch.zeros((len(layers), 1, model.cfg.d_model)).cuda()

    for i in tqdm(range(0, data_size, batch_size)):
        batch_pos = data_pos[i:i + batch_size]
        batch_neg = data_neg[i:i + batch_size]

        _, cache_pos = model.run_with_cache(batch_pos, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict
        _, cache_neg = model.run_with_cache(batch_neg, names_filter=lambda x: x in full_act_names, padding_side='left')  # logits: (bsz pos vocab) cache: dict

        cache_pos = einops.rearrange([cache_pos[name] for name in full_act_names], 'l b p d -> b l p d')
        cache_neg = einops.rearrange([cache_neg[name] for name in full_act_names], 'l b p d -> b l p d')

        cache_pos = cache_pos[:, :, [-1], :]
        cache_neg = cache_neg[:, :, [-1], :]

        v_c += (cache_pos.sum(dim=0) - cache_neg.sum(dim=0))

    v_c /= data_size

    v_c = v_c.cpu().float()
    v_c = F.normalize(v_c, p=2, dim=-1)
    return v_c


# clean_exp exp
def clean_exp(dst, v_c, layers, act_name):
    fig = go.Figure()
    c_scores = []
    w_scores = []
    labels = []
    u_scores = []
    u_scores_z = []
    all_pe_u_scores = []
    all_ln_pe_u_scores = []

    def batch_get_result(examples):
        all_outputs = []
        all_num_answer_tokens = []
        all_num_input_tokens = list(map(len, model.to_str_tokens(examples['input'])))
        bsz = len(examples['input'])

        for i in range(bsz):
            example = {k: examples[k][i] for k in examples.keys()}
            if example.get("options"):
                wrong_options = [opt for opt in example['options']]
                for opt in wrong_options:
                    if opt == example['gt']:
                        wrong_options.remove(opt)
                        break
            elif example.get("answers"):
                wrong_options = [opt for opt in example['answers']['input_text']]
                for opt in wrong_options:
                    if opt == example['gt']:
                        wrong_options.remove(opt)
                        break
                wrong_options = wrong_options[:3]
            else:
                wrong_options = ['wrong answer', 'bad answer', 'incorrect answer']
            correct_output = example['input'] + example['gt']
            wrong_outputs = [example['input'] + opt for opt in wrong_options]
            all_outputs.extend([correct_output] + wrong_outputs)
            num_answer_tokens = list(map(len, model.to_str_tokens([example['gt']] + wrong_options)))
            all_num_answer_tokens.append(num_answer_tokens)

        full_act_names = [utils.get_act_name(act_name, l) for l in sorted(layers)]

        batch_logits, batch_cache = model.run_with_cache(all_outputs, names_filter=lambda x: x in full_act_names,
                                                         device='cpu',
                                                         padding_side='left')  # logits: (bsz pos vocab) cache: dict
        batch_cache = einops.rearrange([batch_cache[name] for name in full_act_names],
                                       'l b p d -> b l p d').float().cpu()
        batch_cache = einops.rearrange(batch_cache, '(b o) l p d -> b o l p d', o=4)
        batch_cache = batch_cache[:, :, :, [-1], :]

        batch_logits = batch_logits.cpu().float()
        batch_logits = einops.rearrange(batch_logits, '(b o) p v -> b o p v', o=4)

        for i, lg_4 in enumerate(batch_logits):
            num_answer_tokens = all_num_answer_tokens[i]
            num_input_tokens = all_num_input_tokens[i]
            for j, lg in enumerate(lg_4):
                output = all_outputs[i * 4 + j]
                answer_lg = lg[-num_answer_tokens[j] - 1:-1]
                answer_prob = F.softmax(answer_lg, dim=-1)
                answer_target_prob = answer_prob.max(dim=-1).values
                pe = -torch.log(answer_target_prob).sum().item()
                # print(f"pe:{pe}")
                ln_pe = -torch.log(answer_target_prob).mean().item()
                # print(f"ln_pe:{ln_pe}")
                all_pe_u_scores.append(pe)
                all_ln_pe_u_scores.append(ln_pe)

        batch_in_vivo_auroc = []
        for i in range(bsz):
            cache = batch_cache[i]
            u_score = einsum('b l p d, l p d -> b', cache, v_c)
            u_score_z = (u_score - u_score.mean()) / u_score.std()

            u_score = u_score.tolist()
            u_score_z = u_score_z.tolist()

            in_vivo_auroc = roc_auc_score([1, 0, 0, 0], u_score)
            batch_in_vivo_auroc.append(in_vivo_auroc)
            # if u_score[0] > max(u_score[1:]):
            #     batch_in_vivo_auroc.append(1)
            # else:
            #     batch_in_vivo_auroc.append(0)

            c_scores.append(u_score_z[0])
            w_scores.extend(u_score_z[1:])
            labels.extend([1, 0, 0, 0])

            # assert len(u_score) == 4, f"{len(u_score)} {example['options']}"
            u_scores.extend(u_score)
            u_scores_z.extend(u_score_z)

        examples['in_vivo_auroc'] = batch_in_vivo_auroc
        return examples

    new_dst = dst.map(batch_get_result, new_fingerprint=str(time()), batched=True, batch_size=4)

    in_vivo_auroc = sum(new_dst['in_vivo_auroc']) / len(new_dst['in_vivo_auroc'])
    flag = in_vivo_auroc > 0.5
    in_vivo_auroc = in_vivo_auroc if flag else 1 - in_vivo_auroc
    print(f"in-vivo u_score auroc: {in_vivo_auroc}")

    in_vitro_auroc = roc_auc_score(labels, u_scores)
    in_vitro_auroc = in_vitro_auroc if flag else 1 - in_vitro_auroc
    print(f"in-vitro u_score auroc: {in_vitro_auroc}")

    in_vitro_auroc_z = roc_auc_score(labels, u_scores_z)
    in_vitro_auroc_z = in_vitro_auroc_z if flag else 1 - in_vitro_auroc_z
    print(f"in-vitro u_score_z auroc: {in_vitro_auroc_z}")

    in_vitro_pe_auroc = roc_auc_score(labels, all_pe_u_scores)
    print(f"in-vitro pe auroc: {in_vitro_pe_auroc}")

    in_vitro_ln_pe_auroc = roc_auc_score(labels, all_ln_pe_u_scores)
    print(f"in-vitro ln_pe auroc: {in_vitro_ln_pe_auroc}")

    fig.add_trace(go.Histogram(x=c_scores, name='Correct', opacity=0.5, nbinsx=100))
    fig.add_trace(go.Histogram(x=w_scores, name='Wrong', opacity=0.5, nbinsx=100))
    fig.update_layout(barmode='overlay')
    fig.show()

'''
Compared Baseline Methods:
PE(Predictive Entropy),
LN-PE(Length-normalised Predictive Entropy),
SAR(Shifting Attention to more Relevant),
LS(Lexical Similarity),
SE(Semantic Entropy),
SR(Self-Report)
Ours(Activation Based)
'''

# Evaluation: AUROC with Correctness Metric
def get_auroc(val_dst, u_metric, c_metric, c_th):
    c_metrics = val_dst[c_metric]
    label = [1 if c > c_th else 0 for c in c_metrics]
    u_score = val_dst[u_metric]
    auroc = roc_auc_score(label, u_score)
    auroc = auroc if auroc > 0.5 else 1 - auroc
    return auroc

def plot_th_curve(test_dst, u_metrics, c_metric, nbins=10):
    fig = go.Figure()
    th_range = [i / nbins for i in range(1, nbins)]
    accs = []
    c_metrics = test_dst[c_metric]

    for th in th_range:
        acc = 0
        for c in c_metrics:
            if c > th:
                acc += 1
        acc = acc / len(c_metrics)
        accs.append(acc)

    fig.add_trace(go.Scatter(x=th_range, y=accs, mode='lines+markers+text', name=f"acc", text=[f"{a:.4f}" for a in accs], textposition="top center"))

    for u_metric in u_metrics:
        aurocs = []
        for th in th_range:
            aurocs.append(get_auroc(test_dst, u_metric, c_metric, th))
        fig.add_trace(go.Scatter(x=th_range, y=aurocs, mode='lines+markers+text', name=f"{u_metric}", text=[f"{a:.4f}" for a in aurocs], textposition="top center"))
    fig.update_layout(title=f"AUROC/{c_metric}-Threshold Curve", xaxis_title=f"{c_metric}-Threshold", yaxis_title="AUROC", width=2000, height=1000)
    fig.write_image(f"eval_results/{dst_name}_{model_name}_{c_metric}_th_curve.png")
    fig.show()

In [None]:
u_metrics = [k for k in test_dst[0].keys() if k.startswith("u_score") and not k.endswith("all")]
fig = plot_th_curve(test_dst, u_metrics, 'rougel')
fig.show()

In [None]:
example = test_dst.filter(lambda x: x['rougel'] < 0.1)[1]
example = test_dst[2]
print(f"gt:{example['gt']}")
print(f"options:{example['options']}")

str_tokens = model.to_str_tokens(f":{example['washed_answer']}", prepend_bos=False)[1:]
fig = make_subplots(rows=2, cols=1, subplot_titles=("Token Level", "Sentence Level"), row_heights=[0.5, 0.5])

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=example['u_score_pe_all'], mode='lines+markers'), row=1, col=1)
fig.update_xaxes(title_text='Token', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=1, col=1)

sentence_u_score_pe_all = []
indices = [0]+[i for i, x in enumerate(str_tokens) if x == '.']+[-1]
spans = [(indices[i], indices[i+1]) for i in range(len(indices)-1)]
print(len(indices))
for span in spans:
    sentence_score = sum(example['u_score_pe_all'][span[0]:span[1]]) / (span[1] - span[0])
    sentence_u_score_pe_all.extend([sentence_score] * (span[1] - span[0]))
sentence_u_score_pe_all.append(sentence_u_score_pe_all[-1])
# print(str_tokens)
for i, sentence in enumerate(example['washed_answer'].split(".")):
    print(i+1,sentence.replace("\n",' ').strip())
# print(len(example['u_score_pe_all']))
# print(sentence_u_score_pe_all)
# print(len(sentence_u_score_pe_all))

fig.add_trace(go.Scatter(x=list(range(len(str_tokens))), y=sentence_u_score_pe_all, mode='lines+markers'), row=2, col=1)
fig.update_xaxes(title_text='Sentence', tickvals=list(range(len(str_tokens))), ticktext=str_tokens, row=2, col=1)

fig.update_layout(height=1000, width=2500, margin=dict(l=0, r=0, b=50, t=50), title_text=example['washed_answer'])
fig.show()