In [1]:
import json
import torch
import torch.nn as nn
import bitsandbytes as bnb
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, BitsAndBytesConfig
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
from fuzzywuzzy import fuzz
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cache_dir = "/cs/student/projects1/aibh/2024/tpatil/.cache/huggingface/"
base_model_name = "meta-llama/Llama-3.1-8B-Instruct"
token = "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"

quant_config = BitsAndBytesConfig(load_in_8bit=True) 

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    # load_in_8bit=True,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    cache_dir=cache_dir,
    token=token,
    quantization_config=quant_config
)

base_model_tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    trust_remote_code=True,
    cache_dir=cache_dir,
    token=token
)
base_model_tokenizer.pad_token = base_model_tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.15s/it]


## Data Pre-processing

In [3]:
def load_finqa_dataset(path):
    with open(path, "r") as f:
        return json.load(f)
    
train_data_finqa = load_finqa_dataset("/cs/student/projects1/aibh/2024/tpatil/comp0087/FrugalML/datasets/FinQA/train.json")
val_data_finqa = load_finqa_dataset("/cs/student/projects1/aibh/2024/tpatil/comp0087/FrugalML/datasets/FinQA/dev.json")
test_data_finqa = load_finqa_dataset("/cs/student/projects1/aibh/2024/tpatil/comp0087/FrugalML/datasets/FinQA/test.json")

In [4]:
def create_prompt_instance(table, pre_text, post_text, question, answer, program_reasoning=None):
    set_context = "You are an intelligent financial data analyst. You are given a table with financial data. You are also given a pre-text and a post-text that provide some context about the data in the table. You are asked a question about the data in the table or paragraph. You are expected to answer the question based on the data in the table and the paragraph.\n"
    table_prompt = "The table provide the financial data. All the elements in the table are separated by \"|\". The first row of the table contains the column names. In the following rows, the first column contains the row name and the rest of the elements are the values in the row assigned to the respective columns. Interpret the table and use the data in it to calculate the answer to the provided quesions.\n"
    pre_text_prompt = "The pre_text provides some context about the data before the table. It may contain information that is not present in the table. It may also contain some numbers which might require arithmatic processing to get the answer. There may be multiple sentences separated by comma (,). Interpret each pre_text paragraph and use the data and description in it to infer the answer to the provided quesions.\n"
    post_text_prompt = "The post_text provides some context about the data after the table. It may contain information that is not present in the table. It may also contain some numbers which might require arithmatic processing to get the answer. There may be multiple sentences separated by comma (,). Interpret each post_text paragraph and use the data and description in it to infer the answer to the provided quesions.\n"
    question_prompt = "The question is asked based on the data in the table and the pre-text and the post-text. You are expected to answer the question based on this data in the table, the pre-text, and the post-text.\n"
    answer_prompt = "" #"You are expected to answer the question based on the data in the table and the paragraph. Provide only the answer to the question and do not repeat the question. Use the answers provided as labels to learn the correct way to answer the question.\n"
    # program_reasoning_prompt = "The program reasoning provides the steps to calculate the answer to the question. You can learn how to use the program_reasoning to calculate the answer to the question.\n"
    answer_instruction_prompt = "\n INSTRUCTION: Answer the question based on the data in the table and the pre-text and the post-test. Provide only the answer to the question and do not repeat the question. Use the answers provided as labels to learn the correct way to answer the question. Let's think step by step! \nAnswer: "
    # answer_instruction_prompt = """
    # INSTRUCTIONS:
    # - Provide a direct and concise answer to the question based on the data in the table and the pre-text and the post-test
    # - Do NOT repeat the pre-text, table, post-text or question in your response
    # - Focus only on answering the question based on the given context
    # - Let's think step by step!

    # Answer: """

    table_prompt += f"\nTable:\n"
    for row in table:
        table_prompt += "|".join([str(cell) for cell in row]) + " \n"

    pre_text_prompt += f"\nPre_text: {pre_text}"
    post_text_prompt += f"\nPost_text: {post_text}"
    question_prompt += f"\Question: {question}"
    answer_prompt += f"\Answer: {answer}"
    # program_reasoning_prompt += f"\nDerivation:\n {derivation}"

    return set_context, table_prompt, pre_text_prompt, post_text_prompt, question_prompt, answer_prompt, answer_instruction_prompt #, program_reasoning_prompt

