In [1]:
from openai import OpenAI
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
)  # 
import pandas as pd
import json
import copy
import tqdm
import numpy as np
import re
from collections import Counter

In [3]:
path_val1 = "../data/SHROOM_dev-v2/val.model-agnostic.json"
path_val2 = "../data/SHROOM_dev-v2/val.model-aware.v2.json"

In [30]:
df1 = pd.read_json(path_val1)
df2 = pd.read_json(path_val2)
df = pd.concat([df1,df2],ignore_index=True)
df

Unnamed: 0,hyp,ref,src,tgt,model,task,labels,label,p(Hallucination)
0,Resembling or characteristic of a weasel.,tgt,The writer had just entered into his eighteent...,Resembling a weasel (in appearance).,,DM,"[Hallucination, Not Hallucination, Not Halluci...",Not Hallucination,0.2
1,Alternative form of sheath knife,tgt,Sailors ' and fishermen 's <define> sheath - k...,.,,DM,"[Hallucination, Hallucination, Hallucination, ...",Hallucination,0.8
2,(obsolete) A short period of time.,tgt,"As to age , Bead could not form any clear impr...","(poetic) An instant, a short moment.",,DM,"[Not Hallucination, Not Hallucination, Not Hal...",Not Hallucination,0.0
3,(slang) An incel.,tgt,Because redpillers are usually normies or <def...,"(incel, _, slang) A man of a slightly lower ra...",,DM,"[Not Hallucination, Not Hallucination, Halluci...",Not Hallucination,0.2
4,"An island in Lienchiang County, Taiwan.",tgt,On the second day of massive live - fire drill...,"An island in Dongyin, Lienchiang, Taiwan, in t...",,DM,"[Not Hallucination, Not Hallucination, Not Hal...",Not Hallucination,0.0
...,...,...,...,...,...,...,...,...,...
995,Using a gas-fired device is a way to stop peop...,either,Doonii fayyadamuun meeshaa geejibuun namootaba...,Using ships to transport goods is by far the m...,facebook/nllb-200-distilled-600M,MT,"[Hallucination, Hallucination, Hallucination, ...",Hallucination,1.0
996,Since Montevideo is located south of the equat...,either,وبما أن مونتيفيديو موجودة في جنوب خط الاستواء،...,"Since Montevideo is south of the Equator, it i...",facebook/nllb-200-distilled-600M,MT,"[Not Hallucination, Not Hallucination, Not Hal...",Not Hallucination,0.0
997,He was also a supporter of Singapore's Deputy ...,either,Gin-abiabi hiya han Deputy Prime Minister han ...,He was greeted by Singapore's Deputy Prime Min...,facebook/nllb-200-distilled-600M,MT,"[Hallucination, Hallucination, Not Hallucinati...",Hallucination,0.8
998,The concept of a worm is that the parasitic or...,either,འབུ་ཞེས་པའི་ཐ་སྙད་དེ་ཉིད་འབུ་སྲིན་དཔྱད་རིག་པ་བ...,The term bug is used by entomologists in a for...,facebook/nllb-200-distilled-600M,MT,"[Hallucination, Hallucination, Hallucination, ...",Hallucination,1.0


In [5]:
client = OpenAI(api_key="sk-XXX")

@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def completion_with_backoff(**kwargs):
    return client.chat.completions.create(**kwargs)

## OpenAI rejects requests with repetitive prompts which sometimes can be found in our training dataset
# split by word regex
word_regex = re.compile("(\w[\w']*\w|\w)")
def fix_repetitive(sample):    
    words = word_regex.findall(sample)
    word_counter = Counter(words)
    median = np.median(list(word_counter.values()))
    word, cnt = word_counter.most_common(1)[0]
    if (cnt / median) >= 10: # if count is ten times more then median value its repetitive
        return sample[:sample.find(word)+len(word)]
    return sample


def create_sample(row):
    if row["task"] == "PG":
        context = row['src']
    else: #i.e. task == "MT" or task == "DM":
        context = row['tgt']
    
    return {"context":context,"sentence":fix_repetitive(row['hyp'])}

# def answer_to_label(answer:str):
#     formatted = answer.strip().lower()
#     assert formatted in ["yes","no"], f"Got wrong answer: '{answer}'"
#     return "Hallucination" if formatted == "no" else "Not Hallucination"

def answer_to_label(answer:float):    
    return "Hallucination" if answer < 0.5 else "Not Hallucination"

def detect(samples):
    batch_instruction = {
        "role": "system",
        "content":
        'You will be provided with a Sentence \
        and your task is to rate the consistency of that sentence to \
        that of the provided Context. Your answer must be only \
        a number between 0.0 and 1.0 rounded to the nearest two \
        decimal places where 0.0 represents no consistency and \
        1.0 represents perfect consistency and similarity. \
        Reply with a valid JSON in following format: {"answers":{"<pair_id>": <float>} }. Example: {"answers":{"0":0.7,"12":0.33} }. Array of answers should contain reply for each Context/Sentence pair.'
    }
    for i, s in enumerate(samples):
        s["pair_id"] = str(i)
    completion = client.chat.completions.create(
        model="gpt-3.5-turbo-1106",
        messages=[
            batch_instruction,
            {"role": "user", "content": "Pairs: " + json.dumps(samples,indent=1)}
        ],
        temperature=0,
        response_format = {"type": "json_object"},
        seed=42
    )
    answers = json.loads(completion.choices[0].message.content)["answers"].values()
    return [(answer_to_label(a),a) for a in answers], completion

In [6]:
batch_size = 32 # turned out to be the best size, larger ones result in worse detections
batched_df = [df[i:i+batch_size] for i in range(0,len(df),batch_size)]

In [7]:
start_batch = 0
df_labels = []
for batch in tqdm.tqdm(batched_df[start_batch:],total=len(batched_df[start_batch:])):
    samples = batch.apply(create_sample,axis=1).to_list()
    labels,completion = detect(samples)
    assert len(labels) == len(samples), f"Returned result contains {len(labels)}, not {len(samples)}"
    df_labels.append(labels) # store result per batch

100%|██████████| 32/32 [02:33<00:00,  4.80s/it]


In [31]:
final_labels = []
final_p = []
for batch in df_labels:
    l, p =list(zip(*batch))
    final_labels.extend(l)
    final_p.extend(p)
df["label"] = final_labels
df["p(Hallucination)"] = 1-np.array(final_p)

In [32]:
df.loc[df1.index, :].to_json("labeled-val.model-agnostic.json",orient="records")
df.loc[df2.index, :].to_json("labeled-val.model-aware.v2.json",orient="records")