In [None]:
from langchain.vectorstores import FAISS
from transformers import (
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
)
from typing import List

In [None]:
device = f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu'

In [None]:
context_model_path = "facebook/dpr-ctx_encoder-single-nq-base"
query_model_path = "facebook/dpr-question_encoder-single-nq-base"
cache_dir = "/fs/scratch/ban_bgsw_etm-team/nyn1kor_scratch/"

context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_path,cache_dir=cache_dir)
context_model = DPRContextEncoder.from_pretrained(context_model_path,cache_dir=cache_dir).eval()
query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_model_path,cache_dir=cache_dir)
query_model = DPRQuestionEncoder.from_pretrained(query_model_path,cache_dir=cache_dir).eval()

class DPREmbeddings:

    def encode(self, tokenizer, model, texts, max_length=300, batch_size=64):
        model.to(device)
        outputs = []
        num_texts = len(texts)
        start_idx = 0

        with torch.no_grad():
            while start_idx < num_texts:
                end_idx = min(start_idx + batch_size, num_texts)
                batch_texts = texts[start_idx:end_idx]

                # Tokenize the batch of texts
                inputs = tokenizer(
                    batch_texts,
                    padding="longest",
                    truncation=True,
                    max_length=max_length,
                    return_tensors="pt",
                ).to(device)

                # Forward pass through the model
                output = model(**inputs).pooler_output.cpu().numpy()
                outputs.extend(output)

                start_idx += batch_size

        return outputs

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        texts = [text.replace("\n", " ") for text in texts]
        pooled_outputs = self.encode(context_tokenizer, context_model, texts)
        return pooled_outputs

    def embed_query(self, text: str) -> List[float]:
        pooled_output = self.encode(
            query_tokenizer, query_model, [text], batch_size=1
        )
        return pooled_output[0]

In [None]:
def split_text(text, max_chunk_length=300, overlap=50):
    # Split text into paragraphs using '\n\n'
    paragraphs = text.split('\n\n')

    # Split each paragraph into words
    paragraphs = [paragraph.split() for paragraph in paragraphs]

    # Split each paragraph into chunks with overlap at word level using list comprehension
    split_text_result = [paragraph[i:i + max_chunk_length] for paragraph in paragraphs
                         for i in range(0, len(paragraph), max_chunk_length - overlap)]

    # Join the words back into paragraphs
    split_text_result = [' '.join(chunk) for chunk in split_text_result]

    return split_text_result

In [None]:
def get_k_contexts(q, wiki_txt,k):
    if wiki_txt:
        wiki_chunked = split_text(wiki_txt)
        vectorstore = FAISS.from_texts(wiki_chunked, DPREmbeddings())
        top_contexts = vectorstore.similarity_search(q, k=k)
        top_contexts = [d.page_content for d in top_contexts]
        return "\n".join(top_contexts)
    else:
        return ""

In [None]:
question_prompts = pd.read_csv('question-prompts.csv')

In [None]:
with open('traindata_with_context.jsonl','r') as trainfile:
    train_json_list = list(trainfile)
with open('valdata_with_context.jsonl','r') as valfile:
    val_json_list = list(valfile)    
with open('testdata_with_context.jsonl','r') as testfile:
    test_json_list = list(testfile)

In [None]:
train_df_list=[]
val_df_list=[]
test_df_list=[]

for i,json_str in enumerate(train_json_list):
    result = json.loads(json_str)     
    train_df_list.append(result)
for i,json_str in enumerate(val_json_list):
    result = json.loads(json_str)     
    val_df_list.append(result)
for i,json_str in enumerate(test_json_list):
    result = json.loads(json_str)     
    test_df_list.append(result)

In [None]:
train_df = pd.DataFrame(train_df_list)
train_df = train_df.sample(frac = 1,random_state=8)

val_df = pd.DataFrame(val_df_list)
val_df = val_df.sample(frac = 1,random_state=8)

In [None]:
def format_instruction_llama(sample):
    return f'''<s>[INST] <<SYS>>
 You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
 Give valid wikipedia page titles in the answer.
 <</SYS>>
 {question_prompts[question_prompts['Relation']==sample['Relation']]['PromptTemplate'].tolist()[0].replace('{subject_entity}',sample['SubjectEntity'])} [/INST] Answer: {sample['ObjectEntities']} </s>'''

In [None]:
def get_question(sample):
    return f'''{question_prompts[question_prompts['Relation']==sample['Relation']]['PromptTemplate'].tolist()[0].replace('{subject_entity}',sample['SubjectEntity'])}'''

In [None]:
def get_prompt_with_context_llama(sample,k):
    context = get_k_contexts(sample['question'], sample['subject_text'],k)
    return f'''<s>[INST] <<SYS>>
 You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
 Give valid wikipedia page titles in the answer. The answer should be in a python list of string format.
 If you dont know the answer from both the given context and your past knowledge, answer should just be a python empty list.
 <</SYS>>
 context: '{context}'
    
 {sample['question']} [/INST] Answer: {sample['ObjectEntities']} </s>'''