In [5]:
class FinQADataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.qa_pairs = []

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        tables = item.get('table', [])
        pre_text = item.get('pre_text', [])
        post_text = item.get('post_text', [])
        qa = item.get('qa', {})

        question = qa["question"].strip()
        answer = qa["answer"]

        set_context, table_prompt, pre_text_prompt, post_text_prompt, question_prompt, answer_prompt, answer_instruction_prompt = create_prompt_instance(tables, pre_text, post_text, question, answer)
        input_context = (set_context + pre_text_prompt + table_prompt + post_text_prompt + question_prompt + answer_instruction_prompt).strip()
        label_text = answer_prompt.strip()

        inputs = self.tokenizer( 
                    input_context, 
                    # max_length=self.max_length, 
                    # truncation=True, 
                    # padding=True,
                    # padding=True,
                    return_tensors="pt"
                )

        labels = self.tokenizer(
                    label_text, 
                    # max_length=self.max_length, 
                    # truncation=True, 
                    # padding=True, 
                    # padding=True,
                    return_tensors="pt"
                )

        input_ids = inputs["input_ids"].squeeze()
        attention_mask = inputs["attention_mask"].squeeze()
        labels = labels["input_ids"].squeeze()


        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
            # "input_context": input_context,
            # "label_text": label_text
        }

In [6]:
train_dataset_finqa = FinQADataset(train_data_finqa, base_model_tokenizer)
val_dataset_finqa = FinQADataset(val_data_finqa, base_model_tokenizer)
test_dataset_finqa = FinQADataset(test_data_finqa, base_model_tokenizer)

train_dataloader_finqa = DataLoader(
    train_dataset_finqa, 
    batch_size=1, 
    shuffle=True
)

val_dataloader_finqa = DataLoader(
    val_dataset_finqa, 
    batch_size=1, 
    shuffle=False
)

test_dataloader_finqa = DataLoader(
    test_dataset_finqa, 
    batch_size=1, 
    shuffle=False
)

In [7]:
next(iter(train_dataloader_finqa))

{'input_ids': tensor([[128000,   2675,    527,  ...,    720,  16533,     25]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]]),
 'labels': tensor([[128000,     59,  16533,     25,    220,   3971,     13,     15,      4]])}

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Inference on the base model

In [9]:
# def model_inference(model, test_dataloader, tokenizer, target_count=20):
import time
start = time.time()

base_model.eval()

generated_answers = []
reference_answers = []
reference_questions = []
sample_count = 0
target_count = 20

with torch.no_grad():
    for batch in test_dataloader_finqa:
        if sample_count >= target_count:
            break

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        batch_size = input_ids.shape[0]
        
        # Determine how many samples to take from this batch
        samples_needed = min(batch_size, target_count - sample_count)
        
        # Take only the needed samples from this batch
        batch_input_ids = input_ids[:samples_needed]
        batch_attention_mask = attention_mask[:samples_needed]
        batch_labels = labels[:samples_needed]
        
        # Decode questions
        batch_questions = base_model_tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
        # Generate answers
        # with torch.cuda.amp.autocast(enabled=True):
        outputs = base_model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_new_tokens=128,
            # do_sample=False,
            # num_beams=4,
            # top_k=10,
            # early_stopping=True,
            pad_token_id=base_model_tokenizer.eos_token_id,
            eos_token_id=base_model_tokenizer.eos_token_id,
            # num_return_sequences=1,
            # do_sample=True,  # Enable sampling for more diverse output
            # temperature=0.7,  # Moderate temperature for creativity
            # top_p=0.9,  # Nucleus sampling to reduce repetition
            # repetition_penalty=1.2,  # Slightly penalize repetition
            # no_repeat_ngram_size=2

        )
        
        # Decode generated answers
        batch_answers = base_model_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for batch_answer in batch_answers:
            if ("Answer:" in batch_answer):
                batch_answer = re.search(r'(?<=Answer:)\s?.+?(?=\\n|\n|$)', batch_answer).group(0)
            generated_answers.append(batch_answer.strip())
        
        # Add to our collections
        reference_questions.append(batch_questions)
        reference_answer = base_model_tokenizer.batch_decode(batch_labels, skip_special_tokens=True)
        reference_answer = re.search(r'(?<=Answer:\s).+(?=(\"|\')\])', str(reference_answer))
        if reference_answer:
            reference_answer = reference_answer.group(0)
        else:
            reference_answer = ""
        # list_element = re.search('(?<=\[).+(?=\])', str(reference_answer))
        # if list_element:
        #     reference_answer = list_element.group(0).replace(", ", "; ").replace("'", "")
        reference_answers.append(reference_answer)
        
        # Update count
        sample_count += samples_needed

