# Imported Libraries

In [None]:
# !python -m pip install -q "transformers == 4.25.1"

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BartForSequenceClassification, BartTokenizer, GPTJForCausalLM
import torch
import pandas as pd
import numpy as np
import string
import re
import ast
import gc

torch.cuda.empty_cache()
gc.collect()

0

# Data Preprocessing

In [None]:
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co)"
starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov)"
digits = "([0-9])"

def split_into_sentences(text):
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    if "..." in text: text = text.replace("...","<prd><prd><prd>")
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    if "www." in text: text = text.replace("www.","www<prd>")
    if "Gen." in text: text = text.replace("Gen.", "Gen<prd>")
    if "Sgt." in text: text = text.replace("Sgt.", "Sgt<prd>")
    if "Lt." in text: text = text.replace("Lt.", "Lt<prd>")
    if "Rep." in text: text = text.replace("Rep.", "Rep<prd>")
    if "Sen." in text: text = text.replace("Sen.", "Sen<prd>")
    if "Jan." in text: text = text.replace("Jan.", "Jan<prd>")
    if "Feb." in text: text = text.replace("Feb.", "Feb<prd>")
    if "Apr." in text: text = text.replace("Apr.", "Apr<prd>")
    if "Aug." in text: text = text.replace("Aug.", "Aug<prd>")
    if "Sept." in text: text = text.replace("Sept.", "Sept<prd>")
    if "Oct." in text: text = text.replace("Oct.", "Oct<prd>")
    if "Nov." in text: text = text.replace("Nov.", "Nov<prd>")
    if "Dec." in text: text = text.replace("Dec.", "Dec<prd>")
    if "etc." in text: text = text.replace("etc.", "etc")
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)
    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = sentences[:-1]
    sentences = [s.strip() for s in sentences]

    if len(sentences) == 0:
        sentences = [text.strip()]
    
    return sentences

def calculate_text_error_types(text):

    error_dict = {'redundant': 0, 'off_prompt': 0, 'self_contradiction': 0, 'incoherent': 0}

    reviewer_counts = sum(len(str(i)) > 2 for i in ast.literal_eval(text))
    
    rd_count = sum(1 for i in ast.literal_eval(text) if "\'Redundant\'" in str(i))
    op_count = sum(1 for i in ast.literal_eval(text) if "\'Off-prompt\'" in str(i))
    sc_count = sum(1 for i in ast.literal_eval(text) if "\'Self-contradiction\'" in str(i))
    in_count = sum(1 for i in ast.literal_eval(text) if "\'Incoherent\'" in str(i))

    error_dict['redundant'] = np.round(error_dict['redundant'] + (rd_count / reviewer_counts), 2)
    error_dict['off_prompt'] = np.round(error_dict['off_prompt'] + (op_count / reviewer_counts), 2)
    error_dict['self_contradiction'] = np.round(error_dict['self_contradiction'] + (sc_count / reviewer_counts), 2)
    error_dict['incoherent'] = np.round(error_dict['incoherent'] + (in_count / reviewer_counts), 2)

    return error_dict['redundant'], error_dict['off_prompt'], error_dict['self_contradiction'], error_dict['incoherent']

nli_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")
nli_model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")

def nli_final_comparison(premise, hypothesis, drop = None):

    # 0 = contradiction, 1 = neutral, 2 = entailment

    tokens = nli_tokenizer(premise, hypothesis, return_tensors="pt")
    outputs = nli_model(**tokens)
    logits = outputs.logits

    if drop == 'contradiction':
        logits = logits[:,[1,2]]
        probs = logits.softmax(dim=1)
        contradiction_prob = 0
        neutral_prob = probs[0][0].item()
        entailment_prob = probs[0][1].item()
    elif drop == 'neutral':
        logits = logits[:,[0,2]]
        probs = logits.softmax(dim=1)
        contradiction_prob = probs[0][0].item()
        neutral_prob = 0
        entailment_prob = probs[0][1].item()
    elif drop == 'entailment':
        logits = logits[:,[0,1]]
        probs = logits.softmax(dim=1)
        contradiction_prob = probs[0][0].item()
        neutral_prob = probs[0][1].item()
        entailment_prob = 0
    else:
        probs = logits.softmax(dim=1)
        contradiction_prob = probs[0][0].item()
        neutral_prob = probs[0][1].item()
        entailment_prob = probs[0][2].item()
            
    return np.round(entailment_prob*100, 4), np.round(neutral_prob*100, 4), np.round(contradiction_prob*100, 4)