In [None]:
def get_prompt_without_context_llama(sample):
    return f'''<s>[INST] <<SYS>>
 You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
 Give valid wikipedia page titles in the answer. The answer should be in a python list of string format.
 If you dont know the answer from both the given context and your past knowledge, answer should just be a python empty list.
 <</SYS>>
 context: ''
    
 {sample['question']} [/INST] Answer: {sample['ObjectEntities']} </s>'''

In [None]:
def best_match_prompt_llama(query,options,answer):
    return f'''<s>[INST] <<SYS>>
 You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
 If you dont know the answer from the given context, answer should just be a python empty list.
 <</SYS>>
 context: {options}
 
 Which among the context options is equivalent to {query}? [/INST] Answer: ['{answer}'] </s>'''

In [None]:
def best_match_prompt_beluga(sample,options,answer):
    return f'''### System:
You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
Choose an answer from the options in the context.
If you dont know the answer from the given context, answer should just be a python empty list.

### User:
context: {options}

{question_prompts[question_prompts['Relation']==sample['Relation']]['PromptTemplate'].tolist()[0].replace('{subject_entity}',sample['SubjectEntity'])}

### Assistant
Answer: {answer}'''

In [None]:
def get_prompt_with_context_beluga(sample):
    context = get_k_contexts(sample['question'], sample['subject_text'])
    return f'''### System:
You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
Give valid wikipedia page titles in the answer. The answer should be in a python list of string format.
If you dont know the answer from both the given context and your past knowledge, answer should just be a python empty list.

### User:
context: '{context}'

{sample['question']}

### Assistant
Answer: {sample['ObjectEntities']}'''

In [None]:
def get_prompt_without_context_beluga(sample):
    context = get_k_contexts(sample['question'], sample['subject_text'])
    return f'''### System:
You are a helpful, respectful and honest assistant. Your answers should be crisp, short and not repititive.
Give valid wikipedia page titles in the answer. The answer should be in a python list of string format.
If you dont know the answer from both the given context and your past knowledge, answer should just be a python empty list.

### User:
context: ''

{sample['question']}

### Assistant
Answer: {sample['ObjectEntities']}'''

In [None]:
model = 'llama' #change to beluga if generating data for beluga model

train_df['question']=train_df.apply(lambda x: get_question(x),axis=1)
val_df['question']=val_df.apply(lambda x: get_question(x),axis=1)

if model =='llama':
    train_df['prompt_with_context']=train_df.apply(lambda x: get_prompt_with_context_llama(x,2), axis=1)
    val_df['prompt_with_context']=val_df.apply(lambda x: get_prompt_with_context_llama(x,2), axis=1)
    train_df.to_csv('train_with_top2_context_llama.csv')
    val_df.to_csv('val_with_top2_context_llama.csv')
    
    train_df['prompt_with_context']=train_df.apply(lambda x: get_prompt_without_context_llama(x,2), axis=1)
    val_df['prompt_with_context']=val_df.apply(lambda x: get_prompt_without_context_llama(x,2), axis=1)
    train_df.to_csv('train_without_context_llama.csv')
    val_df.to_csv('val_without_context_llama.csv')
    
    train_df_just_with_prompt = pd.DataFrame(columns=['prompt_with_context'])
    df = pd.read_csv('train_df_with_candidates.csv') #this file is generated by running fetch wiki options notebook with train.jsonl
    count=0
    for row in df.iterrows():
        queries = literal_eval(row[1]['ObjectEntities'])
        options = literal_eval(row[1]['WikiTitles'])
        answers = literal_eval(row[1]['OrigAnsWikiTitle'])
        if len(queries)!= len(answers) or len(queries)!=len(options):
            count+=1
            continue
        else:
            for i, query in enumerate(queries):
                if query.strip()!='':
                    train_df_just_with_prompt.loc[len(train_df_just_with_prompt.index)] = [best_match_prompt_llama(query,str(options[i]),answers[i])]

    train_df_just_with_prompt.to_csv('train_with_wiki_candidates_llama.csv')

    val_df_just_with_prompt = pd.DataFrame(columns=['prompt_with_context'])
    df = pd.read_csv('val_df_with_candidates.csv') #this file is generated by running fetch wiki options notebook with val.jsonl
    count=0
    for row in df.iterrows():
        queries = literal_eval(row[1]['ObjectEntities'])
        options = literal_eval(row[1]['WikiTitles'])
        answers = literal_eval(row[1]['OrigAnsWikiTitle'])
        if len(queries)!= len(answers) or len(queries)!=len(options):
            count+=1
            continue
        else:
            for i, query in enumerate(queries):
                if query.strip()!='':
                    val_df_just_with_prompt.loc[len(val_df_just_with_prompt.index)] = [best_match_prompt_llama(query,str(options[i]),answers[i])]

    val_df_just_with_prompt.to_csv('val_with_wiki_candidates_llama.csv')
    
    train_df_with_context = pd.read_csv('train_with_top2_context_llama.csv')
    train_df_without_context = pd.read_csv('train_without_context_llama.csv')
    train_df_with_wiki = pd.read_csv('train_with_wiki_candidates_llama.csv')

    val_df_with_context = pd.read_csv('val_with_top2_context_llama.csv')
    val_df_without_context = pd.read_csv('val_without_context_llama.csv')
    val_df_with_wiki = pd.read_csv('val_with_wiki_candidates_llama.csv')

    train_df_final = pd.concat([train_df_with_context,train_df_without_context,train_df_with_wiki])
    val_df_final = pd.concat([val_df_with_context,val_df_without_context,val_df_with_wiki])
    
