In [None]:
%load_ext autoreload
%autoreload 2


from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from datasets import load_dataset
import functools
import pickle

from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckBERTScore, SelfCheckNgram
from sklearn.metrics import roc_auc_score
import statistics
import spacy

from result_collector import trex_data_to_question_template, answer_trivia, answer_trex, load_data, model_dir

import torch
import numpy as np

In [None]:
org="tiiuae"
model_name = "falcon-7b"
repo = f"{org}/{model_name}"

# Data related params
dataset_name = "trivia_qa"

# GPU
gpu = "0"
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

# SelfCheckGPT
self_checkgpt_temperature = 1.0
selfcheckgpt_n_trials = 20

In [None]:
dataset = load_data(dataset_name)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(repo, cache_dir=model_dir, torch_dtype=torch.bfloat16, trust_remote_code=True).to(device)

In [None]:
selfcheck_bertscore = SelfCheckBERTScore(rescale_with_baseline=True)
selfcheck_ngram = SelfCheckNgram(n=1) # n=1 means Unigram, n=2 means Bigram, etc.

In [None]:
def generate_responses(question, str_response, tokenizer):

    # generate several responses to the question and (self)check them against the zero temp response
    inputs = tokenizer(question, return_tensors="pt").input_ids.to(device)
    start_pos = inputs.size(dim=-1)

    hitemp_str_responses = []
    for i in range(0, selfcheckgpt_n_trials):
        model_outputs = model.generate(
            inputs, do_sample=True, temperature=self_checkgpt_temperature, max_new_tokens=100, return_dict_in_generate=True, output_scores=True
        )
        generated_tokens_ids = model_outputs.sequences[0]
        hitemp_str_responses.append(tokenizer.decode(generated_tokens_ids[start_pos:]).replace("\n", " ").strip())

    selfcheck_scores_bert_overall = []
    selfcheck_scores_bert_average = []
    selfcheck_ngram_overall = []
    
    sentences = [str_response]
    overall_bertscore = selfcheck_bertscore.predict(
        sentences = sentences,                          # list of sentences
        sampled_passages = hitemp_str_responses, # list of sampled passages
    )
    selfcheck_scores_bert_overall.append(overall_bertscore[0])
    
    nlp = spacy.load("en_core_web_sm")
    sentences = [sent for sent in nlp(str_response).sents]
    sentences = [sent.text.strip() for sent in sentences if len(sent) > 3]
    all_bertscores = selfcheck_bertscore.predict(
        sentences = sentences,                          # list of sentences
        sampled_passages = hitemp_str_responses, # list of sampled passages
    )
    average_bertscore = statistics.mean(all_bertscores)
    selfcheck_scores_bert_average.append(average_bertscore)
      
    
    sent_scores_ngram = selfcheck_ngram.predict(
        sentences = sentences,   
        passage = str_response,
        sampled_passages = hitemp_str_responses,
    )
    selfcheck_ngram_overall.append(sent_scores_ngram)
    
    return hitemp_str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall


In [None]:
selfcheck_dict = {
        'question': [],
        'response': [],
        'str_response': [],
        'start_pos': [],
        'correct': [],
        'hitemp_str_responses': [],
        'selfcheck_scores_bert_overall': [],
        'selfcheck_scores_bert_average': [],
        'selfcheck_ngram_overall': []
    }

selfcheck_arr_overall = []
selfcheck_arr_average = []
selfcheck_ngram_average = []
correct_arr = []

if dataset_name in trex_data_to_question_template.keys():
    question_asker = functools.partial(answer_trex, question_template=trex_data_to_question_template[dataset_name])
elif dataset_name == "trivia_qa":
    question_asker = answer_trivia
else:
    raise ValueError(f"Unknown dataset name {dataset_name}.")


In [None]:
for idx in tqdm(range(len(dataset))):

    question, answers = dataset[idx]
    response, str_response, logits, start_pos, correct = question_asker(question, answers, model, tokenizer)
    hitemp_str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall\
        = generate_responses(
            question if dataset_name=="trivia_qa" else trex_data_to_question_template[dataset_name].substitute(source=question),
            str_response, 
            tokenizer
        )

    selfcheck_dict['question'].append(question)
    selfcheck_dict['response'].append(response)
    selfcheck_dict['str_response'].append(str_response)
    selfcheck_dict['start_pos'].append(start_pos)
    selfcheck_dict['correct'].append(correct)
    selfcheck_dict['hitemp_str_responses'].append(hitemp_str_responses)
    selfcheck_dict['selfcheck_scores_bert_overall'].append(selfcheck_scores_bert_overall)
    selfcheck_dict['selfcheck_scores_bert_average'].append(selfcheck_scores_bert_average)
    selfcheck_dict['selfcheck_ngram_overall'].append(selfcheck_ngram_overall)

    selfcheck_arr_overall.append(1.0-selfcheck_scores_bert_overall[0]) #bert score flipped
    selfcheck_arr_average.append(1.0-selfcheck_scores_bert_average[0]) #bert score flipped
    selfcheck_ngram_average.append(1.0-np.exp(-selfcheck_ngram_overall[0]['doc_level']['avg_neg_logprob']))
    correct_arr.append(int(correct))
    

In [None]:
#print(selfcheck_arr_overall)
#print(correct_arr)
roc_score = roc_auc_score(correct_arr, selfcheck_arr_overall)
print(f"AUROC for self check overall: {roc_score}")

#print(selfcheck_arr_average)
#print(correct_arr)
roc_score = roc_auc_score(correct_arr, selfcheck_arr_average)
print(f"AUROC for self check average: {roc_score}")

roc_score = roc_auc_score(correct_arr, selfcheck_ngram_average)
print(f"AUROC for self check ngram: {roc_score}")

In [None]:
with open(f"selfcheck_{model_name}_{dataset_name}_{gpu}.pickle", "wb") as outfile:
        outfile.write(pickle.dumps(selfcheck_dict))

In [None]:
selfcheck_dict['hitemp_str_responses'][0]

In [None]:
selfcheck_dict['hitemp_str_responses'][0]