def calculate_gen_text_nli_tag(the_prompt, the_gen_text):

    nli_class = ''
    nli_probs = ''

    initial_prompt = the_prompt

    gpt3_gen_text = (the_gen_text).strip()

    entailment_prob_gpt3, neutral_prob_gpt3, contradiction_prob_gpt3 = nli_final_comparison(initial_prompt, gpt3_gen_text)

    if entailment_prob_gpt3 > neutral_prob_gpt3 and entailment_prob_gpt3 > contradiction_prob_gpt3:
        nli_class = 'ENT'
    elif neutral_prob_gpt3 > entailment_prob_gpt3 and neutral_prob_gpt3 > contradiction_prob_gpt3:
        nli_class = 'NEU'
    elif contradiction_prob_gpt3 > entailment_prob_gpt3 and contradiction_prob_gpt3 > neutral_prob_gpt3:
        nli_class = 'CON'
    else:
        nli_class = 'ERR'

    nli_probs = str(entailment_prob_gpt3) + ' - ' + str(neutral_prob_gpt3) + ' - ' + str(contradiction_prob_gpt3)

    return nli_class, nli_probs


Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

In [None]:
df_scarecrow = pd.read_csv('Scarecrow_Initial_Dataset.csv')

df_scarecrow = df_scarecrow[df_scarecrow['model'] == 'gpt3'].copy()
df_scarecrow = df_scarecrow.reset_index(drop = True)

vectorized_function_error_types = np.vectorize(lambda text: calculate_text_error_types(text))
df_scarecrow['redundant'], df_scarecrow['off_prompt'], df_scarecrow['self_contradiction'], df_scarecrow['incoherent'] = vectorized_function_error_types(df_scarecrow.responses)

vectorized_function_nli_classes = np.vectorize(lambda the_prompt, the_gen_text: calculate_gen_text_nli_tag(the_prompt, the_gen_text))
df_scarecrow['GPT3_NLI'], df_scarecrow['GPT3_NLI_PROBS'] = vectorized_function_nli_classes(df_scarecrow.prompt, df_scarecrow.generation)

df_scarecrow_temp1 = df_scarecrow[df_scarecrow['temperature'] == 1].copy()


In [None]:
df_scarecrow_temp1_tidy = pd.DataFrame(columns = ['id', 'gid', 'prompt', 'generation', 'model', 'p', 'frequency_penalty', 'responses', 'redundant', 'off_prompt', 'self_contradiction', 'incoherent', 'NLI_Label', 'NLI_Prob'])

