Copyright 2024 The Chain-of-Table authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Demo of Chain of Tables

Paper: https://arxiv.org/abs/2401.04398

In [1]:
import pandas as pd

from utils.load_data import wrap_input_for_demo
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *
from utils.load_data import load_tabfact_dataset
import os

In [2]:
# User parameters
model_name: str = "gpt-4o-mini"
openai_api_key: str = os.environ["OPENAI_API_KEY"]

# Default parameters
dataset_path: str = "data/tabfact/test.jsonl"
raw2clean_path: str = "data/tabfact/raw2clean.jsonl"

In [3]:
dataset = load_tabfact_dataset(dataset_path, raw2clean_path)

Loading tabfact-test dataset: 100%|██████████| 2024/2024 [00:00<00:00, 53015.53it/s]


In [21]:
test_sample= dataset[75]
table_text = test_sample['table_text']
answer = True if test_sample['label'] == 1 else False
test_sample

{'statement': 'the geo id for johnstown is 3810340820',
 'label': 0,
 'table_caption': 'list of townships in north dakota',
 'table_text': [['township',
   'county',
   'pop (2010)',
   'land ( sqmi )',
   'water (sqmi)',
   'latitude',
   'longitude',
   'geo id',
   'ansi code'],
  ['jackson',
   'sargent',
   '33',
   '35.809',
   '0.0',
   '46.066276',
   '- 97.945530',
   '3808140460',
   '1036797'],
  ['james hill',
   'mountrail',
   '32',
   '31.82',
   '4.243',
   '48.423125',
   '- 102.429934',
   '3806140500',
   '1037048'],
  ['james river valley',
   'dickey',
   '40',
   '28.597',
   '0.0',
   '46.246641',
   '- 98.188329',
   '3802140540',
   '1036767'],
  ['janke',
   'logan',
   '28',
   '35.995',
   '0.163',
   '46.415512',
   '- 99.131701',
   '3804740620',
   '1037193'],
  ['jefferson',
   'pierce',
   '45',
   '35.069',
   '1.125',
   '48.232149',
   '- 100.182370',
   '3806940700',
   '1759556'],
  ['jim river valley',
   'stutsman',
   '38',
   '34.134',
   '1.74

In [22]:
gpt_llm = ChatGPT(
    model_name=model_name,
    key=openai_api_key,
)

In [23]:
proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(
    sample=test_sample, llm=gpt_llm, debug=True
)


Act Chain:  []
Kept Act Chain:  []
Skip Act Chain:  []
Last Operation:  <init>
Possible Next Operations:  ['add_column', 'select_row', 'select_column', 'group_column', 'sort_column']
possible_statement_interpretations=['The geo id for the township of johnstown is 3810340820.'] explanation="We need to check the 'geo id' value for the row where the 'township' is 'johnstown'." operationchain=['select_row', 'select_column', 'END']
select_row
Act Chain:  ['f_select_row(row 1)']
Kept Act Chain:  ['f_select_row(row 1)']
Skip Act Chain:  []
Last Operation:  select_row
Possible Next Operations:  ['select_column', 'group_column', 'sort_column', '<END>']
possible_statement_interpretations=['The geo id for the township named Johnstown is 3810340820'] explanation="The row containing Johnstown has been selected, so the next step is to select the 'geo id' column to verify if it matches the given value." operationchain=['select_column']
select_column
Act Chain:  ['f_select_row(row 1)', 'f_select_colum

In [24]:
output_sample = simple_query(
    sample=proc_sample,
    table_info=get_table_info(proc_sample),
    debug=True

)
cotable_log = get_table_log(output_sample)

explanation='The table shows that the geo id for johnstown is 3803540940, which is different from the claimed 3810340820.' answer=<Answers.FALSE: 'FALSE'>
Table: table caption : list of townships in north dakota
col : township | geo id
row 1 : johnstown | 3803540940
Statement: the geo id for johnstown is 3810340820
Answer: NO
Explanation: The table shows that the geo id for johnstown is 3803540940, which is different from the claimed 3810340820.


In [25]:
print(f'Statements: {output_sample["statement"]}\n')
print(f'Table: {output_sample["table_caption"]}')
print(f"{pd.DataFrame(table_text[1:], columns=table_text[0])}\n")
for table_info in cotable_log:
    if table_info["act_chain"]:
        table_text = table_info["table_text"]
        table_action = table_info["act_chain"][-1]
        if "skip" in table_action:
            continue
        if "query" in table_action:
            result = table_info["cotable_result"]
            if result == "YES":
                print(f"-> {table_action}\nThe statement is True\n")
            else:
                print(f"-> {table_action}\nThe statement is False\n")
        else:
            print(f"-> {table_action}\n{pd.DataFrame(table_text[1:], columns=table_text[0])}")
            if 'group_sub_table' in table_info:
                group_column, group_info = table_info["group_sub_table"]
                group_headers = ["Group ID", group_column, "Count"]
                group_rows = []
                for i, (v, count) in enumerate(group_info):
                    if v.strip() == "":
                        v = "[Empty Cell]"
                    group_rows.append([f"Group {i+1}", v, str(count)])
                print(f"{pd.DataFrame(group_rows, columns=group_headers)}")
            print()

print (f"We Answered With: {cotable_log[-1]['cotable_result']}")
print(f"Groundtruth: The statement is {answer}")

Statements: the geo id for johnstown is 3810340820

Table: list of townships in north dakota
             township       county pop (2010) land ( sqmi ) water (sqmi)  \
0             jackson      sargent         33        35.809          0.0   
1          james hill    mountrail         32         31.82        4.243   
2  james river valley       dickey         40        28.597          0.0   
3               janke        logan         28        35.995        0.163   
4           jefferson       pierce         45        35.069        1.125   
5    jim river valley     stutsman         38        34.134        1.746   
6             johnson        wells         36        35.299        0.908   
7           johnstown  grand forks         79        36.199          0.0   
8            joliette      pembina         67        70.044        0.771   

    latitude     longitude      geo id ansi code  
0  46.066276   - 97.945530  3808140460   1036797  
1  48.423125  - 102.429934  3806140500   103