In [None]:
import json
import math
import random
import torch


# train_data = MyLLMDataloader(4, tokenizer, "cleaned_TeleQnA_train_context_gte.json", shuffle=True)
import pandas as pd
import os

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from tqdm.auto import tqdm


import matplotlib.pyplot as plt
import pandas as pd
import itertools

from statistics import mode


In [None]:
with open("questions_new_final_backup.json", "r") as f:
    test_data = f.read()

orig_test_data=json.loads(test_data)

In [None]:
from string import Template
prompt_q_without_contex_train= Template('''Instruct: $question
$options
$question
''')


prompt_without_contex_train= Template('''Instruct: $question
Abbreviations: $abbreviation
          
Considering the following contexts:
context 1: $context1
context 2: $context2
context 3: $context3      
                                                                    
$question
$options
Output: option ''')
# prompt_without_context = f'Hello {planet}'

In [None]:
def clean_question(question):
    for num in [14, 15, 16, 17, 18]:
        question = question.replace(f"[3GPP Release {num}]", "")
    return question

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
BASE_MODEL_ID = "logs/a/model"
device = "cuda"
torch.set_default_device(device)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
model.to(device)


In [None]:
class MyLLMDataloader:
    def __init__(self, batch_size, tokenizer, data, shuffle = False, val= False):
        ## initializations
        self.batch_size  = batch_size
        self.tokenizer  = tokenizer
        self.tokenizer.pad_token = self.tokenizer.eos_token
        with open(data, "r") as f:
            self.data = json.load(f)
        self.all_examples = list(self.data.keys())
        self.shuffle = shuffle
        self.val = val
        
        self.n_data_points = math.ceil(len(self.data)/self.batch_size)
        self.indices = [i for i in range(self.n_data_points)]
        
    def __getitem__(self, idx):
        ## this gets a batch 
        option_header = ["option 1 ", "option 2 ", "option 3 ", "option 4 ", "option 5 "]
        batch_start_id = idx * self.batch_size
        batch_end_id  = min(len(self.data), batch_start_id + self.batch_size) 
        batch = {"question_context":[], "answer":[]}
        
        for i in range(batch_start_id, batch_end_id):
            example = self.data[self.all_examples[i]]
            options = []
            opts = []
            for key in example.keys():
                if key.startswith("opt"):
                    if example[key] == None:
                        continue
                    options.append(example[key])
                    opts.append((example[key], key.split("option ")[1]))
            
            string_opts = ' '.join(opt[0] for opt in opts)
            batch_prompts = []
            option_maps = []
         
            if not ("option" in string_opts or "above" in string_opts):

                    all_permutations = list(itertools.permutations(opts))
                    all_permutations = random.sample(all_permutations, 20 if len(all_permutations)>20 else len(all_permutations))
                    for option_set in all_permutations:
                        option_map = []
                        options_with_header   = []
                        for z in range(len(option_set)):

                            options_with_header.append(option_header[z] +option_set[z][0])

                         
                            option_map.append(int(option_set[z][1]))
                        
                        
                        options_with_header = "\n".join(options_with_header)

                        prompt = prompt_without_contex_train.substitute(question = clean_question(example["question"]),\
                        abbreviation='\n'.join(example["abbreviation"]), context1 = '\n'.join(example["context_qwen2"][:2]) , context2 = '\n'.join(example["context_gle"]), context3 = '\n'.join(example["context_bm"][:2]),
                        options =options_with_header)
                        batch_prompts.append(prompt)
                        option_maps.append(option_map)


            else:
                options_with_header = [option_header[i] +options[i] for i in range(len(options)) ]
               
                options_with_header = "\n".join(options_with_header)
                prompt = prompt_without_contex_train.substitute(question = clean_question(example["question"]),\
                abbreviation='\n'.join(example["abbreviation"]), context1 = '\n'.join(example["context_qwen2"][:2]) , context2 = '\n'.join(example["context_gle"]), context3 = '\n'.join(example["context_bm"][:2]),
                options =options_with_header)
                batch_prompts.append(prompt)
                

                


            batch["question_context"] += batch_prompts

            # batch["answer"] += [answer]

        self.tokenizer.padding_side = "left"
        q_tokens = self.tokenizer(batch["question_context"], padding="longest", return_tensors="pt")  
        self.tokenizer.padding_side = "right"
        # a_tokens = self.tokenizer(batch["answer"], padding="longest", return_tensors="pt")
        tokens = q_tokens
        attn_masks = q_tokens["attention_mask"]
        # attn_masks = torch.cat([q_tokens["attention_mask"], a_tokens["attention_mask"]], dim=1)
        # loss_mask = torch.cat([torch.zeros_like(q_tokens["attention_mask"]), a_tokens["attention_mask"]], dim=1)[:,1:]
   
        result = {
        "inp_ids":tokens["input_ids"],
        "inp_mask":attn_masks,## Causal Training
        "option_maps": option_maps
        }

        # result["loss_mask"] = loss_mask * result["out_mask"]
        # result["out_ids"][:,:q_tokens["input_ids"].size(1)-10] = self.tokenizer.eos_token_id

        return result       


            

    def __iter__(self):
        self.idx = 0
        return self

    def __next__(self):
        if self.idx >= self.n_data_points:
            self.idx = 0
            raise StopIteration
        temp_idx = self.indices[self.idx]
        self.idx += 1
        return self[temp_idx]
             








            




    
    def __len__(self):
        return self.n_data_points
    



        