for i in range(len(df_scarecrow_temp1)):
    tidy_id = df_scarecrow_temp1.iloc[i]['id']
    tidy_gid = df_scarecrow_temp1.iloc[i]['gid']
    tidy_prompt = df_scarecrow_temp1.iloc[i]['prompt']
    tidy_generation = df_scarecrow_temp1.iloc[i]['generation']
    tidy_model = df_scarecrow_temp1.iloc[i]['model']
    tidy_p = df_scarecrow_temp1.iloc[i]['p']
    tidy_frequency_penalty = df_scarecrow_temp1.iloc[i]['frequency_penalty']
    tidy_responses = df_scarecrow_temp1.iloc[i]['responses']
    tidy_redundant = df_scarecrow_temp1.iloc[i]['redundant']
    tidy_off_prompt = df_scarecrow_temp1.iloc[i]['off_prompt']
    tidy_self_contradiction = df_scarecrow_temp1.iloc[i]['self_contradiction']
    tidy_incoherent = df_scarecrow_temp1.iloc[i]['incoherent']
    tidy_GPT3_NLI = df_scarecrow_temp1.iloc[i]['GPT3_NLI']
    tidy_GPT3_NLI_PROBS = df_scarecrow_temp1.iloc[i]['GPT3_NLI_PROBS']

    tidy_GPT3_ENT = float(tidy_GPT3_NLI_PROBS[0: tidy_GPT3_NLI_PROBS.find(' - ')])
    tidy_GPT3_NEU = float(tidy_GPT3_NLI_PROBS[tidy_GPT3_NLI_PROBS.find(' - ') + len(' - '): tidy_GPT3_NLI_PROBS.rfind(' - ')])
    tidy_GPT3_CON = float(tidy_GPT3_NLI_PROBS[tidy_GPT3_NLI_PROBS.rfind(' - ') + len(' - '): len(tidy_GPT3_NLI_PROBS)])

    df_scarecrow_temp1_tidy = df_scarecrow_temp1_tidy.append({'id' : tidy_id, 'gid' : tidy_gid, 'prompt' : tidy_prompt, 'generation' : tidy_generation, 'model' : tidy_model, 'p' : tidy_p, 'frequency_penalty' : tidy_frequency_penalty, 'responses' : tidy_responses, 
                                                              'redundant' : tidy_redundant, 'off_prompt' : tidy_off_prompt, 'self_contradiction' : tidy_self_contradiction, 'incoherent' : tidy_incoherent, 'NLI_Label' : 'ENT', 'NLI_Prob' : tidy_GPT3_ENT}, 
                                                             ignore_index = True)
    
    df_scarecrow_temp1_tidy = df_scarecrow_temp1_tidy.append({'id' : tidy_id, 'gid' : tidy_gid, 'prompt' : tidy_prompt, 'generation' : tidy_generation, 'model' : tidy_model, 'p' : tidy_p, 'frequency_penalty' : tidy_frequency_penalty, 'responses' : tidy_responses, 
                                                              'redundant' : tidy_redundant, 'off_prompt' : tidy_off_prompt, 'self_contradiction' : tidy_self_contradiction, 'incoherent' : tidy_incoherent, 'NLI_Label' : 'NEU', 'NLI_Prob' : tidy_GPT3_NEU}, 
                                                             ignore_index = True)
    
    df_scarecrow_temp1_tidy = df_scarecrow_temp1_tidy.append({'id' : tidy_id, 'gid' : tidy_gid, 'prompt' : tidy_prompt, 'generation' : tidy_generation, 'model' : tidy_model, 'p' : tidy_p, 'frequency_penalty' : tidy_frequency_penalty, 'responses' : tidy_responses, 
                                                              'redundant' : tidy_redundant, 'off_prompt' : tidy_off_prompt, 'self_contradiction' : tidy_self_contradiction, 'incoherent' : tidy_incoherent, 'NLI_Label' : 'CON', 'NLI_Prob' : tidy_GPT3_CON}, 
                                                             ignore_index = True)


# Descriptive Analysis for Error Types and NLI Classes

In [None]:
# Table 1:

print('--- (p = 0.4) --- \n')

