# Chain of Table Notebook

<a href="https://colab.research.google.com/github/run-llama/llama-hub/blob/main/llama_hub/llama_packs/tables/chain_of_table/chain_of_table.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook we highlight our implementation of the ["Chain of Table" paper by Wang et al.](https://arxiv.org/pdf/2401.04398v1.pdf).

Chain-of-Table proposes the following: given a user query over tabular data, plan out a sequence of tabular operations over the table to retrieve the right information in order to satisfy the user query. The updated table is explicitly used/modified throughout the intermediate chain (unlike chain-of-thought/ReAct which uses generic thoughts). 

There is a fixed set of tabular operations that are defined in the paper:
- `f_add_column`
- `f_select_row`
- `f_select_column`
- `f_group_by`
- `f_sort_by`

We implemented the paper based on the prompts described in the paper, and adapted it to get it working. That said, this is marked as beta, so there may still be kinks to work through. Do you have suggestions / contributions on how to improve the robustness? Let us know! 

## Download Data

We use the [WikiTableQuestions dataset](https://ppasupat.github.io/WikiTableQuestions/) (Pasupat and Liang 2015) as our test dataset.

WikiTableQuestions is a question-answering dataset over various semi-structured tables taken from Wikipedia. These tables range in size from a few rows/columns to mnay rows. Some columns may contain multi-part information as well (e.g. a temperature column may contain both Fahrenheight and Celsius).

In [None]:
# !wget "https://github.com/ppasupat/WikiTableQuestions/releases/download/v1.0.2/WikiTableQuestions-1.0.2-compact.zip" -O data.zip
# !unzip data.zip

Let's visualize a sample df.

In [1]:
import pandas as pd

df = pd.read_csv("data/raw//WikiTableQuestions/csv/200-csv/3.csv")
df

Unnamed: 0,Year,Winner,Jockey,Trainer,Owner,Breeder
0,1919,Sir Barton,Johnny Loftus,H. Guy Bedwell,J. K. L. Ross,
1,1930,Gallant Fox,Earl Sande,Jim Fitzsimmons,Belair Stud,Belair Stud
2,1935,Omaha,Willie Saunders,Jim Fitzsimmons,Belair Stud,Belair Stud
3,1937,War Admiral,Charley Kurtsinger,George H. Conway,Samuel D. Riddle,Samuel D. Riddle
4,1941,Whirlaway,Eddie Arcaro,Ben A. Jones,Calumet Farm,Calumet Farm
5,1943,Count Fleet,Johnny Longden,Don Cameron,Fannie Hertz,Fannie Hertz
6,1946,Assault,Warren Mehrtens,Max Hirsch,King Ranch,King Ranch
7,1948,Citation,Eddie Arcaro,Horace A. Jones,Calumet Farm,Calumet Farm
8,1973,Secretariat,Ron Turcotte,Lucien Laurin,Meadow Stable,Meadow Stud
9,1977,Seattle Slew,Jean Cruguet,"William H. Turner, Jr.",Karen L. Taylor,Ben S. Castleman


## Load Pack / Setup

Now we do `download_llama_pack` to load the Chain-Of-Table LlamaPack (you can also import the module directly if using the llama-hub package).

We will also optionally setup observability/tracing so we can observe the intermediate steps.

In [3]:
# Option: if developing with the llama_hub package
# from llama_hub.llama_packs.tables.chain_of_table.base import (
#     ChainOfTableQueryEngine,
#     serialize_table
# )

# Option: download llama_pack
from llama_index.core.llama_pack import download_llama_pack

download_llama_pack(
    "ChainOfTablePack",
    "./chain_of_table_pack",
#     skip_load=True,
    # leave the below line commented out if using the notebook on main
    # llama_hub_url="https://raw.githubusercontent.com/run-llama/llama-hub/jerry/add_chain_of_table/llama_hub"
)
from chain_of_table_pack.llama_index.packs.tables.chain_of_table.base import ChainOfTableQueryEngine, serialize_table

In [4]:
import os
from llama_index.llms.openai import OpenAI
from configparser import ConfigParser

config=ConfigParser()
config.read('conf/conf.ini')
os.environ["OPENAI_API_KEY"] = config['openai']['apikey']

llm = OpenAI(model="gpt-4-1106-preview")

### Optional: Setup Observability

Here we will use our Arize Phoenix integration to view traces through the query engine.

In [7]:
import phoenix as px
import llama_index
from llama_index.core import set_global_handler

px.launch_app()
set_global_handler("arize_phoenix")

🌍 To view the Phoenix app in your browser, visit http://localhost:6006/
📺 To view the Phoenix app in a notebook, run `px.active_session().view()`
📖 For more information on how to use Phoenix, check out https://docs.arize.com/phoenix


## Try out some Queries

Now let's try out our `ChainOfTableQueryEngine`!

We run it over a few different tables.

### Example: Movie Awards Table

In [8]:
import pandas as pd

df = pd.read_csv("data/raw//WikiTableQuestions/csv/200-csv/11.csv")

In [9]:
df

Unnamed: 0,Award,Category,Nominee,Result
0,"Academy Awards, 1972",Best Picture,Phillip D'Antoni,Won
1,"Academy Awards, 1972",Best Director,William Friedkin,Won
2,"Academy Awards, 1972",Best Actor,Gene Hackman,Won
3,"Academy Awards, 1972",Best Adapted Screenplay,Ernest Tidyman,Won
4,"Academy Awards, 1972",Film Editing,Gerald B. Greenberg,Won
5,"Academy Awards, 1972",Best Supporting Actor,Roy Scheider,Nominated
6,"Academy Awards, 1972",Best Cinematography,Owen Roizman,Nominated
7,"Academy Awards, 1972",Best Sound,Theodore Soderberg\nChristopher Newman,Nominated
8,"American Cinema Editors, 1972",Best Edited Feature Film,Gerald B. Greenberg,Nominated
9,"BAFTA, 1972",Best Actor,Gene Hackman,Won


In [10]:
query_engine = ChainOfTableQueryEngine(df, llm=llm, verbose=True)

In [11]:
response = query_engine.query("Who won best Director in the 1972 Academy Awards?")

[1;3;32m> Iteration: 0
[0m[1;3;34m> Current table:
col : Award | Category | Nominee | Result
row 1 : Academy Awards, 1972 | Best Picture | Phillip D'Antoni | Won
row 2 : Academy Awards, 1972 | Best Director | William Friedkin | Won
row 3 : Academy Awards, 1972 | Best Actor | Gene Hackman | Won
row 4 : Academy Awards, 1972 | Best Adapted Screenplay | Ernest Tidyman | Won
row 5 : Academy Awards, 1972 | Film Editing | Gerald B. Greenberg | Won
row 6 : Academy Awards, 1972 | Best Supporting Actor | Roy Scheider | Nominated
row 7 : Academy Awards, 1972 | Best Cinematography | Owen Roizman | Nominated
row 8 : Academy Awards, 1972 | Best Sound | Theodore Soderberg\nChristopher Newman | Nominated
row 9 : American Cinema Editors, 1972 | Best Edited Feature Film | Gerald B. Greenberg | Nominated
row 10 : BAFTA, 1972 | Best Actor | Gene Hackman | Won
row 11 : BAFTA, 1972 | Best Film Editing | Gerald B. Greenberg | Won
row 12 : BAFTA, 1972 | Best Direction | William Friedkin | Nominated
row 13 

In [12]:
str(response.response)

'assistant: William Friedkin.'

### Example: Yearly Temperature and Precipitation

This table is interesting the cells for the first three rows contain two values (e.g. C/F or mm/inches).

Let's see if chain-of-table can handle this question.

In [14]:
import pandas as pd

df = pd.read_csv("data/raw//WikiTableQuestions/csv/200-csv/42.csv")

In [15]:
df

Unnamed: 0,Month,Jan,Feb,Mar,Apr,May,Jun,Jul,Aug,Sep,Oct,Nov,Dec,Year
0,Average high °C (°F),17.3\n(63.1),19.5\n(67.1),22.6\n(72.7),25.9\n(78.6),27.2\n(81),29.3\n(84.7),31.8\n(89.2),31.4\n(88.5),28.9\n(84),25.5\n(77.9),21.7\n(71.1),19.2\n(66.6),24.76\n(76.57)
1,Average low °C (°F),7.9\n(46.2),9.4\n(48.9),12.5\n(54.5),17.6\n(63.7),19.2\n(66.6),21.6\n(70.9),23.8\n(74.8),22.5\n(72.5),20.7\n(69.3),16.5\n(61.7),14.1\n(57.4),8.5\n(47.3),15.94\n(60.69)
2,Precipitation mm (inches),235.9\n(9.287),129.2\n(5.087),82.8\n(3.26),33.6\n(1.323),4.7\n(0.185),0.2\n(0.008),0.0\n(0),0.2\n(0.008),3.2\n(0.126),58.0\n(2.283),107.4\n(4.228),214.5\n(8.445),857.3\n(33.752)
3,Avg. precipitation days,13.9,11.4,8.6,3.6,2.4,0.1,0.0,0.1,1.8,4.9,8.0,11.8,63.7


In [16]:
query_engine = ChainOfTableQueryEngine(df, llm=llm, verbose=True)

In [17]:
response = query_engine.query("What was the precipitation in inches during June?")

[1;3;32m> Iteration: 0
[0m[1;3;34m> Current table:
col : Month | Jan | Feb | Mar | Apr | May | Jun | Jul | Aug | Sep | Oct | Nov | Dec | Year
row 1 : Average high °C (°F) | 17.3\n(63.1) | 19.5\n(67.1) | 22.6\n(72.7) | 25.9\n(78.6) | 27.2\n(81) | 29.3\n(84.7) | 31.8\n(89.2) | 31.4\n(88.5) | 28.9\n(84) | 25.5\n(77.9) | 21.7\n(71.1) | 19.2\n(66.6) | 24.76\n(76.57)
row 2 : Average low °C (°F) | 7.9\n(46.2) | 9.4\n(48.9) | 12.5\n(54.5) | 17.6\n(63.7) | 19.2\n(66.6) | 21.6\n(70.9) | 23.8\n(74.8) | 22.5\n(72.5) | 20.7\n(69.3) | 16.5\n(61.7) | 14.1\n(57.4) | 8.5\n(47.3) | 15.94\n(60.69)
row 3 : Precipitation mm (inches) | 235.9\n(9.287) | 129.2\n(5.087) | 82.8\n(3.26) | 33.6\n(1.323) | 4.7\n(0.185) | 0.2\n(0.008) | 0.0\n(0) | 0.2\n(0.008) | 3.2\n(0.126) | 58.0\n(2.283) | 107.4\n(4.228) | 214.5\n(8.445) | 857.3\n(33.752)
row 4 : Avg. precipitation days | 13.9 | 11.4 | 8.6 | 3.6 | 2.4 | 0.1 | 0.0 | 0.1 | 1.8 | 4.9 | 8.0 | 11.8 | 63.7


[0m[1;3;38;5;200m> New Operation + Args: f_select_row([

In [18]:
str(response)

'assistant: 0.008 inches'

#### Try out a Baseline

As an example lets take our LLM and see if it can directly answer the question by dumping the table into the prompt! 

We can construct this concisely using our query pipeline syntax (you can, of course, just call the prompt/llm directly)

In [16]:
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import QueryPipeline

prompt_str = """\
Here's a serialized table.

{serialized_table}

Given this table please answer the question: {question}
Answer: """
prompt = PromptTemplate(prompt_str)
prompt_c = prompt.as_query_component(partial={"serialized_table": serialize_table(df)})

In the response below, we see that the right row is identified, but it mistakenly identifies 0.2 as the inches instead of 0.008.

In [18]:
qp = QueryPipeline(chain=[prompt_c, llm])
response = qp.run("What was the precipitation in inches during June?")
print(str(response))

assistant: The precipitation in inches during June is given in row 3 under the "Jun" column. According to the table, it is 0.2 inches (0.008).


### Example

In [19]:
import pandas as pd

df = pd.read_csv("data/raw//WikiTableQuestions/csv/203-csv/114.csv")
df

Unnamed: 0,Week,Date,TV Time,Opponent,Result,Game site,Record,Attendance
0,1,"September 7, 1998",ABC 7:00 pm MT,New England Patriots,W 27–21,Mile High Stadium (ABC),1–0,74745
1,2,"September 13, 1998",FOX 2:00 pm MT,Dallas Cowboys,W 42–23,Mile High Stadium (FOX),2–0,75013
2,3,"September 20, 1998",CBS 2:00 pm MT,at Oakland Raiders,W 34–17,Oakland-Alameda County Coliseum (CBS),3–0,56578
3,4,"September 27, 1998",CBS 11:00 am MT,at Washington Redskins,W 38–16,FedEx Field (CBS),4–0,71880
4,5,"October 4, 1998",FOX 2:00 pm MT,Philadelphia Eagles,W 41–16,Mile High Stadium (FOX),5–0,73218
5,6,"October 11, 1998",CBS 2:00 pm MT,at Seattle Seahawks,W 21–16,Kingdome (CBS),6–0,66258
6,7,Bye,Bye,Bye,Bye,Bye,Bye,Bye
7,8,"October 25, 1998",CBS 2:00 pm MT,Jacksonville Jaguars,W 37–24,Mile High Stadium (CBS),7–0,75217
8,9,"November 1, 1998",CBS 11:00 am MT,at Cincinnati Bengals,W 33–26,Cinergy Field (CBS),8–0,59974
9,10,"November 8, 1998",CBS 2:00 pm MT,San Diego Chargers,W 27–10,Mile High Stadium (CBS),9–0,74925


In [20]:
query_engine = ChainOfTableQueryEngine(df, llm=llm, verbose=True)
response = query_engine.query("Which televised ABC game had the greatest attendance?")

[1;3;32m> Iteration: 0
[0m[1;3;34m> Current table:
col : Week | Date | TV Time | Opponent | Result | Game site | Record | Attendance
row 1 : 1 | September 7, 1998 | ABC 7:00 pm MT | New England Patriots | W 27–21 | Mile High Stadium (ABC) | 1–0 | 74,745
row 2 : 2 | September 13, 1998 | FOX 2:00 pm MT | Dallas Cowboys | W 42–23 | Mile High Stadium (FOX) | 2–0 | 75,013
row 3 : 3 | September 20, 1998 | CBS 2:00 pm MT | at Oakland Raiders | W 34–17 | Oakland-Alameda County Coliseum (CBS) | 3–0 | 56,578
row 4 : 4 | September 27, 1998 | CBS 11:00 am MT | at Washington Redskins | W 38–16 | FedEx Field (CBS) | 4–0 | 71,880
row 5 : 5 | October 4, 1998 | FOX 2:00 pm MT | Philadelphia Eagles | W 41–16 | Mile High Stadium (FOX) | 5–0 | 73,218
row 6 : 6 | October 11, 1998 | CBS 2:00 pm MT | at Seattle Seahawks | W 21–16 | Kingdome (CBS) | 6–0 | 66,258
row 7 : 7 | Bye | Bye | Bye | Bye | Bye | Bye | Bye
row 8 : 8 | October 25, 1998 | CBS 2:00 pm MT | Jacksonville Jaguars | W 37–24 | Mile High Sta

AttributeError: 'NoneType' object has no attribute 'group'

In [145]:
print(str(response))

assistant: The answer is: ABC 7:00 pm MT with an attendance of 78,100.


#### Baseline

Once again, we consider a simple QA prompt baseline and get the wrong answer.

In [140]:
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import QueryPipeline

prompt_str = """\
Here's a serialized table.

{serialized_table}

Given this table please answer the question: {question}
Answer: """
prompt = PromptTemplate(prompt_str)
prompt_c = prompt.as_query_component(partial={"serialized_table": serialize_table(df)})
qp = QueryPipeline(chain=[prompt_c, llm])
response = qp.run("Which televised ABC game had the greatest attendance?")
print(str(response))

assistant: According to the table, there are two games that were televised on ABC:

1. Week 1: September 7, 1998, against the New England Patriots with an attendance of 74,745.
2. Week 16: December 21, 1998, against the Miami Dolphins with an attendance of 74,363.

The game with the greatest attendance among the ABC televised games is the Week 1 game against the New England Patriots, with an attendance of 74,745.