In [None]:
def forward_pass(model, batch):
    inp_ids = batch["inp_ids"].to(model.device)
    attn_mask = batch["inp_mask"].to(model.device)
    result = model(input_ids=inp_ids, attention_mask=attn_mask)
    logits = result.logits
    return logits

In [None]:
def inference(model, testLoader):
    my_ans = {"Answer_ID": []}
    k  = 0 
    model.eval()          
    option_ids = [tokenizer(o).input_ids[0] for o in ["1", "2", "3", "4", "5"]]
    pbar = tqdm(range(len(testLoader)), ncols=100)
    for item in testLoader:
  
        # print(item["inp_ids"])
        # if int(tokenizer.decode(item["a_tokens"].input_ids[:,0], skip_special_tokens=True)) ==0:
        if len(item["option_maps"]) >0:
            # print('\n'.join(tokenizer.batch_decode(item['inp_ids']))),
            # first_half = item.copy()

            # batch = first_half["inp_ids"]


 
            logits = []
            batched_data = item.copy()
            for batch, attn_data in zip(item["inp_ids"], item["inp_mask"]):
                batched_data["inp_ids"] = torch.unsqueeze(batch,0)
                batched_data["inp_mask"] = torch.unsqueeze(attn_data,0)


              
                # print(first_half["inp_ids"].shape, second_half["inp_ids"].shape)
                with torch.inference_mode():
                    with torch.autocast(device_type=device, dtype=torch.float16):
                        # gen_tokens = model.generate(inputs=item["inp_ids"].to(device), max_new_tokens=1)

                        logits1 = forward_pass(model, batched_data)
                        # logits2 = forward_pass(model, second_half)
                  
                logits.append(logits1)
            logits = torch.cat(logits,axis=0)
            # print(logits.shape)
            
            preds =(logits[:, -1, option_ids ].argmax(dim=1) )

            z = torch.tensor(item["option_maps"])
            preds = torch.mode(z.gather(1, preds.unsqueeze(1)).squeeze(1)).values
        
            my_ans["Answer_ID"].append(preds.item())
            # print("logits", tokenizer.batch_decode(gen_tokens,  skip_special_tokens=True)[0] )
            # print(tokenizer.decode(item["a_tokens"].input_ids[:,0], skip_special_tokens=True))
            # break
        else:
            with torch.inference_mode():
                with torch.autocast(device_type=device, dtype=torch.float16):
                    # gen_tokens = model.generate(inputs=item["inp_ids"].to(device), max_new_tokens=1)
                    logits = forward_pass(model, item)
                preds =(logits[:, -1, option_ids ].argmax(dim=1) +1)
                my_ans["Answer_ID"].append(preds.item())

        pd.DataFrame(my_ans).to_csv("law_ans.csv")
        pbar.set_description(f"Prediction: {preds}")
        pbar.update(1)
    return my_ans

    