print('--- ENT ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['GPT3_NLI'] == 'ENT')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))
print('--- NEU ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['GPT3_NLI'] == 'NEU')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))
print('--- CON ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['GPT3_NLI'] == 'CON')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))

print('\n--- (p = 0.96) --- \n')

print('--- ENT ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['GPT3_NLI'] == 'ENT')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))
print('--- NEU ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['GPT3_NLI'] == 'NEU')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))
print('--- CON ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['GPT3_NLI'] == 'CON')]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))


--- (p = 0.4) --- 

--- ENT ---
12.931034482758621
--- NEU ---
83.62068965517241
--- CON ---
3.4482758620689653

--- (p = 0.96) --- 

--- ENT ---
1.3793103448275863
--- NEU ---
86.20689655172414
--- CON ---
12.413793103448276


In [None]:
# Table 2:

print('--- (p = 0.4) --- \n')

print('--- All ---')
print(100 * df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4)]['GPT3_NLI'].value_counts() / df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4)]['GPT3_NLI'].value_counts().sum())
print('--- CO ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['off_prompt'] < 0.5) & (df_scarecrow_temp1['self_contradiction'] < 0.5) & (df_scarecrow_temp1['incoherent'] < 0.5) & (df_scarecrow_temp1['redundant'] < 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['off_prompt'] < 0.5) & (df_scarecrow_temp1['self_contradiction'] < 0.5) & (df_scarecrow_temp1['incoherent'] < 0.5) & (df_scarecrow_temp1['redundant'] < 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- OP ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['off_prompt'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['off_prompt'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- SC ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- IN ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['incoherent'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['incoherent'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- RD ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['redundant'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['redundant'] >= 0.5)])['GPT3_NLI'].value_counts().sum())

print('\n--- (p = 0.96) --- \n')

print('--- All ---')
print(100 * df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96)]['GPT3_NLI'].value_counts() / df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96)]['GPT3_NLI'].value_counts().sum())
print('--- CO ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['off_prompt'] < 0.5) & (df_scarecrow_temp1['self_contradiction'] < 0.5) & (df_scarecrow_temp1['incoherent'] < 0.5) & (df_scarecrow_temp1['redundant'] < 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['off_prompt'] < 0.5) & (df_scarecrow_temp1['self_contradiction'] < 0.5) & (df_scarecrow_temp1['incoherent'] < 0.5) & (df_scarecrow_temp1['redundant'] < 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- OP ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['off_prompt'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['off_prompt'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- SC ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- IN ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['incoherent'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['incoherent'] >= 0.5)])['GPT3_NLI'].value_counts().sum())
print('--- RD ---')
print(100 * (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['redundant'] >= 0.5)])['GPT3_NLI'].value_counts() / (df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['redundant'] >= 0.5)])['GPT3_NLI'].value_counts().sum())



--- (p = 0.4) --- 

--- All ---
NEU    83.620690
ENT    12.931034
CON     3.448276
Name: GPT3_NLI, dtype: float64
--- CO ---
NEU    95.000000
ENT     3.333333
CON     1.666667
Name: GPT3_NLI, dtype: float64
--- OP ---
NEU    87.5
CON    12.5
Name: GPT3_NLI, dtype: float64
--- SC ---
NEU    75.0
CON    25.0
Name: GPT3_NLI, dtype: float64
--- IN ---
NEU    100.0
Name: GPT3_NLI, dtype: float64
--- RD ---
NEU    68.888889
ENT    28.888889
CON     2.222222
Name: GPT3_NLI, dtype: float64

--- (p = 0.96) --- 

--- All ---
NEU    86.206897
CON    12.413793
ENT     1.379310
Name: GPT3_NLI, dtype: float64
--- CO ---
NEU    91.623037
CON     7.853403
ENT     0.523560
Name: GPT3_NLI, dtype: float64
--- OP ---
NEU    76.923077
CON    23.076923
Name: GPT3_NLI, dtype: float64
--- SC ---
NEU    80.0
CON    20.0
Name: GPT3_NLI, dtype: float64
--- IN ---
NEU    63.636364
CON    31.818182
ENT     4.545455
Name: GPT3_NLI, dtype: float64
--- RD ---
NEU    76.666667
CON    16.666667
ENT     6.666667
Name: G

In [None]:
# Figure 2:

print('--- (p = 0.4) --- \n')

