In [1]:
import json
from torch.utils.data import Dataset, DataLoader

In [2]:
def load_tatqa_dataset(path):
    with open(path, "r") as f:
        return json.load(f)

train_data = load_tatqa_dataset("/cs/student/projects1/aibh/2024/tpatil/comp0087/FrugalML/datasets/TATQA/tatqa_dataset_train.json")

In [5]:
examples = train_data[0]

print(examples["table"]["table"])
# for example in examples:
#     # print((example["table"]).get("table", []))
#     print(example["table"]["table"])
#     # print(example["question"])
#     # print(example.keys())

[['', '2019 %', '2018 %', '2017 %'], ['Weighted average actuarial assumptions used at 31 March1:', '', '', ''], ['Rate of inflation2', '2.9', '2.9', '3.0'], ['Rate of increase in salaries', '2.7', '2.7', '2.6'], ['Discount rate', '2.3', '2.5', '2.6']]


In [3]:
for train_sample in train_data:
    table = train_sample['table']['table']
    paragraphs = train_sample['paragraphs']
    questions = train_sample['questions']

    for question_answer in questions:
        try:
            question = question_answer["question"].strip()
            answer = question_answer["answer"]
            derivation = question_answer["derivation"]
            answer_type = question_answer["answer_type"]
            answer_from = question_answer["answer_from"]
            facts = question_answer["facts"]
            answer_mapping = question_answer["mapping"]
            scale = question_answer["scale"]
        
        except RuntimeError as e :
            print(f"run time error:{e}")

In [14]:
def create_prompt_instance(table, paragraphs, question, answer, derivation=None):
    set_context = "You are an intelligent financial data analyst. You are given a table with financial data. You are also given a paragraph that provides some context about the data in the table. You are asked a question about the data in the table or paragraph. You are expected to answer the question based on the data in the table and the paragraph.\n"
    # table_prompt = "The first element of the table array contains the column names. In the following elements, the first element is the row name and the rest of the elements are the values in the row assigned to the respective columns."
    table_prompt = "The table provide the financial data. All the elements in the table are separated by \"|\". The first row of the table contains the column names. In the following rows, the first column contains the row name and the rest of the elements are the values in the row assigned to the respective columns. Interpret the table and use the data in it to calculate the answer to the provided quesions.\n"
    paragraph_prompt = "The paragraphs provides some context about the data in the table. It may contain information that is not present in the table. It may also contain some numbers which might require arithmatic processing to get the answer. There may be multiple paragraphs separated by keyword matching \"Paragraph [0-9]+:\". Interpret each paragraph and use the data and description in it to infer the answer to the provided quesions.\n"
    question_prompt = "The question is asked based on the data in the table and the paragraph. You are expected to answer the question based on the data in the table and the paragraph.\n"
    answer_prompt = "" #"You are expected to answer the question based on the data in the table and the paragraph. Provide only the answer to the question and do not repeat the question. Use the answers provided as labels to learn the correct way to answer the question.\n"
    # derivation_prompt = "The derivation provides the steps to calculate the answer to the question. You can learn how to use the derivation to calculate the answer to the question.\n"
    # answer_type_prompt = "The answer type is the type of the answer to the question, and informs how answer was constructed.\n"
    # answer_from_prompt = "The answer from is the source of the answer to the question. It can be from table, paragraph, or both.\n"
    # facts_prompt = "The facts are the information that is used to answer the question. It can be from table, paragraph, or both.\n"
    # answer_mapping_prompt = "The answer mapping is the mapping of the answer to the question. It can be from table, paragraph, or both. If it is from the table, you will be given the index of the relevant values in the table. If it is from the paragraph, you will be given the start and end indexes of the relevant phrase in the paragraph. All indexes begin with 0.\n"
    # scale_prompt = "The scale is the scale of the numerical answer to the question. It can be from table, paragraph, or both.\n"
    answer_instruction_prompt = "\nInstruction: Answer the question based on the data in the table and the paragraph. Provide only the answer to the question and do not repeat the question. Use the answers provided as labels to learn the correct way to answer the question.\n"

    # table_idx = 1
    # for table in tables:
    table_prompt += f"\nTable:\n"
    for row in table["table"]:
        # table_prompt += "|"
        table_prompt += "|".join([str(cell) for cell in row]) + " \n"
        # table_idx += 1
    
    for paragraph in paragraphs:
        paragraph_prompt += f"\nParagraph {paragraph['order']}:\n"
        paragraph_prompt += paragraph['text'] + " "

    question_prompt += f"\nQuestion:\n {question}"
    answer_prompt += f"\nAnswer:\n {answer}"
    # derivation_prompt += f"\nDerivation:\n {derivation}"
    # answer_type_prompt += f"\nAnswer Type:\n {answer_type}"
    # answer_from_prompt += f"\nAnswer From:\n {answer_from}"
    # facts_prompt += f"\nFacts:\n {facts}"
    # answer_mapping_prompt += f"\nAnswer Mapping:\n {answer_mapping}"
    # scale_prompt += f"\nScale:\n {scale}"

    return set_context, table_prompt, paragraph_prompt, question_prompt, answer_prompt, answer_instruction_prompt #, derivation_prompt, answer_type_prompt, answer_from_prompt, facts_prompt, answer_mapping_prompt, scale_prompt