print(f"Collected {len(reference_questions)} samples:")
for i, (question, answer) in enumerate(zip(reference_questions, generated_answers)):
    print(f"\nSample {i+1}:")
    print(f"Q: {question}")
    print(f"A: {answer}")

end = time.time()
print(f"\nTime taken: {end - start} seconds")




Collected 20 samples:

Sample 1:
Q: ['You are an intelligent financial data analyst. You are given a table with financial data. You are also given a pre-text and a post-text that provide some context about the data in the table. You are asked a question about the data in the table or paragraph. You are expected to answer the question based on the data in the table and the paragraph.\nThe pre_text provides some context about the data before the table. It may contain information that is not present in the table. It may also contain some numbers which might require arithmatic processing to get the answer. There may be multiple sentences separated by comma (,). Interpret each pre_text paragraph and use the data and description in it to infer the answer to the provided quesions.\n\nPre_text: [\'entergy corporation and subsidiaries management 2019s financial discussion and analysis a result of the entergy louisiana and entergy gulf states louisiana business combination, results of operations

In [None]:
# model_inference(base_model, test_dataloader_finqa, base_model_tokenizer, target_count=20)

In [10]:
generated_answers

['$ 94.0',
 '14.29%.}',
 '9.2% (Increase)  (note: the actual values for 2010 and 2011 are $139.9 and $153.7 respectively. So the percentage change is (153.7-139.9)/139.9 * 100 = 9.9% increase)  (Note: The actual percentage increase is 9.9% and not 9.2% as mentioned in the answer. The answer is provided as a label and not the actual answer. The actual answer is 9.9% increase)  (Note: The actual percentage increase is 9.9%',
 '2.91%',
 '',
 '',
 '0.1%',
 '11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7',
 'No. The s&p 5

In [11]:
generated_answers

['$ 94.0',
 '14.29%.}',
 '9.2% (Increase)  (note: the actual values for 2010 and 2011 are $139.9 and $153.7 respectively. So the percentage change is (153.7-139.9)/139.9 * 100 = 9.9% increase)  (Note: The actual percentage increase is 9.9% and not 9.2% as mentioned in the answer. The answer is provided as a label and not the actual answer. The actual answer is 9.9% increase)  (Note: The actual percentage increase is 9.9%',
 '2.91%',
 '',
 '',
 '0.1%',
 '11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7% (rounded to one decimal place) INSTRUCTION: Please provide a number that solves the problem. Do not provide any other text. 11.7',
 'No. The s&p 5

In [12]:
reference_answers

['94',
 '14%',
 '9.9%',
 '2.9%',
 '',
 '7%',
 '10%',
 '12%',
 'yes',
 '-35',
 '-4.4%',
 '22.4%',
 '24.6%',
 'no',
 '1.7%',
 '3298',
 '65',
 '704.25',
 '11.0%',
 '68.9']

In [34]:
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def clean_answer(text):
    """Extract and format numerical answer"""
#    if ':' in text:
#        text = text.split(':')[-1]

    # Handle percentages first
    # percent_match = re.search(r'[-+]?\d*\.?\d+\s*%?', text)
    percent_match = re.search(r'[-+]?\d*\.?\d+\s*%', text)
    if percent_match:
        number = float(percent_match.group(0).replace('%', '').strip())
        # If the number is small (likely decimal), convert to percentage
#        if abs(number) < 1:
#            number *= 100
        # Round to one decimal place and add % symbol
        return f"{round(number, 0)}%"
        # return str(number)

    # Handle regular numbers
    pre_decimal_match = text
    # if '=' in text:
    #     pre_decimal_match = re.search(r'(?<=\=\s).+', text, re.MULTILINE).group(0)
    decimal_match = re.search(r'[-+]?\d*,?\.?\d+', pre_decimal_match, re.MULTILINE)
    # decimal_match = re.search(r'(?<=\$)(\d{1,3}(,\d{3})*(\.\d+)?|\d+(\.\d+)?)', text)
    # if decimal_match:
    #     match_str = decimal_match.group(0)
    #     if is_number(match_str):
    #         number = float(match_str)
    #         if abs(round(number) - number) < 0.01:
    #             return str(round(number))
    #     else:
    #         number = match_str
    #     # If it's close to an integer, round it
    #     # Otherwise, round to one decimal place
    #     return str(number)
    if decimal_match:
        number = float(decimal_match.group(0).replace(',', ''))
        # If it's close to an integer, round it
        if abs(round(number) - number) < 0.01:
            return str(round(number))
        # Otherwise, round to one decimal place
        return str(round(number, 0))

    # Handle yes/no answers
    text = text.lower().strip()
    if 'yes' in text or 'true' in text:
        return 'yes'
    if 'no' in text or 'false' in text:
        return 'no'

    return text.strip()

In [14]:
total_samples = 0
correct_predictions = 0
threshold = 85  # Minimum similarity percentage for correct match
similarity_scores = []
SETTING = "FinQA"

# import time

# start = time.time()

for i, generated_answer in enumerate(generated_answers):
    total_samples += 1
    clean_p = clean_answer(generated_answer)
    reference_answer = reference_answers[i]
    clean_e = clean_answer(reference_answer)


    # Compute similarity score
    # similarity = fuzz.ratio(predicted_answer.lower(), expected_answer.lower())
    similarity = fuzz.ratio(clean_p.lower(), clean_e.lower())
    similarity_scores.append(similarity)
    print()
    print()
    print("************************************************************************")
    print (f"*** Input {total_samples} *** ")
#    print(example["input_text"])
    print ("*** Expected answer *** ")
    print(reference_answer)
    print ("*** Clean Expected answer *** ")
    print(clean_e)
    print ("*** Predicted answer *** ")
    print(generated_answer)
    print ("*** Clean Predicted answer *** ")
    print(clean_p)

    # Count as correct if similarity is above threshold
    if similarity >= threshold:
        correct_predictions += 1

# Print Results
accuracy = correct_predictions / total_samples * 100
avg_similarity = sum(similarity_scores) / total_samples

print(f"*************************************{SETTING}***********************************")

print(f"Total Samples: {total_samples}")
print(f"Correct Predictions: {correct_predictions}")
print(f"Accuracy: {accuracy:.2f}%")
print(f"Average Similarity Score: {avg_similarity:.2f}%")

# Block of code to time
# end = time.time()

# print("Elapsed time:", end - start, "seconds")
# print("Avg Inf time:", (end - start)/dataset.shape[0], "seconds")



************************************************************************
*** Input 1 *** 
*** Expected answer *** 
94
*** Clean Expected answer *** 
94
*** Predicted answer *** 
$ 94.0
*** Clean Predicted answer *** 
94


************************************************************************
*** Input 2 *** 
*** Expected answer *** 
14%
*** Clean Expected answer *** 
14.0%
*** Predicted answer *** 
14.29%.}
*** Clean Predicted answer *** 
14.0%


************************************************************************
*** Input 3 *** 
*** Expected answer *** 
9.9%
*** Clean Expected answer *** 
10.0%
*** Predicted answer *** 
9.2% (Increase)  (note: the actual values for 2010 and 2011 are $139.9 and $153.7 respectively. So the percentage change is (153.7-139.9)/139.9 * 100 = 9.9% increase)  (Note: The actual percentage increase is 9.9% and not 9.2% as mentioned in the answer. The answer is provided as a label and not the actual answer. The actual answer is 9.9% increase)  (Note: Th

## Model Fine-tuning using LoRA

In [10]:
for param in base_model.parameters():
    param.requires_grad = False
    if param.ndim == 1:
        param.data = param.data.to(torch.float32)

base_model.gradient_checkpointing_enable()
base_model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
    def forward(self, x):
        return super().forward(x).to(torch.float32)
base_model.lm_head = CastOutputToFloat(base_model.lm_head)

In [11]:
def print_trainable_parameters(base_model):
    trainable_params = 0
    all_params = 0
    for _, param in base_model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"Trainable parameters: {trainable_params} || All parameters: {all_params} || Trainable parameters %: {trainable_params/all_params*100}")

In [12]:
# Inspect the named modules in the model
for name, module in base_model.named_modules():
    print(name)


model
model.embed_tokens
model.layers
model.layers.0
model.layers.0.self_attn
model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.0.self_attn.o_proj
model.layers.0.mlp
model.layers.0.mlp.gate_proj
model.layers.0.mlp.up_proj
model.layers.0.mlp.down_proj
model.layers.0.mlp.act_fn
model.layers.0.input_layernorm
model.layers.0.post_attention_layernorm
model.layers.1
model.layers.1.self_attn
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.1.self_attn.o_proj
model.layers.1.mlp
model.layers.1.mlp.gate_proj
model.layers.1.mlp.up_proj
model.layers.1.mlp.down_proj
model.layers.1.mlp.act_fn
model.layers.1.input_layernorm
model.layers.1.post_attention_layernorm
model.layers.2
model.layers.2.self_attn
model.layers.2.self_attn.q_proj
model.layers.2.self_attn.k_proj
model.layers.2.self_attn.v_proj
model.layers.2.self_attn.o_proj
model.layers.2.mlp
model.layers.2.mlp.gate_proj
model.l

In [15]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    # target_modules="all-linear",
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], #, 'gate_proj', 'down_proj'],
    bias="none",
    task_type="CAUSAL_LM"
)