else:
    train_df['prompt_with_context']=train_df.apply(lambda x: get_prompt_with_context_beluga(x,2), axis=1)
    val_df['prompt_with_context']=val_df.apply(lambda x: get_prompt_with_context_beluga(x,2), axis=1)
    train_df.to_csv('train_with_top2_context_beluga.csv')
    val_df.to_csv('val_with_top2_context_beluga.csv')
    
    train_df['prompt_with_context']=train_df.apply(lambda x: get_prompt_without_context_beluga(x,2), axis=1)
    val_df['prompt_with_context']=val_df.apply(lambda x: get_prompt_without_context_beluga(x,2), axis=1)
    train_df.to_csv('train_without_context_beluga.csv')
    val_df.to_csv('val_without_context_beluga.csv')
    
    train_df_just_with_prompt = pd.DataFrame(columns=['prompt_with_context'])
    df = pd.read_csv('train_df_with_candidates.csv') #this file is generated by running fetch wiki options notebook with train.jsonl
    count=0
    for row in df.iterrows():
        queries = literal_eval(row[1]['ObjectEntities'])
        options = literal_eval(row[1]['WikiTitles'])
        answers = literal_eval(row[1]['OrigAnsWikiTitle'])
        if len(queries)!= len(answers) or len(queries)!=len(options):
            count+=1
            continue
        else:
            for i, query in enumerate(queries):
                if query.strip()!='':
                    train_df_just_with_prompt.loc[len(train_df_just_with_prompt.index)] = [best_match_prompt_beluga(query,str(options[i]),answers[i])]

    train_df_just_with_prompt.to_csv('train_with_wiki_candidates_beluga.csv')

    val_df_just_with_prompt = pd.DataFrame(columns=['prompt_with_context'])
    df = pd.read_csv('val_df_with_candidates.csv') #this file is generated by running fetch wiki options notebook with val.jsonl
    count=0
    for row in df.iterrows():
        queries = literal_eval(row[1]['ObjectEntities'])
        options = literal_eval(row[1]['WikiTitles'])
        answers = literal_eval(row[1]['OrigAnsWikiTitle'])
        if len(queries)!= len(answers) or len(queries)!=len(options):
            count+=1
            continue
        else:
            for i, query in enumerate(queries):
                if query.strip()!='':
                    val_df_just_with_prompt.loc[len(val_df_just_with_prompt.index)] = [best_match_prompt_beluga(query,str(options[i]),answers[i])]

    val_df_just_with_prompt.to_csv('val_with_wiki_candidates_beluga.csv')
    
    train_df_with_context = pd.read_csv('train_with_top2_context_beluga.csv')
    train_df_without_context = pd.read_csv('train_without_context_beluga.csv')
    train_df_with_wiki = pd.read_csv('train_with_wiki_candidates_beluga.csv')

    val_df_with_context = pd.read_csv('val_with_top2_context_beluga.csv')
    val_df_without_context = pd.read_csv('val_without_context_beluga.csv')
    val_df_with_wiki = pd.read_csv('val_with_wiki_candidates_beluga.csv')

    train_df_final = pd.concat([train_df_with_context,train_df_without_context,train_df_with_wiki])
    val_df_final = pd.concat([val_df_with_context,val_df_without_context,val_df_with_wiki])
    
train_df_prompt = pd.DataFrame(columns=['prompt'])
val_df_prompt = pd.DataFrame(columns=['prompt'])
train_df_prompt['prompt'] = train_df_final['prompt_with_context']
val_df_prompt['prompt'] = val_df_final['prompt_with_context']
train_df_prompt.to_csv('train_df_combined.csv')
val_df_prompt.to_csv('val_df_combined.csv')

In [None]:
for row in df.iterrows():
    queries = literal_eval(row[1]['ObjectEntities'])
    options = literal_eval(row[1]['WikiTitles'])
    answers = literal_eval(row[1]['OrigAnsWikiTitle'])
    if len(queries)!= len(answers) or len(queries)!=len(options):
        count+=1
        continue
    else:
        for i, query in enumerate(queries):
            if query.strip()!='':
                df_just_with_prompt.loc[len(df_just_with_prompt.index)] = [best_match_prompt(query,str(options[i]),answers[i])]

In [None]:
with open('test_with_top3_context.jsonl','w') as outfile:
    for json_str in test_json_list:
        sample = json.loads(json_str)
        question=question_prompts[question_prompts['Relation']==sample['Relation']]['PromptTemplate'].tolist()[0].replace('{subject_entity}',sample['SubjectEntity'])
        sample['top3context'] = get_k_contexts(question,sample['subject_text'],3)
        del sample['subject_text']
        json.dump(sample,outfile)
        outfile.write('\n')