In [9]:
import sys
import logging

import datasets
from datasets import Dataset
from datasets import load_dataset
from peft import LoraConfig
import torch
import transformers
from trl import SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
import json
import numpy as np

In [None]:
with open('chaine_of_thought_examples.json','r') as f:
    examples = json.load(f)

template = '{question} Option 1 : {Option1} or Option 2 : {Option2}. lets think step by step. {answer} \n \n '

text = ""
for example in examples:
    text += template.format(question=example['question'],Option1=example['Option1'],Option2=example['Option2'],answer=example['answer'])

print(text)

In [None]:
checkpoint_path = 'models/phi3'#"microsoft/Phi-3-medium-128k-instruct"
model_kwargs = dict(
    use_cache=False,
    trust_remote_code=True,
    #attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support
    torch_dtype=torch.bfloat16,
    device_map=None
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)#, **model_kwargs) #Uncomment for download
checkpoint_path = 'tokenizer/phi3'#"microsoft/Phi-3-medium-128k-instruct" #"tokenizer/phi3/"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.model_max_length = 512#2048
tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
tokenizer.padding_side = 'right'
model.to('mps')

In [12]:

def make_Dataset(path,tokenizer) -> Dataset:
    system_message = 'you are an AI assistant trained to answere Questions. You will decide wether " Option 1 " or " Option 2 " is correct. Think step by step'
    
    with open(path,'r') as f:
        data = json.load(f)

    with open('chaine_of_thought_examples.json','r') as f:
        examples = json.load(f)

    data_dict = {'messages':[],'labels':[]}
    for item in data:
        
        messages = [{"role":"system", "content": " " + system_message }]
        template = '{question} Option 1 : {Option1} or Option 2 : {Option2}. lets think step by step.'
        for example in examples:
            question = template.format(question=example['question'],Option1=example['Option1'],Option2=example['Option2'])
            messages.append({"role": "user", "content": " " + question})
            messages.append({"role": "assistant", "content": " " + example['answer']})
        #for i in range(2):
        #    example= data[np.random.randint(0, len(data))]
        #    example_question = example['question'] + ' Option 1 : ' + example['Option1'] + ' Option 2 : ' + example['Option2'] + ' lets think step by step'
        #    messages.append({"role": "user", "content": " " + example_question})
        #    messages.append({"role": "assistant", "content": " The answer is: " + example['answer']})

        question = template.format(question=item['question'],Option1=item['Option1'],Option2=item['Option2']) #item['question'] + ' Option 1 : ' + item['Option1'] + ' Option 2 : ' + item['Option2'] + ' lets think step by step'
        #messages = [{"role":"system", "content": " " + system_message },{"role": "user", "content": " " + question},{"role": "assistant", "content": " " + answer}]
        messages.append({"role": "user", "content": " " + question})
        #messages.append({"role": "assistant", "content": " The answer is: " + answer})
        input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        data_dict['messages'].append(input)
        data_dict['labels'].append(item["answer"])

        # add same question reversed Options
        #question = item['question'] + ' Option 1 : ' + item['Option2'] + ' Option 2 : ' + item['Option1']
        #messages = [{"role":"system", "content":" " + system_message },{"role": "user", "content": " " + question},{"role": "assistant", "content": " " + answer_reversed}]
        #input = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
        #data_dict['messages'].append(input)
    print(np.max([len(m) for m in data_dict['messages'] ]))
    return  Dataset.from_dict(data_dict) 

    

In [None]:
def apply_chat_template(
    example,
    tokenizer,
):
    messages = example["messages"]
    example["text"] = messages#tokenizer.apply_chat_template(
        #messages, tokenize=False, add_generation_prompt=False)
    return example

train_dataset = make_Dataset('data/QQA/QQA_train.json',tokenizer)
test_dataset = make_Dataset('data/QQA/QQA_dev.json',tokenizer)
column_names =  list(['messages'])

processed_train_dataset = train_dataset.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    num_proc=10,
    remove_columns=column_names,
    desc="Applying chat template to train_sft",
)

processed_test_dataset = test_dataset.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    num_proc=10,
    remove_columns=column_names,
    desc="Applying chat template to test_sft",
)


In [None]:
processed_train_dataset[0]

In [18]:
def predict(index):
    input = tokenizer.encode(processed_train_dataset[index]["text"],return_tensors="pt", max_length=4096)
    input = input.to('mps')
    output = model.generate(input, max_new_tokens=512)
    output_text = tokenizer.batch_decode(output)[0]
    print('\n \n----------------------------')
    print("output Text: /n" + output_text)
    prediction = output_text.split("<|assistant|>")[-1]
    print("\n \n predicted answer : \n" + prediction)
    answer = processed_train_dataset[index]["labels"]
    print("\n \n the answer was : \n" + answer)

    if answer in prediction:
        print("correct predited")
        return True
    else:
        return False
    

In [None]:
predict(13)


In [None]:
predict(5)

In [None]:
true_count = 0
total_count = 0
for index,_ in enumerate(processed_train_dataset):
    res = predict(index)

    total_count += 1
    if res:
        true_count += 1

    print(f'{total_count} -- {total_count/len(processed_train_dataset)} -- accuracy : {true_count / total_count}')