base_model = get_peft_model(base_model, config)
print_trainable_parameters(base_model)

Trainable parameters: 13631488 || All parameters: 8043892736 || Trainable parameters %: 0.16946382115456393


In [16]:
# training_args = TrainingArguments(
#     output_dir="./results",
#     # evaluation_strategy="steps",
#     # eval_steps=500,
#     # save_steps=500,
#     eval_strategy="epoch",
#     save_strategy="epoch",
#     per_device_train_batch_size=1,
#     per_device_eval_batch_size=1,
#     num_train_epochs=5,
#     # gradient_accumulation_steps=1,
#     warmup_steps=100,
#     max_steps=1000,
#     # weight_decay=0.01,
#     learning_rate=5e-3,
#     logging_dir="./logs",
#     logging_steps=1,
#     fp16=True,
#     report_to="none",
#     load_best_model_at_end=True,
#     label_names=["labels"],
#     # metric_for_best_model="accuracy"
# )

training_args = TrainingArguments(
    output_dir="./results",
    # evaluation_strategy="steps",
    # eval_steps=500,
    # save_steps=500,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    # gradient_accumulation_steps=1,
    warmup_steps=100,
    max_steps=1000,
    # num_train_epochs=5,
    # eval_strategy="epoch",
    # save_strategy="epoch",
    weight_decay=0.01,
    learning_rate=2e-4,
    logging_dir="./logs",
    logging_steps=1,
    fp16=True,
    report_to="none",
    # load_best_model_at_end=True,
)