print('--- OP ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['off_prompt'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))
print('--- SC ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))
print('--- IN ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['incoherent'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))
print('--- RD ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.4) & (df_scarecrow_temp1['redundant'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.4]))

print('\n--- (p = 0.96) --- \n')

print('--- OP ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['off_prompt'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))
print('--- SC ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['self_contradiction'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))
print('--- IN ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['incoherent'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))
print('--- RD ---')
print(100 * len(df_scarecrow_temp1[(df_scarecrow_temp1['p'] == 0.96) & (df_scarecrow_temp1['redundant'] >= 0.5)]) / len(df_scarecrow_temp1[df_scarecrow_temp1['p'] == 0.96]))

--- (p = 0.4) --- 

--- OP ---
6.896551724137931
--- SC ---
3.4482758620689653
--- IN ---
0.8620689655172413
--- RD ---
38.793103448275865

--- (p = 0.96) --- 

--- OP ---
17.93103448275862
--- SC ---
3.4482758620689653
--- IN ---
7.586206896551724
--- RD ---
10.344827586206897


# Experimental Evaluation of NLI Strategies versus Vanilla GPTJ

In [None]:
gen_model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16)
gen_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

def gen_text_from_vanilla_gptj(initial_prompt, top_p_parameter, temperature_parameter):

    if initial_prompt[-1] not in string.punctuation:
        initial_prompt = initial_prompt + '.'

    init_temp = temperature_parameter
    init_p = top_p_parameter

    text_generator = pipeline(task = 'text-generation', model = gen_model, tokenizer = gen_tokenizer, framework = 'pt', device = 0, top_p = init_p, temperature = init_temp, pad_token_id = 50256)

    prompt = initial_prompt

    vanilla_text = text_generator(prompt, min_length = 30, max_length = 256)
    vanilla_text = vanilla_text[0].get('generated_text')
    vanilla_text = (vanilla_text[len(prompt):len(vanilla_text)]).strip()

    vanilla_sent_list = split_into_sentences(vanilla_text)

    vanilla_text = " ".join(vanilla_sent_list)

    return vanilla_text

def gen_text_from_nli_gptj(initial_prompt, top_p_parameter, temperature_parameter, mode = 'NEU'):

    if initial_prompt[-1] not in string.punctuation:
        initial_prompt = initial_prompt + '.'

    init_temp = temperature_parameter
    init_p = top_p_parameter

    text_generator = pipeline(task = 'text-generation', model = gen_model, tokenizer = gen_tokenizer, framework = 'pt', device = 0, top_p = init_p, temperature = init_temp, pad_token_id = 50256)

    initial_prompt_text_len = len(initial_prompt)
    initial_prompt_sent_len = len(split_into_sentences(initial_prompt))

    prompt = initial_prompt
    prompt_sent_len = len(split_into_sentences(prompt))

    failed_update_counter = 0
    failed_update_array = []

    while True:

        prev_prompt = prompt
        prev_prompt_sent_len = len(split_into_sentences(prev_prompt))

        if (len(prompt) >= initial_prompt_text_len + 256) and (len(split_into_sentences(prompt)) >= 3):
            break

        generated_text = text_generator(prompt, min_length = 16, max_length = 128)
        generated_text = generated_text[0].get('generated_text')
        generated_text = (generated_text[len(prompt):len(generated_text)]).strip()

        prompt_sent_list = split_into_sentences(prompt)
        generated_text_sent_list = split_into_sentences(generated_text)

        for candidate_sent in generated_text_sent_list:
            flag = 1.0

            for existing_sent in prompt_sent_list:
                entailment, neutral, contradiction = nli_final_comparison(existing_sent, candidate_sent, 'None')

                if mode == 'ENT':
                    if entailment < contradiction:
                        flag = 0.0
                        break
                elif mode == 'CON':
                    if contradiction < entailment:
                        flag = 0.0
                        break
                else:
                    if neutral < 0.85:
                        flag = 0.0
                        break

            if (flag == 1.0):
                prompt_sent_list.append(candidate_sent)
                prompt = " ".join(prompt_sent_list)
            else:
                break

            if (len(prompt) >= initial_prompt_text_len + 256) and (len(split_into_sentences(prompt)) >= 3):
                break

        prompt = " ".join(prompt_sent_list)
        prompt_sent_len = len(split_into_sentences(prompt))

        if prev_prompt_sent_len >= prompt_sent_len:
            failed_update_counter = failed_update_counter + 1

            if failed_update_counter >= 7:
                failed_update_array.append(failed_update_counter)
                failed_update_counter = 0
                if init_p == 0.4:
                    if mode == 'ENT':
                        prompt = prompt + ' |NLI-LOW_P-STOP-ENT| '
                    elif mode == 'CON':
                        prompt = prompt + ' |NLI-LOW_P-STOP-CON| '
                    else:
                        prompt = prompt + ' |NLI-LOW_P-STOP-NEU| '
                elif init_p == 0.96:
                    if mode == 'ENT':
                        prompt = prompt + ' |NLI-HIGH_P-STOP-ENT| '
                    elif mode == 'CON':
                        prompt = prompt + ' |NLI-HIGH_P-STOP-CON| '
                    else:
                        prompt = prompt + ' |NLI-HIGH_P-STOP-NEU| '
                else:
                    prompt = prompt + ' |ERROR| '
                break
        else:
            failed_update_array.append(failed_update_counter)
            if failed_update_counter > 0:
                failed_update_counter = 0  

    prompt = prompt[initial_prompt_text_len+1:]

    if len(failed_update_array) > 0:
        avg_num_of_failed_generations = sum(failed_update_array) / len(failed_update_array)
    else:
        avg_num_of_failed_generations = 0
    
    return prompt, avg_num_of_failed_generations

Downloading:   0%|          | 0.00/836 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/12.1G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [None]:
df_scarecrow_sample = pd.concat([df_scarecrow[(df_scarecrow['p'] == 0.4) & (df_scarecrow['temperature'] == 1)], df_scarecrow[(df_scarecrow['p'] == 0.96) & (df_scarecrow['temperature'] == 1)].sample(116, random_state = 22532)])
df_scarecrow_sample = df_scarecrow_sample.reset_index(drop = True)

scarecrow_subset_indices = [i for i in range(0, 25)] + [i for i in range(116, 141)]
scarecrow_offensive_content_indices = [3, 15, 125, 129]
list_of_indices = [i for i in scarecrow_subset_indices if i not in scarecrow_offensive_content_indices]

for i in list_of_indices:

    prompt_text = df_scarecrow_sample.iloc[i]['prompt']

    # Vanilla GPTJ

    vanilla_gptj_generated_text_low_p = gen_text_from_vanilla_gptj(prompt_text, 0.4, 1.0)
    vanilla_gptj_generated_text_high_p = gen_text_from_vanilla_gptj(prompt_text, 0.96, 1.0)

    # NLI GPTJ - ENT

    nli_gptj_ent_generated_text_low_p = gen_text_from_nli_gptj(prompt_text, 0.4, 1.0, 'ENT')
    nli_gptj_ent_generated_text_high_p = gen_text_from_nli_gptj(prompt_text, 0.96, 1.0, 'ENT')

    # NLI GPTJ - NEU

    nli_gptj_neu_generated_text_low_p = gen_text_from_nli_gptj(prompt_text, 0.4, 1.0, 'NEU')
    nli_gptj_neu_generated_text_high_p = gen_text_from_nli_gptj(prompt_text, 0.96, 1.0, 'NEU')

    # NLI GPTJ - CON

    nli_gptj_con_generated_text_low_p = gen_text_from_nli_gptj(prompt_text, 0.4, 1.0, 'CON')
    nli_gptj_con_generated_text_high_p = gen_text_from_nli_gptj(prompt_text, 0.96, 1.0, 'CON')

    # Print

    print('-----------------', i, '-----------------')

    print('Prompt: ', prompt_text)

    print('Vanilla GPTJ (Low p): ', vanilla_gptj_generated_text_low_p)
    print('NLI GPTJ - ENT (Low p): ', nli_gptj_ent_generated_text_low_p)
    print('NLI GPTJ - NEU (Low p): ', nli_gptj_neu_generated_text_low_p)
    print('NLI GPTJ - CON (Low p): ', nli_gptj_con_generated_text_low_p)

    print('Vanilla GPTJ (High p): ', vanilla_gptj_generated_text_high_p)
    print('NLI GPTJ - ENT (High p): ', nli_gptj_ent_generated_text_high_p)
    print('NLI GPTJ - NEU (High p): ', nli_gptj_neu_generated_text_high_p)
    print('NLI GPTJ - CON (High p): ', nli_gptj_con_generated_text_high_p)
