In [None]:
import json
import re
import pandas as pd
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
from transformers import TextStreamer, pipeline
from langchain.llms import HuggingFacePipeline
from ast import literal_eval

In [None]:
question_prompts = pd.read_csv('question-prompts.csv')

In [None]:
llama_lora_injection_type = 'method1' #change to method 2 if you want to load adapter trained using method 2 of paper

if llama_lora_injection_type == 'method1':
    path = "./llama_method1_injection" #directory of method 1 trained LoRA injection model
else:
    path = "./llama_method2_injection" #directory of method 2 trained LoRA injection model

#change cache_dir path to where you have kept meta-llama/Llama-2-13b-chat-hf
model = AutoPeftModelForCausalLM.from_pretrained(
    path,
    load_in_4bit=True,
    cache_dir="./llama_base_directory",
    device_map='auto',
    use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(path)

In [None]:
streamer = TextStreamer(tokenizer)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=150,
    temperature=0.0,
    top_p=0.95,
    repetition_penalty=1.15,
    streamer=streamer
)

llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
def llm_extract_lst(corrupt_lst):
    lst_extract_prompt = f'''<s>[INST] <<SYS>>
 Example 1: Wrong Format: ['People's Republic of China', 'Laos', 'Thailand', 'India', 'Bangladesh']"]. Correct Format: Answer: "People's Republic of China", "Laos", "Thailand", "India", "Bangladesh"] </s>
 Example 2: Wrong Format: ['Artibonite', 'Nord-Est Department', 'South Department', 'West Department', 'Centre Department', 'Grand'Anse Department', 'North Department']. Correct Format: Answer: "Artibonite", "Nord-Est Department", "South Department", "West Department", "Centre Department", "Grand'Anse Department", "North Department"] </s>
 Example 3: Wrong Format: ['book's and page's']. Correct Format: Answer: ["book's and page's"] </s>
 Your answer should only be a valid python list of string format. Do not give any explainations.
 <</SYS>>

 Use the examples to convert {corrupt_lst} into a correct python list. [/INST] Answer: '''

    with llm.pipeline.model.disable_adapter():
        ans = llm(lst_extract_prompt)
    return ans

In [None]:
def best_match_prompt_llama(sample,options):
    return f'''<s>[INST] <<SYS>>
 You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
 Choose an answer from the options in the context.
 If you dont know the answer from the given context, answer should just be a python empty list.
 <</SYS>>
 context: {options}
 
 {question_prompts[question_prompts['Relation']==sample['Relation']]['PromptTemplate'].tolist()[0].replace('{subject_entity}',sample['SubjectEntity'])} [/INST] Answer:'''

In [None]:
def lst_regex(input_string):
    extracted_list = []
    pattern = r'\[([^\]]+)\]'
    matches = re.findall(pattern, input_string)
    if matches:
        if type(eval(matches[0])) is tuple:
            extracted_list = list(eval(matches[0]))
        else:
            extracted_list = [str(eval(matches[0]))]
    return extracted_list


def extract_list_from_string(input_string):
    try:
        extracted_list = lst_regex(input_string)
    except:
        llm_lst = llm_extract_lst(input_string)
        print("llm_lst", llm_lst, type(llm_lst))
        try:
            extracted_list = lst_regex(llm_lst)
        except:
            extracted_list = []

 

    return extracted_list

In [None]:
def iterate_dataframe_as_generator(df):
    for index, row in df.iterrows():
        yield row

df = pd.read_csv("llama_with_wikidata_info.csv")
df["WikiTitles"] = df["WikiTitles"].apply(literal_eval) 
df["ObjectEntities"] = df["ObjectEntities"].apply(literal_eval) 
df_generator = iterate_dataframe_as_generator(df)

In [None]:
jsondata = []
while True:
    try:
        row = next(df_generator)
        OrigAnsWikiTitles = []
        for object_options in row["WikiTitles"]:
            if object_options and len(object_options)>0 and object_options != [None]:
                formatted_sample = best_match_prompt_llama(row, object_options)
                res = llm(formatted_sample)
                OrigAnsWikiTitles.append(extract_list_from_string(res))
            else:
                OrigAnsWikiTitles.append([])

        row["OrigAnsWikiTitle"] = OrigAnsWikiTitles
        jsondata.append(row)       
    except StopIteration:
        # Break out of the loop when there are no more values to yield
        break
    except Exception as e:
        print(f"Exception occurred for record: {single_sub_ent}")
        print(f"Error message: {e}")
        continue

In [None]:
pd.DataFrame(jsondata).to_csv("llama_stage2.csv")