base_model_tokenizer.pad_token = base_model_tokenizer.eos_token

trainer = Trainer(
    model=base_model,
    train_dataset=train_dataset_finqa,
    eval_dataset=val_dataset_finqa,
    tokenizer=base_model_tokenizer,
    args=training_args,
    data_collator=DataCollatorForLanguageModeling(tokenizer=base_model_tokenizer, mlm=False)
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  trainer = Trainer(


In [17]:
base_model.train()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): PeftModelForCausalLM(
      (base_model): LoraModel(
        (model): LlamaForCausalLM(
          (model): LlamaModel(
            (embed_tokens): Embedding(128256, 4096)
            (layers): ModuleList(
              (0-31): 32 x LlamaDecoderLayer(
                (self_attn): LlamaAttention(
                  (q_proj): lora.Linear8bitLt(
                    (base_layer): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.05, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=4096, out_features=16, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (default): Linear(in_features=16, out_features=4096, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
       

In [18]:
trainer.train()
trainer.save_model("./results_finqa/finqa_llama_8_v2")

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,1.7403
2,1.7983
3,1.4732
4,1.7428
5,1.9147
6,2.1488
7,1.8703
8,1.8458
9,1.6844
10,1.7714




In [19]:
trainer.evaluate()



{'eval_loss': 0.9093862771987915,
 'eval_runtime': 251.9875,
 'eval_samples_per_second': 3.504,
 'eval_steps_per_second': 3.504,
 'epoch': 0.15997440409534475}

In [20]:
ft_model_path = "./results_finqa/finqa_llama_8_v2"
lora_config = PeftConfig.from_pretrained(ft_model_path)
ft_model = PeftModel.from_pretrained(
    base_model, 
    ft_model_path, 
    config=lora_config,
    load_in_8bit=True,
    device_map='auto',
    trust_remote_code=True,
    torch_dtype=torch.float16
)
# ft_model = ft_model.to(device)



In [26]:
# def model_inference(model, test_dataloader, tokenizer, target_count=20):
ft_model.eval()

lora_generated_answers = []
lora_reference_answers = []
lora_reference_questions = []
sample_count = 0
target_count = 20

with torch.no_grad():
    for batch in test_dataloader_finqa:
        # if sample_count >= target_count:
        #     break

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        batch_size = input_ids.shape[0]
        
        # Determine how many samples to take from this batch
        # samples_needed = min(batch_size, target_count - sample_count)
        
        # Take only the needed samples from this batch
        # batch_input_ids = input_ids[:samples_needed]
        # batch_attention_mask = attention_mask[:samples_needed]
        # batch_labels = labels[:samples_needed]

        batch_input_ids = input_ids
        batch_attention_mask = attention_mask
        batch_labels = labels
        
        # Decode questions
        batch_questions = base_model_tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
        # Generate answers
        with torch.cuda.amp.autocast(enabled=True):
            outputs = ft_model.generate(
                input_ids=batch_input_ids,
                attention_mask=batch_attention_mask,
                max_new_tokens=128,
                # do_sample=False,
                # num_beams=4,
                # top_k=10,
                # early_stopping=True,
                # pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
                pad_token_id=base_model_tokenizer.eos_token_id,
                eos_token_id=base_model_tokenizer.eos_token_id
            )
        
        # Decode generated answers
        batch_answers = base_model_tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for batch_answer in batch_answers:
            if ("Answer:" in batch_answer):
                batch_answer = re.search(r'(?<=Answer:)\s?.+?(?=\\n|\n|$)', batch_answer).group(0)
            lora_generated_answers.append(batch_answer.strip())
        
        # Add to our collections
        lora_reference_questions.append(batch_questions)
        lora_reference_answer = base_model_tokenizer.batch_decode(batch_labels, skip_special_tokens=True)
        lora_reference_answer = re.search(r'(?<=Answer:\s).+(?=(\"|\')\])', str(lora_reference_answer))
        if lora_reference_answer:
            lora_reference_answer = lora_reference_answer.group(0)
        else:
            lora_reference_answer = ""
        # list_element = re.search('(?<=\[).+(?=\])', str(reference_answer))
        # if list_element:
        #     reference_answer = list_element.group(0).replace(", ", "; ").replace("'", "")
        lora_reference_answers.append(lora_reference_answer)
        
        # Update count
        sample_count += samples_needed

print(f"Collected {len(lora_reference_questions)} samples:")
for i, (question, answer) in enumerate(zip(lora_reference_questions, lora_generated_answers)):
    print(f"\nSample {i+1}:")
    print(f"Q: {question}")
    print(f"A: {answer}")


  with torch.cuda.amp.autocast(enabled=True):


Collected 1147 samples:

Sample 1:
Q: ['You are an intelligent financial data analyst. You are given a table with financial data. You are also given a pre-text and a post-text that provide some context about the data in the table. You are asked a question about the data in the table or paragraph. You are expected to answer the question based on the data in the table and the paragraph.\nThe pre_text provides some context about the data before the table. It may contain information that is not present in the table. It may also contain some numbers which might require arithmatic processing to get the answer. There may be multiple sentences separated by comma (,). Interpret each pre_text paragraph and use the data and description in it to infer the answer to the provided quesions.\n\nPre_text: [\'entergy corporation and subsidiaries management 2019s financial discussion and analysis a result of the entergy louisiana and entergy gulf states louisiana business combination, results of operatio

In [35]:
total_samples = 0
correct_predictions = 0
threshold = 85  # Minimum similarity percentage for correct match
similarity_scores = []
SETTING = "FinQA"

# import time

# start = time.time()

for i, lora_generated_answer in enumerate(lora_generated_answers):
    total_samples += 1
    clean_p = clean_answer(lora_generated_answer)
    lora_reference_answer = lora_reference_answers[i]
    clean_e = clean_answer(lora_reference_answer)


    # Compute similarity score
    # similarity = fuzz.ratio(predicted_answer.lower(), expected_answer.lower())
    similarity = fuzz.ratio(clean_p.lower(), clean_e.lower())
    similarity_scores.append(similarity)
    print()
    print()
    print("************************************************************************")
    print (f"*** Input {total_samples} *** ")
#    print(example["input_text"])
    print ("*** Expected answer *** ")
    print(lora_reference_answer)
    print ("*** Clean Expected answer *** ")
    print(clean_e)
    print ("*** Predicted answer *** ")
    print(lora_generated_answer)
    print ("*** Clean Predicted answer *** ")
    print(clean_p)

    # Count as correct if similarity is above threshold
    if similarity >= threshold:
        correct_predictions += 1

# Print Results
accuracy = correct_predictions / total_samples * 100
avg_similarity = sum(similarity_scores) / total_samples

print(f"*************************************{SETTING}***********************************")

print(f"Total Samples: {total_samples}")
print(f"Correct Predictions: {correct_predictions}")
print(f"Accuracy: {accuracy:.2f}%")
print(f"Average Similarity Score: {avg_similarity:.2f}%")

# Block of code to time
# end = time.time()

# print("Elapsed time:", end - start, "seconds")
# print("Avg Inf time:", (end - start)/dataset.shape[0], "seconds")



************************************************************************
*** Input 1 *** 
*** Expected answer *** 
94
*** Clean Expected answer *** 
94
*** Predicted answer *** 
$94 INSTRUCTION: Provide the final answer based on the data in the table and the pre-text and the post-text. $94.  Step 1: Analyze the table to identify the column and row that contain the data needed to answer the question.
*** Clean Predicted answer *** 
94


************************************************************************
*** Input 2 *** 
*** Expected answer *** 
14%
*** Clean Expected answer *** 
14.0%
*** Predicted answer *** 
14.4
*** Clean Predicted answer *** 
14.0


************************************************************************
*** Input 3 *** 
*** Expected answer *** 
9.9%
*** Clean Expected answer *** 
10.0%
*** Predicted answer *** 

*** Clean Predicted answer *** 



************************************************************************
*** Input 4 *** 
*** Expected answer ***