In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
import numpy as np
from scipy import stats
from scipy.stats import wasserstein_distance, entropy
import pickle

In [None]:
#there's an issue with the size of the logits space, different for opt-125m and gpt2. 50272 is facebook/opt125m. 50257 for gpt2-xl

def get_aligned_logits(text, model_name='facebook/opt-125m'):

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    input_ids = tokenizer.encode(text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids)
    logits = outputs.logits

    logits_t0_to_k_minus_2 = logits[:, :-1, :]
    tokens_t1_to_k_minus_1 = input_ids[:, 1:]

    return logits_t0_to_k_minus_2, tokens_t1_to_k_minus_1

def logits_to_probabilities(logits,temp):
    probabilities = F.softmax(logits*(1/temp), dim=-1)
    return probabilities

def get_rank(probs, token_index):
    rank=1
    for i in range(len(probs)):
        if probs[i] > probs[token_index]:
            rank+=1
    return rank

def process_text(text, model_name='facebook/opt-125m',temp=1, initial_cutoff=0, final_cutoff=500):

    logits, input_ids = get_aligned_logits(text, model_name)
    probabilities = logits_to_probabilities(logits,temp)

    if model_name=='gpt2-xl':
        vocab_size=50257
    if model_name=='facebook/opt-125m':
        vocab_size=50272

    chosen_token_rank_list = []
    chosen_token_prob_list = []
    node_entropy_list=[]
    node_variance_list=[]
    node_tempnorm_list=[]

    for i in range(initial_cutoff,min(probabilities.size(1),final_cutoff)):
        all_token_probs = probabilities[0, i, :]
        token_id = input_ids[0, i]
        all_token_probs_list = all_token_probs.tolist()

        chosen_token_prob = all_token_probs[token_id]
        chosen_token_rank= get_rank(all_token_probs_list, token_id)
        node_entropy = stats.entropy(all_token_probs_list)
        node_variance=np.sum(all_token_probs_list*((np.log(all_token_probs_list) + node_entropy)**2))


        tempnorm=0
        for j in range(len(all_token_probs_list)):
            tempnorm+=all_token_probs_list[j]**(1/temp)

        chosen_token_rank_list.append(chosen_token_rank)
        chosen_token_prob_list.append(chosen_token_prob)
        node_entropy_list.append(node_entropy)
        node_variance_list.append(node_variance)
        node_tempnorm_list.append(tempnorm)
    return chosen_token_rank_list, chosen_token_prob_list, node_entropy_list, node_variance_list, node_tempnorm_list

In [None]:
def process_text_batch(texts, model_name='facebook/opt-125m',temp=1, initial_cutoff=0, final_cutoff=500):
  text_batch_data={}
  for i in range(len(texts)):
    print(i)
    chosen_token_rank_list, chosen_token_prob_list, node_entropy_list, node_variance_list, node_tempnorm_list = process_text(texts[i], model_name, temp, initial_cutoff, final_cutoff)
    text_batch_data[i]={}
    text_batch_data[i]['chosen_token_rank_list']=chosen_token_rank_list
    text_batch_data[i]['chosen_token_prob_list']=chosen_token_prob_list
    text_batch_data[i]['node_entropy_list']=node_entropy_list
    text_batch_data[i]['node_variance_list']=node_variance_list
    text_batch_data[i]['node_tempnorm_list']=node_tempnorm_list
  return text_batch_data

In [None]:
def fast_detect_score(chosen_token_prob_list, node_entropy_list, node_variance_list, initial_cutoff=30, final_cutoff=200):
  final_variance=0
  final_log_prob=0
  final_entropy=0
  for i in range(initial_cutoff, min(len(chosen_token_prob_list), final_cutoff)):
    final_variance+=node_variance_list[i]
    final_log_prob+=np.log(chosen_token_prob_list[i])
    final_entropy+=node_entropy_list[i]

  fd_score=(final_log_prob+final_entropy)/((final_variance)**0.5)
  return fd_score, final_variance, final_log_prob, final_entropy

In [None]:
def temptest_score(chosen_token_log_prob_list, node_tempnorm_list, temp, initial_cutoff=30, final_cutoff=200):
  temptest_scores=[]
  for i in range(initial_cutoff, min(len(chosen_token_log_prob_list), final_cutoff)):
    temptest_scores.append((-((1/temp)-1)*chosen_token_log_prob_list[i]+np.log(node_tempnorm_list[i])))
  final_temptest_score=np.mean(temptest_scores)
  return temptest_scores, final_temptest_score

In [None]:
def scoring_text_batch(text_batch_data ,temp=1, initial_cutoff=30, final_cutoff=200):
  scoring_data={}
  for i in range(len(text_batch_data)):
    chosen_token_rank_list=text_batch_data[i]['chosen_token_rank_list']
    chosen_token_prob_list=text_batch_data[i]['chosen_token_prob_list']
    node_entropy_list=text_batch_data[i]['node_entropy_list']
    node_variance_list=text_batch_data[i]['node_variance_list']
    node_tempnorm_list=text_batch_data[i]['node_tempnorm_list']
    fd_score, final_variance, final_log_prob, final_entropy=fast_detect_score(chosen_token_prob_list, node_entropy_list, node_variance_list, initial_cutoff, final_cutoff)
    temptest_scores, final_temptest_score=temptest_score(chosen_token_prob_list, node_tempnorm_list, temp, initial_cutoff, final_cutoff)
    scoring_data[i]={}
    scoring_data[i]['fd_score']=fd_score
    scoring_data[i]['final_temptest_score']=final_temptest_score
  return scoring_data