In [None]:
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from datasets import load_dataset
import re

from selfcheckgpt.modeling_selfcheck import SelfCheckMQAG, SelfCheckBERTScore, SelfCheckNgram
from sklearn.metrics import roc_auc_score
import statistics
import spacy

import torch
import numpy as np
import pandas as pd

In [None]:
gpu = "0"
org="tiiuae"
model_name = "falcon-7b"
repo = f"{org}/{model_name}"
dataset_name = "trivia"
num_samples = 5
start = int(gpu) * num_samples
end = start + num_samples

In [None]:
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(repo)

model = AutoModelForCausalLM.from_pretrained(repo, cache_dir="/home/ec2-user/SageMaker/halu_code/cache/data", 
                                             torch_dtype=torch.bfloat16,
                                             trust_remote_code=True).to(device)

In [None]:
selfcheck_bertscore = SelfCheckBERTScore(rescale_with_baseline=True)
selfcheck_ngram = SelfCheckNgram(n=1) # n=1 means Unigram, n=2 means Bigram, etc.

if dataset_name in ["capitals", "place_of_birth", "founders"]:
    pd_frame = pd.read_csv(f'/home/ec2-user/SageMaker/halu_code/data/{dataset_name}.csv')
    dataset = [(pd_frame.iloc[i]['subject'], pd_frame.iloc[i]['target']) for i in range(start, end)]
elif dataset_name=="trivia":
    trivia_qa = load_dataset('trivia_qa', 'rc.nocontext', cache_dir='/home/ec2-user/SageMaker/halu_code/cache/data')
    full_dataset = []
    for obs in tqdm(trivia_qa['train']):
        aliases = []
        aliases.extend(obs['answer']['aliases'])
        aliases.extend(obs['answer']['normalized_aliases'])
        aliases.append(obs['answer']['value'])
        aliases.append(obs['answer']['normalized_value'])
        full_dataset.append((obs['question'], aliases))
    dataset = full_dataset[start: end]
print ("Loaded training data")
    
num_samples_per_gpu = 10 #1000
start_pos = int(gpu) * num_samples_per_gpu

In [None]:
def get_next_token(x):
    with torch.no_grad():
        return model(x).logits
    
def generate_response(x, max_length=100, pbar=False):
    response = []
    bar = tqdm(range(max_length)) if pbar else range(max_length)
    for step in bar:
        logits = get_next_token(x)
        next_token = logits.squeeze()[-1].argmax()
        x = torch.concat([x, next_token.view(1, -1)], dim=1)
        response.append(next_token)
        if next_token == tokenizer.encode(str(tokenizer._eos_token))[0] and step>5:
            break
    return torch.stack(response).cpu().numpy(), logits.squeeze()

def answer_question(question, tokenizer, max_length=100, pbar=False):
    input_ids = tokenizer(question, return_tensors='pt').input_ids.to(device)
    response, logits = generate_response(input_ids, max_length=max_length, pbar=pbar)
    return response, logits, input_ids.shape[-1]

def generate_responses(question, str_response, tokenizer, temperature, n_trials=3):
    # generate 3 responses to the question and (self)check them against the zero temp response
    inputs = tokenizer(question, return_tensors="pt").input_ids.to(device)
    start_pos = inputs.size(dim=-1)
    
    assert n_trials > 1
    
    hitemp_str_responses = []
    
    for i in range (0, n_trials):
        model_outputs = model.generate(inputs, do_sample=True, temperature=temperature, max_new_tokens=100, return_dict_in_generate=True, output_scores=True)
        generated_tokens_ids = model_outputs.sequences[0]
        response = tokenizer.decode(generated_tokens_ids[start_pos:]).replace("\n", " ").strip()
        hitemp_str_responses.append(response)
        
    selfcheck_scores_bert_overall = []
    selfcheck_scores_bert_average = []
    selfcheck_ngram_overall = []
    
    sentences = [str_response]
    overall_bertscore = selfcheck_bertscore.predict(
        sentences = sentences,                          # list of sentences
        sampled_passages = hitemp_str_responses, # list of sampled passages
        )
    #print(overall_bertscore)
    selfcheck_scores_bert_overall.append(overall_bertscore[0])
    
    nlp = spacy.load("en_core_web_sm")
    sentences = [sent for sent in nlp(str_response).sents]
    sentences = [sent.text.strip() for sent in sentences if len(sent) > 3]
    all_bertscores = selfcheck_bertscore.predict(
        sentences = sentences,                          # list of sentences
        sampled_passages = hitemp_str_responses, # list of sampled passages
        )
    #print(all_bertscores)
    average_bertscore = statistics.mean(all_bertscores)
    selfcheck_scores_bert_average.append(average_bertscore)
      
    
    sent_scores_ngram = selfcheck_ngram.predict(
        sentences = sentences,   
        passage = str_response,
        sampled_passages = hitemp_str_responses,
    )
    #print(sent_scores_ngram)
    selfcheck_ngram_overall.append(sent_scores_ngram)
    
          
    return hitemp_str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall

def answer_trivia(question, targets, tokenizer, temperature):
    response, logits, start_pos = answer_question(question, tokenizer)
    str_response = tokenizer.decode(response, skip_special_tokens=True)
    str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall =\
            generate_responses(question, str_response, tokenizer, temperature)
    correct = False
    for alias in targets:
        if alias.lower() in str_response.lower():
            correct = True
            break
    return response, str_response, logits, start_pos, correct,\
            str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall

def answer_capitals(source, target, tokenizer, temperature):
    question = f"What is the capital of {source}?"
    response, logits, start_pos = answer_question(question, tokenizer)
    str_response = tokenizer.decode(response, skip_special_tokens=True)
    str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall =\
            generate_responses(question, str_response, tokenizer, temperature)
    correct = False
    if target.lower() in str_response.lower():
        correct = True
    return response, str_response, logits, start_pos, correct,\
            str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall

def answer_birth_place(source, target, tokenizer, temperature):
    question = f"Where was {source} born?"
    response, logits, start_pos = answer_question(question, tokenizer)
    str_response = tokenizer.decode(response, skip_special_tokens=True)
    str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall =\
            generate_responses(question, str_response, tokenizer, temperature)
    correct = False
    if target.lower() in str_response.lower():
        correct = True
    return response, str_response, logits, start_pos, correct,\
            str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall

def answer_founders(source, target, tokenizer, temperature):
    question = f"Who founded {source}?"
    response, logits, start_pos = answer_question(question, tokenizer)
    str_response = tokenizer.decode(response, skip_special_tokens=True)
    str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall =\
            generate_responses(question, str_response, tokenizer, temperature)
    correct = False
    if target.lower() in str_response.lower():
        correct = True
    return response, str_response, logits, start_pos, correct,\
            str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall

In [None]:
temperature = 1.0
    
selfcheck_dict = {
        'question': [],
        'response': [],
        'str_response': [],
        'start_pos': [],
        'correct': [],
        'hitemp_str_responses': [],
        'selfcheck_scores_bert_overall': [],
        'selfcheck_scores_bert_average': [],
        'selfcheck_ngram_overall': []
    }

selfcheck_arr_overall = []
selfcheck_arr_average = []
selfcheck_ngram_average = []
correct_arr = []

if dataset_name=="trivia":
    answer_func = answer_trivia
elif dataset_name=="capitals":
    answer_func = answer_capitals
elif dataset_name=="place_of_birth":
    answer_func = answer_birth_place
elif dataset_name=="founders":
    answer_func = answer_birth_place

In [None]:
for idx in tqdm(range(start, end)):
    try:
        entry = dataset[idx]
        question = entry[0]
        answer = entry[1]
        response, str_response, logits, start_pos, correct,\
                  hitemp_str_responses, selfcheck_scores_bert_overall, selfcheck_scores_bert_average, selfcheck_ngram_overall =\
                                            answer_func(question, answer, tokenizer, temperature)
        input_ids = tokenizer(question, return_tensors='pt').input_ids.to(device)
    except:
        continue
    selfcheck_dict['question'].append(question)
    selfcheck_dict['response'].append(response)
    selfcheck_dict['str_response'].append(str_response)
    selfcheck_dict['start_pos'].append(start_pos)
    selfcheck_dict['correct'].append(correct)
    selfcheck_dict['hitemp_str_responses'].append(hitemp_str_responses)
    selfcheck_dict['selfcheck_scores_bert_overall'].append(selfcheck_scores_bert_overall)
    selfcheck_dict['selfcheck_scores_bert_average'].append(selfcheck_scores_bert_average)
    selfcheck_dict['selfcheck_ngram_overall'].append(selfcheck_ngram_overall)

    selfcheck_arr_overall.append(1.0-selfcheck_scores_bert_overall[0]) #bert score flipped
    selfcheck_arr_average.append(1.0-selfcheck_scores_bert_average[0]) #bert score flipped
    selfcheck_ngram_average.append(1.0-np.exp(-selfcheck_ngram_overall[0]['doc_level']['avg_neg_logprob']))
    correct_arr.append(int(correct))
    

In [None]:
#print(selfcheck_arr_overall)
#print(correct_arr)
roc_score = roc_auc_score(correct_arr, selfcheck_arr_overall)
print(f"AUROC for self check overall: {roc_score}")

#print(selfcheck_arr_average)
#print(correct_arr)
roc_score = roc_auc_score(correct_arr, selfcheck_arr_average)
print(f"AUROC for self check average: {roc_score}")

roc_score = roc_auc_score(correct_arr, selfcheck_ngram_average)
print(f"AUROC for self check ngram: {roc_score}")

In [None]:
# with open(f"selfcheck_{model_name}_{dataset_name}_{gpu}.pickle", "wb") as outfile:
#         outfile.write(pickle.dumps(selfcheck_dict))

In [None]:
selfcheck_dict['hitemp_str_responses'][0]

In [None]:
selfcheck_dict['hitemp_str_responses'][0]