In [1]:
import pandas as pd
import datasets
from bs4 import BeautifulSoup
from io import StringIO
import csv
import torch
import difflib
from transformers import (
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
    GenerationConfig,
    TextStreamer,
    pipeline,
)


from tqdm import tqdm
import pandas as pd
import re

data = pd.read_csv('../data/pseudo_json_dataset_V3_2048.csv')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = 'RUCKBReasoning/TableLLM-7b'
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [3]:
model = AutoModelForCausalLM.from_pretrained(
    model_id, device_map="auto", torch_dtype=torch.float16,
)
generation_config = GenerationConfig.from_pretrained(model_id)
print(generation_config)


generation_config.max_new_tokens = 512
generation_config.temperature = 0.0001
generation_config.do_sample = True
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it]


GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}



In [4]:
llm = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=True,
    generation_config=generation_config,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
    # streamer=streamer,
)


In [5]:
def extract(ip):
    # Find the index of "[/INST]"
    index = ip.find("[/INST]")

    # Extract the value after "[/INST]"
    if index != -1:
        extracted_value = ip[index + len("[/INST]") + 1:].strip()  # Add 1 to skip the "]" character
        return extracted_value
    else:
        print("No value found after '[/INST]'")

def extract_second(ip):
   # Find the first occurrence of "[/INST]"
    first_inst_index = ip.find("[/INST]")

    # Find the second occurrence of "[/INST]" starting from the position after the first one
    second_inst_index = ip.find("[/INST]", first_inst_index + len("[/INST]"))

    # Extract the value after the second "[/INST]"
    if second_inst_index != -1:
        extracted_value = ip[second_inst_index + len("[/INST]") + 1:].strip()  # Add 1 to skip the "]" character
        return extracted_value
    else:
        print("Two occurrences of '[/INST]' not found")


In [6]:
output_df = pd.DataFrame(columns=['csv', 'question', 'answer', 'prediction'])

In [7]:
def prompt_formatter(csv_data, question):
    ip = f"""[INST]Offer a thorough and accurate solution that directly addresses the Question outlined in the [Question]. You have to only give extract the exact answer to the user.

    ### [Table]
    {csv_data}

    ### [Question]
    {question}

    ### [Solution]
    [/INST]"""
    return ip 

In [8]:
def twoshot_prompt_formatter(csv_data, question):
    ip = f"""[INST]Offer a thorough and accurate solution that directly addresses the Question outlined in the [Question].

    ### [Table]
    1,-,2002,2001,2000
    2,net sales,$ 5742,$ 5363,$ 7983
    3,cost of sales,4139,4128,5817
    4,gross margin,$ 1603,$ 1235,$ 2166
    5,gross margin percentage,28% ( 28 % ),23% ( 23 % ),27% ( 27 % )


    ### [Question]
    what is the gross margin in 2002?

    ### [Solution]
    [/INST] $1603.

    [INST]Offer a thorough and accurate solution that directly addresses the Question outlined in the [Question].

    ### [Table]
    {csv_data}

    ### [Question]
    {question}

    ### [Solution]
    [/INST]"""
    return ip 

In [9]:
# with open(filename, 'w') as file:
for i, row in tqdm(data[10:50].iterrows()):
    soup = BeautifulSoup(row['html_content'], 'html.parser')

    # Find the table
    table = soup.find('table')
    # Open StringIO object for writing
    csv_buffer = StringIO()


    # Write rows to the StringIO object
    writer = csv.writer(csv_buffer)
    trows = table.find_all('tr')
    for trow in trows:
        writer.writerow([cell.get_text() for cell in trow.find_all(['td', 'th'])])


    # Get the CSV data as a string
    csv_data = csv_buffer.getvalue()

    ip = twoshot_prompt_formatter(csv_data, row['question'])

    output = llm(ip)[0]['generated_text']
    print(i)
    print(row['answer'])

    pred = extract_second(output) 
    print(pred)
    print('***')
    output_df.at[i, 'csv'] = csv_data
    output_df.at[i, 'question'] = data.loc[i,'question']
    output_df.at[i, 'answer'] = data.loc[i,'answer']
    output_df.at[i, 'prediction'] = pred

1it [02:07, 127.95s/it]

10
193.3
The table shows the global payments data for May 31, 2005. The gross margin for that date is $193.30. Therefore, the answer to the question is $193.30.
***


2it [04:15, 127.79s/it]

11
39
To find the total for 2010, we need to look at the "total (in $ millions)" column in the table. From the given table, we can see that the total for 2010 is $39 million.
***


3it [05:26, 101.88s/it]

12
$ -
The credit facilities [a] for amount of commitment expiration per period after 2014 is $1900.
***


4it [07:19, 106.12s/it]

13
3.4% ( 3.4 % )
To determine the core price for 2006, we need to refer to the table provided.

According to the table, the core price for 2006 is $2.4.
***


5it [08:05, 84.69s/it] 

14
['-86 ( 86 )' '$ 56.53' '$ 999343' '$ 926512']
The value for canceled in the table is $999343.
***


6it [09:04, 75.82s/it]

15
2.3
The weighted average dilutive effect of equity awards for 2018 is 2.3.
***


7it [10:03, 70.15s/it]

16
$311
According to the table, the interest and dividend income for 2010 is $311.
***


8it [11:11, 69.49s/it]

17
$24490
The cash cash equivalents and marketable securities for 2008 is $24490.
***


9it [14:52, 116.96s/it]

18
$77.15
To find the average price paid per share in October 2017, we can refer to the table provided. 

In the table, the average price paid per share is listed under the "october 2017" column. The value in this column is $77.15.

Therefore, the average price paid per share in October 2017 is $77.15.
***


: 

In [None]:
output_df.to_csv('one_shot_test_v1.csv', index=False)

NameError: name 'output_df' is not defined

: 

In [None]:

    
import os
 
# Get the process ID of the current Jupyter Notebook server
pid = os.getpid()
pid

11719

: 