In [27]:
class TATQADataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.qa_pairs = []

        for item in self.data:
        # item = self.data[idx]
            tables = item.get('table', [])
            paragraphs = item.get('paragraphs', [])
            questions = item.get('questions', [])

            for question_answer in questions:
                question = question_answer["question"].strip()
                answer = question_answer["answer"]
                # derivation = question_answer.get("derivation", "")
                # answer_type = question_answer.get("answer_type", "")
                # answer_from = question_answer.get("answer_from", "")
                # facts = question_answer.get("facts", "")
                # answer_mapping = question_answer.get("mapping", "")
                # scale = question_answer.get("scale", "")
                # rel_paragraphs = question_answer.get("rel_paragraphs", "")
                # req_comparison = question_answer.get("req_comparison", "")

                set_context, table_prompt, paragraph_prompt, question_prompt, answer_prompt, answer_instruction_prompt = create_prompt_instance(tables, paragraphs, question, answer)
                input_context = (set_context + table_prompt + paragraph_prompt + question_prompt + answer_instruction_prompt).strip()
                # label_text = (answer_prompt + derivation_prompt + answer_type_prompt + answer_from_prompt + facts_prompt + answer_mapping_prompt + scale_prompt).strip()
                label_text = answer_prompt.strip()

                # qa_pairs.append({
                #     "input_ids": inputs.input_ids.squeeze(),
                #     "attention_mask": inputs.attention_mask.squeeze(),
                #     "labels": labels
                # })

                self.qa_pairs.append((
                    input_context,
                    label_text
                ))
        # print(self.qa_pairs[-1])

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        input_context, label_text = self.qa_pairs[idx]

        inputs = self.tokenizer( 
                    input_context, 
                    # max_length=self.max_length, 
                    # truncation=True, 
                    # padding="max_length", 
                    return_tensors="pt"
                )

        labels = self.tokenizer(
                    label_text, 
                    # max_length=self.max_length, 
                    # truncation=True, 
                    # padding="max_length", 
                    return_tensors="pt"
                )

        input_ids = inputs["input_ids"].squeeze(0)
        attention_mask = inputs["attention_mask"].squeeze(0)
        labels = labels["input_ids"].squeeze(0)


        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
            # "input_context": input_context,
            # "label_text": label_text
        }

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
model_name = "microsoft/Phi-3.5-mini-instruct"
cache_dir = "/cs/student/projects1/aibh/2024/tpatil/.cache/huggingface"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

# tokenizer = AutoTokenizer.from_pretrained("gpt2")  # Replace with appropriate model
# model = AutoModelForCausalLM.from_pretrained("gpt2")

In [None]:
dataset = TATQADataset(train_data, tokenizer)
dataloader = DataLoader(
    dataset,
    batch_size=1, 
    shuffle=True
    )
# next(iter(dataloader))

{'input_ids': tensor([[  887,   526,   385, 13052,   296, 18161,   848,  3483,   858, 29889,
            887,   526,  2183,   263,  1591,   411, 18161,   848, 29889,   887,
            526,   884,  2183,   263, 14880,   393,  8128,   777,  3030,  1048,
            278,   848,   297,   278,  1591, 29889,   887,   526,  4433,   263,
           1139,  1048,   278,   848,   297,   278,  1591,   470, 14880, 29889,
            887,   526,  3806,   304,  1234,   278,  1139,  2729,   373,   278,
            848,   297,   278,  1591,   322,   278, 14880, 29889,    13,  1576,
           1591,  3867,   278, 18161,   848, 29889,  2178,   278,  3161,   297,
            278,  1591,   526, 13055,   491,   376, 29989,  1642,   450,   937,
           1948,   310,   278,  1591,  3743,   278,  1897,  2983, 29889,   512,
            278,  1494,  4206, 29892,   278,   937,  1897,  3743,   278,  1948,
           1024,   322,   278,  1791,   310,   278,  3161,   526,   278,  1819,
            297,   278,  19