Copyright 2024 The Chain-of-Table authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Demo of Chain of Tables

Paper: https://arxiv.org/abs/2401.04398

In [1]:
import pandas as pd

from utils.load_data import wrap_input_for_demo
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *
from utils.load_data import load_tabfact_dataset
import os

In [2]:
# User parameters
model_name: str = "gpt-4o-mini"
openai_api_key: str = os.environ["OPENAI_API_KEY"]

# Default parameters
dataset_path: str = "data/tabfact/test.jsonl"
raw2clean_path: str = "data/tabfact/raw2clean.jsonl"

In [3]:
dataset = load_tabfact_dataset(dataset_path, raw2clean_path)

Loading tabfact-test dataset: 100%|██████████| 2024/2024 [00:00<00:00, 55996.35it/s]


In [9]:
test_sample= dataset[2]
table_text = test_sample['table_text']
answer = True if test_sample['label'] == 1 else False
test_sample

{'statement': 'the wildcats lost one game in september and two games in november',
 'label': 1,
 'table_caption': '1947 kentucky wildcats football team',
 'table_text': [['game',
   'date',
   'opponent',
   'result',
   'wildcats points',
   'opponents',
   'record'],
  ['1', 'sept 20', 'ole miss', 'loss', '7', '14', '0 - 1'],
  ['2', 'sept 27', 'cincinnati', 'win', '20', '0', '1 - 1'],
  ['3', 'oct 4', 'xavier', 'win', '20', '7', '2 - 1'],
  ['4', 'oct 11', '9 georgia', 'win', '26', '0', '3 - 1 , 20'],
  ['5', 'oct 18', '10 vanderbilt', 'win', '14', '0', '4 - 1 , 14'],
  ['6', 'oct 25', 'michigan state', 'win', '7', '6', '5 - 1 , 13'],
  ['7', 'nov 1', '18 alabama', 'loss', '0', '13', '5 - 2'],
  ['8', 'nov 8', 'west virginia', 'win', '15', '6', '6 - 2'],
  ['9', 'nov 15', 'evansville', 'win', '36', '0', '7 - 2'],
  ['10', 'nov 22', 'tennessee', 'loss', '6', '13', '7 - 3']],
 'table_id': '1-24560733-1.html.csv',
 'id': 'test-2',
 'chain': [],
 'cleaned_statement': 'the wildcat lose 1

In [10]:
gpt_llm = ChatGPT(
    model_name=model_name,
    key=openai_api_key,
)

In [11]:
proc_sample, dynamic_chain_log = dynamic_chain_exec_one_sample(
    sample=test_sample, llm=gpt_llm, debug=True
)


Act Chain:  []
Kept Act Chain:  []
Skip Act Chain:  []
Last Operation:  <init>
Possible Next Operations:  ['add_column', 'select_row', 'select_column', 'group_column', 'sort_column']
explanation='To verify the statement, we need to focus on the games played in September and November and check the results to count the losses in these months.' operationchain=[<Operation.SELECT_ROW: 'SELECT_ROW'>, <Operation.SELECT_COLUMN: 'SELECT_COLUMN'>, <Operation.GROUP_COLUMN: 'GROUP_COLUMN'>, <Operation.END: 'END'>]
select_row
Act Chain:  ['f_select_row(row 1, row 2, row 3)']
Kept Act Chain:  ['f_select_row(row 1, row 2, row 3)']
Skip Act Chain:  []
Last Operation:  select_row
Possible Next Operations:  ['select_column', 'group_column', 'sort_column', '<END>']
explanation='To verify the statement about the number of games lost in specific months, we need to focus on the game dates and their results.' operationchain=[<Operation.END: 'END'>]
<END>


In [12]:
output_sample = simple_query(
    sample=proc_sample,
    table_info=get_table_info(proc_sample),
    debug=True

)
cotable_log = get_table_log(output_sample)

explanation='By analyzing the dates and results, the Wildcats lost one game in September (Row 1) and two games in November (Row 2 and Row 3).' answer=<Answers.TRUE: 'TRUE'>
Table: table caption : 1947 kentucky wildcats football team
col : game | date | opponent | result | wildcats points | opponents | record
row 1 : 1 | sept 20 | ole miss | loss | 7 | 14 | 0 - 1
row 2 : 7 | nov 1 | 18 alabama | loss | 0 | 13 | 5 - 2
row 3 : 10 | nov 22 | tennessee | loss | 6 | 13 | 7 - 3
Statement: the wildcats lost one game in september and two games in november
Answer: YES
Explanation: By analyzing the dates and results, the Wildcats lost one game in September (Row 1) and two games in November (Row 2 and Row 3).


In [13]:
print(f'Statements: {output_sample["statement"]}\n')
print(f'Table: {output_sample["table_caption"]}')
print(f"{pd.DataFrame(table_text[1:], columns=table_text[0])}\n")
for table_info in cotable_log:
    if table_info["act_chain"]:
        table_text = table_info["table_text"]
        table_action = table_info["act_chain"][-1]
        if "skip" in table_action:
            continue
        if "query" in table_action:
            result = table_info["cotable_result"]
            if result == "YES":
                print(f"-> {table_action}\nThe statement is True\n")
            else:
                print(f"-> {table_action}\nThe statement is False\n")
        else:
            print(f"-> {table_action}\n{pd.DataFrame(table_text[1:], columns=table_text[0])}")
            if 'group_sub_table' in table_info:
                group_column, group_info = table_info["group_sub_table"]
                group_headers = ["Group ID", group_column, "Count"]
                group_rows = []
                for i, (v, count) in enumerate(group_info):
                    if v.strip() == "":
                        v = "[Empty Cell]"
                    group_rows.append([f"Group {i+1}", v, str(count)])
                print(f"{pd.DataFrame(group_rows, columns=group_headers)}")
            print()

print (f"We Answered With: {cotable_log[-1]['cotable_result']}")
print(f"Groundtruth: The statement is {answer}")

Statements: the wildcats lost one game in september and two games in november

Table: 1947 kentucky wildcats football team
  game     date        opponent result wildcats points opponents      record
0    1  sept 20        ole miss   loss               7        14       0 - 1
1    2  sept 27      cincinnati    win              20         0       1 - 1
2    3    oct 4          xavier    win              20         7       2 - 1
3    4   oct 11       9 georgia    win              26         0  3 - 1 , 20
4    5   oct 18   10 vanderbilt    win              14         0  4 - 1 , 14
5    6   oct 25  michigan state    win               7         6  5 - 1 , 13
6    7    nov 1      18 alabama   loss               0        13       5 - 2
7    8    nov 8   west virginia    win              15         6       6 - 2
8    9   nov 15      evansville    win              36         0       7 - 2
9   10   nov 22       tennessee   loss               6        13       7 - 3

-> f_select_row(row 0, row 6,