### decompose

In [None]:
import sys
sys.path.append('/media/disk1/chatgpt/zh/tabular_data')

In [1]:
from prompt_manager import get_k_shot_with_answer, view_instruction, row_instruction
import pandas as pd
from utils import parse_specific_composition, add_row_number
from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI, OpenAI
from data_loader import TableFormat, TableLoader
from langchain.memory import ChatMessageHistory
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from sqlalchemy import create_engine
from executor import SQLManager
import sqlparse
embeddings = HuggingFaceBgeEmbeddings(
            model_name='BAAI/bge-large-en',
            model_kwargs={'device': 'cuda:2', 'trust_remote_code': True},
            encode_kwargs={'normalize_embeddings': True})

In [2]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from data_loader import TableFormat
query_examples = [
                  "what was the time difference between the first place finisher and the eighth place finisher?",
                  # "compare the chart positions between the us and the uk for the science of selling yourself short, where did it do better?",
                  "other than william stuart price, which other businessman was born in tulsa?",
                  "which canadian city had the most passengers traveling from manzanillo international airport in 2013?"
                # "what is the next most populous district after haridwar?",(70)
                  ]
new_query_examples = [
  # "what was the chart position of 'The Science of Selling Yourself Short' in the US?; what was the chart position of 'The Science of Selling Yourself Short' in the UK?;",
                      "what was the time for the first place finisher?; what was the time for the eighth place finisher?",
                      "was william stuart price born in tulsa?; who was born in tulsa?",
                      "how many passengers from each airline from canadian city? which canadian city had the most passengers?"
                    # "what are the districts after haridwar?; what is the next most populous district after haridwar?",
                    #   "When did polona hercog partner with alberta brianti?; When did polona hercog partner with stephanie vogt?",
                      ]
num_k = 3
inds = [1, 11, 86, 70, 42]
table_loader = TableLoader(table_name='wikitable', split='validation', use_sample=True, small_test=False)
normalised_data = [table_loader.normalize_table(table_loader.dataset[inds[i]]) for i in range(num_k)]
example_samples = [TableFormat(format='none', data=normalised_data[i], save_embedding=True, embeddings=embeddings).get_sample_data(sample_type='embedding', query=normalised_data[i]['query']) for i in range(num_k)]
examples = [TableFormat.format_nl_sep(example_samples[i], normalised_data[i]['table']['caption']) for i in range(num_k)]

examples_prompt = PromptTemplate(input_variables=["query", "table", "new_query"], template=
"""Sub-Table: {table}
Query: {query}
Decompose query: {new_query}""")

examples_dict = [{"query": query_examples[i],
                                    "table": examples[i],
                                    "new_query": new_query_examples[i]} for i in range(num_k)]
decompose_prompt_wiki = FewShotPromptTemplate(
    examples=examples_dict,
    example_prompt=examples_prompt,
    prefix="""You are capable of converting complex query into sub queries. Based on the table, decompose original query into at most 2 complete sub queries which can solve original query. Output new query directly.""",
    suffix=
    """Sub-Table: {table}
Query: {query}
Decompose query: """,
    input_variables=["query", "table"],
)



In [3]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from data_loader import TableFormat
query_examples = [
    # "after 2005 , the winner of the lifetime achievement award be andrew rule john silvester , sandra harvey lindsay simpson , marele day , shane maloney , and peter doyle",
                  "all 12 club play a total of 22 game for the wru division one east",
                #   "a gamecube game loss the award in each of the first 3 year",
                "from 1980 to 2011 , apoel bc lose more than 2 time as many game as it win",
                  "polona hercog 1890partner with alberta brianti after she have stephanie vogt as the partner",
                  ]
task_examples = ["query rewrite", "query decompose", "query ambiguity resolve"]
new_query_examples = [
    # "Who were the winners of the lifetime achievement award after 2005?;",
                      "How many clubs play for the wru division one east in total?; How many clubs play 22 game for the wru division one east?;",
                    #   "a gamecube game loss the award in each of the first 3 year",
                    "from 1980 to 2011 , how many games did apoel bc lose?; from 1980 to 2011 , how many games did apoel bc win?;",
                      "When did polona hercog partner with alberta brianti?; When did polona hercog partner with stephanie vogt?",
                      ]
num_k = 3
inds = [1, 124, 5]
table_loader = TableLoader(table_name='tabfact', split='validation', use_sample=True, small_test=False)
normalised_data = [table_loader.normalize_table(table_loader.dataset[inds[i]]) for i in range(num_k)]
example_samples = [TableFormat(format='none', data=normalised_data[i], save_embedding=False).get_sample_data(sample_type='random') for i in range(num_k)]
examples = [TableFormat.format_nl_sep(example_samples[i], normalised_data[i]['table']['caption']) for i in range(num_k)]

examples_prompt = PromptTemplate(input_variables=["query", "table", "new_query"], template=
"""Query: {query}
Table: {table}
New query: {new_query}""")

examples_dict = [{"query": query_examples[i],
                                    "table": examples[i],
                                    "new_query": new_query_examples[i]} for i in range(num_k)]
decompose_prompt = FewShotPromptTemplate(
    examples=examples_dict,
    example_prompt=examples_prompt,
    prefix="""You are capable of converting complex query into sub-queries. Based on the table, provide at most 2 sub-queries for knowledge that you need. Output new query directly.""",
    suffix=
    """Query: {query}
Table: {table}
New query: """,
    input_variables=["query", "table"],
)

# Sub-questions are separated by semicolons.
# answer_instruction = PromptTemplate(input_variables=["SQL", "table", "claim"], 
#                                     template="""
# Below is a sub-table generated by excuting the SQL. You need to understand the logic behind the SQL filtering and complete task using the final sub-table. 
# SQL Excuted: 
# ```{SQL}```
# Sub-table: {table}
# Query: {claim}
# answer the question given in the query. Only return the string instead of other format information. Do not repeat the question.
# """ )


In [4]:
task_name = 'tabfact'
split = 'test'
model_name = 'gpt-3.5-turbo-0125'
model = ChatOpenAI(model_name=model_name, openai_api_base="https://api.chatanywhere.com.cn/v1",
                       openai_api_key="sk-WZtqZEeuE0Xb6syVghDgAxdwe0ASWLkQRGxl61UI7B9RqNC4", temperature=0.7).bind(logprobs=True)
schema_information = pd.read_csv(f"result/aug/{task_name}_{split}_schema.csv", index_col='table_id')
aug_information = pd.read_csv(f"result/aug/{task_name}_{split}_summary.csv", index_col='table_id')
composition_information = pd.read_csv(f"result/aug/{task_name}_{split}_composition.csv", index_col='table_id')
engine = create_engine('sqlite:///db/sqlite/tabfact.db', echo=False)
manager = SQLManager(engine=engine)

In [5]:
table_loader = TableLoader(table_name=task_name, split='test', use_sample=False, small_test=False)
sample = table_loader.normalize_table(table_loader.dataset[2])

### step-back

In [6]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from data_loader import TableFormat
inds = [11,]
num_k = 1
table_loader = TableLoader(table_name='wikitable', split='validation', use_sample=True, small_test=False)
normalised_data = [table_loader.normalize_table(table_loader.dataset[inds[i]]) for i in range(num_k)]
example_samples = [TableFormat(format='none', data=normalised_data[i], save_embedding=False).get_sample_data(sample_type='random') for i in range(num_k)]
examples = [TableFormat.format_nl_sep(example_samples[i], normalised_data[i]['table']['caption']) for i in range(num_k)]
new_query_examples = [
    # "Which country uses the US dollar as its currency and has the Federal Reserve as its central bank?",
    "which business man was born in tulsa?",
    ]
examples_prompt = PromptTemplate(input_variables=["query", "new_query"], template=
"""
Query: {query}
Table: {table}
New query: {new_query}""")

examples_dict = [{"query": table_loader.dataset[inds[i]]['question'],
                  "table": examples[i],
                    "new_query": new_query_examples[i]} for i in range(num_k)]
step_back_prompt_wiki = FewShotPromptTemplate(
    examples=examples_dict,
    example_prompt=examples_prompt,
    prefix="""Based on the table, your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.""",
    suffix=
    """
Query: {query}
Table: {table}
New query:""",
    input_variables=["query", "table"],
)

In [7]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from data_loader import TableFormat
inds = [8, 173,]
num_k = 2
table_loader = TableLoader(table_name='tabfact', split='validation', use_sample=True, small_test=False)
normalised_data = [table_loader.normalize_table(table_loader.dataset[inds[i]]) for i in range(num_k)]
example_samples = [TableFormat(format='none', data=normalised_data[i], save_embedding=False).get_sample_data(sample_type='random') for i in range(num_k)]
examples = [TableFormat.format_nl_sep(example_samples[i], normalised_data[i]['table']['caption']) for i in range(num_k)]
new_query_examples = [
    # "Which country uses the US dollar as its currency and has the Federal Reserve as its central bank?",
    "which college list be public?",
    "what was the works number used in 1883?"
    ]
examples_prompt = PromptTemplate(input_variables=["query", "new_query"], template=
"""
Query: {query}
Table: {table}
New query: {new_query}""")

examples_dict = [{"query": table_loader.dataset[inds[i]]['statement'],
                  "table": examples[i],
                    "new_query": new_query_examples[i]} for i in range(num_k)]
step_back_prompt = FewShotPromptTemplate(
    examples=examples_dict,
    example_prompt=examples_prompt,
    prefix="""Based on the table, your task is to step back and paraphrase a question to a more generic step-back question, which is easier to answer.""",
    suffix=
    """
Query: {query}
Table: {table}
New query:""",
    input_variables=["query", "table"],
)

In [8]:
def get_k_shot_with_answer_noinfo(k: int=1):
    sqls = ["SELECT MIN(points) FROM DF WHERE rider = 'roger dutton / tony wright';"
            ]
    thoughts = ["Based on the SQL query provided, the minimum number of points that Roger Dutton / Tony Wright received in the 1972 Isle of Man TT event was 3. 3 is the fewest points they received. "]
    tables = ["<table>\n<caption>1972 isle of man tt</caption>\n<thead>\n<tr><th>  MIN(points)</th></tr>\n</thead>\n<tbody>\n<tr><td>3            </td></tr>\n</tbody>\n</table>"]
    claims = ["2 be the fewest point that roger dutton / tony wright receive"]
    # inds from test split
    examples_prompt = PromptTemplate(input_variables=["SQL", "table", "claim", "thought", "output"], template=
    """
SQL Excuted: 
```{SQL}```
Sub-table: {table}
Query: {claim}
Thought: {thought}
Answer: {output}
    """)
    examples_dict = dict(zip(["SQL", "table", "claim", "thought", "output"], [sqls[0], tables[0], claims[0], thoughts[0], '0']))
    prompt_template = FewShotPromptTemplate(
        examples=[examples_dict],
        example_prompt=examples_prompt,
        prefix="""Below is a sub-table generated by excuting the corresponding SQL. You need to understand the logic behind the SQL filtering. Think step by step and verify whether the provided claim/query is true or false.
You should output in the following format:
Thought: your step by step thought
Answer: return 0 if the query is false, or 1 if the query true
Below is an example.""",
        suffix=
        """
SQL Excuted: 
```{SQL}```
Sub-table: {table}
Query: {query}""",
        input_variables=["table", "query", "SQL"],
)
    return prompt_template

def get_k_shot_with_answer(k: int=1):
    sqls = ["SELECT MIN(points) FROM DF WHERE rider = 'roger dutton / tony wright';"
            ]
    thoughts = ["Based on the SQL query provided, the minimum number of points that Roger Dutton / Tony Wright received in the 1972 Isle of Man TT event was 3. 3 is the fewest points they received. "]
    tables = ["<table>\n<caption>1972 isle of man tt</caption>\n<thead>\n<tr><th>  MIN(points)</th></tr>\n</thead>\n<tbody>\n<tr><td>3            </td></tr>\n</tbody>\n</table>"]
    claims = ["verify whether the provided claim/query is true or false. 2 be the fewest point that roger dutton / tony wright receive"]
    # inds from test split
    examples_prompt = PromptTemplate(input_variables=["SQL", "table", "claim", "thought", "output"], template=
    """
SQL Excuted: 
```{SQL}```
Sub-table: {table}
Query: {claim}
Thought: {thought}
Answer: {output}
    """)
    examples_dict = dict(zip(["SQL", "table", "claim", "thought", "output"], [sqls[0], tables[0], claims[0], thoughts[0], '0']))
    prompt_template = FewShotPromptTemplate(
        examples=[examples_dict],
        example_prompt=examples_prompt,
        prefix="""Below is a sub-table generated by excuting the corresponding SQL. You need to understand the logic behind the SQL filtering. Think step by step and verify whether the provided claim/query is true or false, return 0 if it's false, or 1 if it's true.
You should output in the following format:
Thought: your step by step thought
Answer: final answer
Below is an example.""",
        suffix=
        """
SQL Excuted: 
```{SQL}```
Sub-table: {table}
Extra information:
{information}

Query: {query}""",
        input_variables=["table", "query", "SQL", "information"],
)
    return prompt_template

In [9]:
# def get_k_shot_with_aug(k: int=2):
#     table_loader = TableLoader(table_name='tabfact', split='validation', use_sample=True, small_test=False)

#     inds = [3, 6, 260, 33]
#     Output_examples = [
#                        'team, goals_for',
#                        'year, game, platform_s',
#                        'name, population_density_km_2_, population_2011_census_'
#                        'leading_scorer, score, date']
#     linking_examples = ['the team -> team; the most goal for -> goals_for',
#                         'gamecube -> platform_s; gamecube game -> game; the first 3 year -> year;',
#                         'alberta -> name; population density -> population_density_km_2_; 4257744 less people -> population_2011_census_; 2011 -> population_2011_census_'
#                         'jason richardson -> leading_scorer; leading scorer -> score; month -> date; 23 point per game -> leading_scorer'
#     ]
#     examples_prompt = PromptTemplate(input_variables=["table", "claim", "output", "linking"], template=
#     """
#     Table: {table}
#     Query: {claim}
#     Column linking: {linking}
#     Columns: {output}""")
#     num_k = 3
#     examples_dict = [{"table": TableFormat(format='none', data=table_loader.dataset[inds[i]], use_sampling=True).format_nl_sep(table_loader.dataset[inds[i]]['table']['caption']),
#                                         "claim": table_loader.dataset[inds[i]]['statement'],
#                                         "linking": linking_examples[i],
#                                         # "summary": summary_examples[i],
#                                         "output": Output_examples[i]} for i in range(num_k)]
#     prompt_template = FewShotPromptTemplate(
#         examples=examples_dict,
#         example_prompt=examples_prompt,
#         prefix=
#         """
#     Your task is accurately output columns related to the query or contain useful information about the query. This process involves linking similar words or semantically similar terms to columns in the table.
#     Approach this task as follows:
#     Read the question thoroughly and list every possible link from query term to column in Table.
#     Based on the column linking, output all useful columns at last. Make sure all columns in the link step are included and every column is in the Table.""",
#     # You are a brilliant table executor with the capabilities information retrieval, table parsing, table partition and semantic understanding who can understand the structural information of the table.
#     # Given the following table and query, you should output columns related to the query or contain useful information about the query. 
#     # Here are some examples:""",
#         suffix=
#         """
#     Table: {table}
#     Query: {claim}
#     Column linking:
#     """,
#         input_variables=["table", "claim"],
# )
#     return prompt_template

def get_k_shot_with_aug(k: int=2):
    table_loader = TableLoader(table_name='tabfact', split='validation', use_sample=True, small_test=False)
    table_loader_wiki = TableLoader(table_name='wikitable', split='train', use_sample=True, small_test=False)
    inds = [3, 6, 260, 33]
    Output_examples = [
                       'team, goals_for',
                       'year, game, platform_s',
                       'name, population_density_km_2_, population_2011_census_'
                       'leading_scorer, score, date']
    linking_examples = ['the team -> team, the most goal for -> goals_for',
                        'gamecube -> platform_s, gamecube game -> game, the first 3 year -> year',
                        'alberta -> name, population density -> population_density_km_2_, 4257744 less people -> population_2011_census_, 2011 -> population_2011_census_'
                        'jason richardson -> leading_scorer, month -> date, 23 point per game -> score'
    ]
    examples_prompt = PromptTemplate(input_variables=["table", "claim", "output", "linking"], template=
    """
    Table: {table}
    Query: {claim}
    Column linking: {linking}
    Columns: {output}""")
    num_k = 2
    normalised_data = [table_loader.normalize_table(table_loader.dataset[inds[i]]) for i in range(num_k)]
    example_samples = [TableFormat(format='none', data=normalised_data[i], save_embedding=False).get_sample_data(sample_type='random') for i in range(num_k)]
    examples = [TableFormat.format_html(example_samples[i], normalised_data[i]['table']['caption']) for i in range(num_k)]
    examples_dict = [{"table": examples[i],
                                        "claim": table_loader.dataset[inds[i]]['statement'],
                                        "linking": linking_examples[i],
                                        # "summary": summary_examples[i],
                                        "output": Output_examples[i]} for i in range(num_k)]
    examples_dict.extend([{"table": '<table>\n<caption>Hoot Kloot</caption>\n<thead>\n<tr><th> Number</th><th> Title</th><th> Directed_by_</th><th> Released_</th></tr>\n</thead>\n<tbody>\n<tr><td>1  </td><td>"Kloot\'s Kounty"           </td><td>Hawley Pratt  </td><td>1973       </td></tr>\n<tr><td>2  </td><td>"Apache on the County Seat"</td><td>Hawley Pratt  </td><td>1973       </td></tr>\n<tr><td>6  </td><td>"Stirrups and Hiccups"     </td><td>Gerry Chiniquy</td><td>1973       </td></tr>\n</tbody>\n</table>',
                                        "claim": table_loader_wiki.dataset[95]['question'],
                                        "linking": "the last title -> Released_, the last title-> Number, title -> Title",
                                        # "summary": summary_examples[i],
                                        "output": "Title, Released_, Number"}])
#     examples_dict.extend([{"table": """<table>
# <thead>
# <tr><th>  Pick</th><th>       Player</th><th>              Team</th><th>  Position</th><th>                    School</th></tr>
# </thead>
# <tbody>
# <tr><td>24    </td><td>Alan Zinter  </td><td>New York Mets     </td><td>C         </td><td>University of Arizona     </td></tr>
# <tr><td>18    </td><td>Willie Greene</td><td>Pittsburgh Pirates</td><td>SS        </td><td>Jones County HS (Gray, GA)</td></tr>
# <tr><td>23    </td><td>Mo Vaughn    </td><td>Boston Red Sox    </td><td>1B        </td><td>Seton Hall University     </td></tr>
# </tbody>
# </table>""",
#                                     "claim": "was kiki jones picked before or after greg gohr?",
#                                     "linking": "kiki jones -> Player, picked before -> Pick, after -> Pick, greg gohr -> Player",
#                                     # "summary": summary_examples[i],
#                                     "output": "Player, Pick"}])
    prompt_template = FewShotPromptTemplate(
        examples=examples_dict,
        example_prompt=examples_prompt,
        prefix=
        """
    Your task is accurately output columns related to the query or contain useful information about the query. This process involves linking similar words or semantically similar terms to columns in the table.
    Approach this task as follows，read the claim thoroughly and list every possible link from query term to column in Table. Then Based on the column linking, output all useful columns at last. Make sure all columns in the link step are included and every column is in the Table.""",
    # You are a brilliant table executor with the capabilities information retrieval, table parsing, table partition and semantic understanding who can understand the structural information of the table.
    # Given the following table and query, you should output columns related to the query or contain useful information about the query. 
    # Here are some examples:""",
        suffix=
        """
    Table: {table}
    Extra information: {aug}
    Query: {claim}""",
        input_variables=["table", "claim", "aug"],
)
    return prompt_template

pre_instruction_schema = PromptTemplate(input_variables=["table"], template="""
Instruction: Given the following table, you will add Metadata about the columns in the table.
Metadata includes:
- Numerical: consist digits and numerical symbols like decimal points or signs.
- Char: whether column content is a text or description.
- Date: whether the column content is datetime.

You need to output all the column names with metadata in angle brackets.
Example: name<Char> launched<Date> count<Numerical>

Table: {table}
Output:
""")

In [10]:
from utils import parse_output
def scene_A(query, sample, verbose=True):
    row_instruction = PromptTemplate(input_variables=["table", "claim", "aug"], 
                                 template="""Our ultimate goal is to answer query based on the original table. Now we have a sub-table with rows sampled from the original table below, you are required to infer the data distribution and format from the sample data of the sub-table. Carefully analyze the query, based on the augmentation information, write a SQLITE3 SELECT SQL statement using table DF that complete query. Directly Output SQL, do not add other string.
sub-table: {table}
Extra information: {aug}

Query: {claim}
SQL: """)
    # row_instruction = PromptTemplate(input_variables=["table", "claim", "aug"], 
    #                                 template="""
    # Our ultimate goal is to answer query based on the table. Below is a subtable with columns filtered, you are required to infer the data distribution and format from the sample data of the sub-table. Carefully analyze the query, based on the augmentation information, write a SQLITE3 SELECT SQL statement using table DF that complete query. Directly Output SQL, do not add other string.
    # sub-table: {table}
    # Query: {claim}
    # Extra information: {aug}
    # SQL: """)
    formatter = TableFormat(format='none', data=sample, save_embedding=True, embeddings=embeddings)
    k_shot_prompt = get_k_shot_with_aug()
    formatter.normalize_schema(schema_information.loc[sample['table']['id']]['schema'])
    sample_data = formatter.get_sample_data(sample_type='embedding', query=query)
    with get_openai_callback() as cb:
        llm_chain = LLMChain(llm=model, prompt=k_shot_prompt, verbose=verbose)
        summary_aug, column_aug = aug_information.loc[sample['table']['id']]['summary'], aug_information.loc[sample['table']['id']]['column_description'] 
        # if pd.isna(summary_aug):
        #     summary_aug = ''
        col_names, col_infos = parse_output(column_aug, pattern=r'([^<]*)<([^>]*)>')
        extra_col_info = []
        for i_c in range(len(col_names)):
            extra_col_info.append(f'{i_c + 1}. {col_names[i_c]}: {col_infos[i_c]}')
        
        stage_1_batch_pred = llm_chain.batch([dict({'table': TableFormat.format_html(data=sample_data, table_caption=sample['table']['caption']),
                                            'claim': query,
                                            'aug':  summary_aug + '\n'.join(extra_col_info)
                                            })], return_only_outputs=True)[0]['text']
        stage_1_batch_pred = stage_1_batch_pred.split(':')[-1]
        if verbose:
            print(stage_1_batch_pred)
        extra_cols = formatter.get_sample_column(embeddings, column_aug)
        # stage 2: SQL generation
        llm_chain = LLMChain(llm=model, prompt=row_instruction, verbose=verbose)
        columns = list(set([c.strip() for c in stage_1_batch_pred.split(',')] + extra_cols))
        
        try: 
            # formatter.all_data = formatter.all_data.loc[:, columns]
            sample_data = add_row_number(sample_data.loc[:, columns])
        except:
            pass
        extra_information = parse_specific_composition(composition_information.loc[sample['table']['id']]['composition'], sample_data.columns)
        extra_information.append('row_number: row number in the original table')
        stage_2_batch_pred = llm_chain.batch([dict({'table': TableFormat.format_html(data = sample_data, table_caption=sample['table']['caption']),
                                            'claim': query,
                                            'aug':  summary_aug + '\nColumn information:\n' + '\n'.join(extra_information)
                                            })], return_only_outputs=True)[0]['text'].replace("–", "-").replace("—", "-").replace("―", "-").replace("−", "-")
        # print(stage_2_batch_pred)
    # stage 3: SQL Excution
    try: 
        execute_data = manager.execute_from_df(stage_2_batch_pred, add_row_number(formatter.all_data), table_name='DF')
    except:
        execute_data = formatter.all_data
        stage_2_batch_pred = 'SELECT * from DF;'
    if len(execute_data) == 0:
        return query, stage_2_batch_pred, 'No data from database', cb.total_tokens
    return query, stage_2_batch_pred, TableFormat.format_html(data=execute_data), cb.total_tokens

In [11]:
import concurrent.futures
from langchain_community.callbacks import get_openai_callback
def parallel_run(func, args_list):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(func, arg) for arg in args_list]
        return [future.result() for future in concurrent.futures.as_completed(results)]

def parallel_run_kwargs(func, args_list):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = executor.map(lambda kwargs: func(**kwargs), args_list)
        return list(results)

In [12]:
from typing import List
import os
import json
def save_csv(input_list: List[List], label_list: List, file_path):
    import pandas as pd
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    assert len(input_list) == len(label_list)
    df = pd.DataFrame()
    for i in range(len(label_list)):
        df[label_list[i]] = pd.Series(input_list[i])
    if os.path.exists(file_path) and file_path.endswith('.csv'):
        df_origin = pd.read_csv(file_path)
        df = pd.concat([df_origin, df], axis=0)
    df.to_csv(file_path, index=False, encoding='utf-8')

### RUN

In [13]:
def eval_blury_string(pred_list):
    pred_label = []
    for pred in pred_list:
        predict_ans = pred.split('\n')[-1]
        if '0' in predict_ans:
            predict_ans = '0'
        elif '1' in predict_ans:
            predict_ans = '1'
        else:
            predict_ans = '2'
        pred_label.append(predict_ans)
    return pred_label

In [14]:
import concurrent.futures
from typing import List
import os
import json
from langchain_community.callbacks import get_openai_callback
def parallel_run(func, args_list):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = [executor.submit(func, arg) for arg in args_list]
        return [future.result() for future in concurrent.futures.as_completed(results)]

def parallel_run_kwargs(func, args_list):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = executor.map(lambda kwargs: func(**kwargs), args_list)
        return list(results)

def save_csv(input_list: List[List], label_list: List, file_path):
    import pandas as pd
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)

    assert len(input_list) == len(label_list)
    df = pd.DataFrame()
    for i in range(len(label_list)):
        df[label_list[i]] = pd.Series(input_list[i])
    if os.path.exists(file_path) and file_path.endswith('.csv'):
        df_origin = pd.read_csv(file_path)
        df = pd.concat([df_origin, df], axis=0)
    df.to_csv(file_path, index=False, encoding='utf-8')

## 调整extrainformation的位置

In [15]:
from utils import parse_output
answer_instruction = PromptTemplate(input_variables=["SQL", "table", "claim"], 
                                    template="""
Below is a sub-table generated by excuting the SQL. You need to understand the logic behind the SQL filtering and complete task using the final sub-table. 
SQL Excuted: 
```{SQL}```
Sub-table: 
{table}
Query: {claim}
Please provide a clear, complete statement in response to the question. If you cannot answer the question based on the sub-table, just say 'Cannot get answer from sub-table'
""" )
def scene_B(query, sample, verbose=False):
    row_instruction = PromptTemplate(input_variables=["table", "claim", "aug"], 
                                 template="""Our ultimate goal is to answer query based on the original table. Now we have a sub-table with rows sampled from the original table, you are required to infer the data distribution and format from the sample data of the sub-table. Carefully analyze the query, based on the augmentation information, write a SQLITE3 SELECT SQL statement using table DF that complete query. Directly Output SQL, do not add other string.
sub-table: {table}
Extra information: {aug}

Query: {claim}
SQL: """)
    formatter = TableFormat(format='none', data=sample, save_embedding=True, embeddings=embeddings)
    formatter.normalize_schema(schema_information.loc[sample['table']['id']]['schema'])
    sample_data = formatter.get_sample_data(sample_type='embedding', query=query)
    # get columns
    k_shot_prompt = get_k_shot_with_aug()
    with get_openai_callback() as cb:
        llm_chain = LLMChain(llm=model, prompt=k_shot_prompt, verbose=verbose)
        summary_aug, column_aug = aug_information.loc[sample['table']['id']]['summary'], aug_information.loc[sample['table']['id']]['column_description'] 
        # if pd.isna(summary_aug):
        #     summary_aug = ''
        col_names, col_infos = parse_output(column_aug, pattern=r'([^<]*)<([^>]*)>')
        extra_col_info = []
        for i_c in range(len(col_names)):
            extra_col_info.append(f'{i_c + 1}. {col_names[i_c]}: {col_infos[i_c]}')
        stage_1_batch_pred = llm_chain.batch([dict({'table': TableFormat.format_html(data=sample_data, table_caption=sample['table']['caption']),
                                            'claim': query,
                                            'aug':  '\n'.join(extra_col_info)
                                            })], return_only_outputs=True)[0]['text']
        stage_1_batch_pred = stage_1_batch_pred.split(':')[-1]
        
        extra_cols = formatter.get_sample_column(embeddings, column_aug)
        # stage 2: SQL generation
        llm_chain = LLMChain(llm=model, prompt=row_instruction, verbose=verbose)
        columns = list(set([c.strip() for c in stage_1_batch_pred.split(',')] + extra_cols))
        
        try: 
            sample_data = add_row_number(sample_data.loc[:, columns])
        except:
            pass
        extra_information = (parse_specific_composition(composition_information.loc[sample['table']['id']]['composition'], sample_data.columns))
        extra_information.append('row_number: row number in the table')
        stage_2_batch_pred = llm_chain.batch([dict({'table': TableFormat.format_html(data=sample_data, table_caption=sample['table']['caption']),
                                            'claim': query,
                                            'aug':  summary_aug + '\n Column information:' + '\n'.join(extra_information)
                                            })], return_only_outputs=True)[0]['text'].replace("–", "-").replace("—", "-").replace("―", "-").replace("−", "-")
    
        
        # stage 3: SQL Excution
        try: 
            execute_data= manager.execute_from_df(stage_2_batch_pred, add_row_number(formatter.all_data), table_name='DF')
        except:
            execute_data = formatter.all_data
            stage_2_batch_pred = 'SELECT * from DF;'
        llm_chain = LLMChain(llm=model, prompt=answer_instruction, verbose=verbose)
        response = llm_chain.batch([dict({'table': TableFormat.format_html(execute_data),
                                                'claim': query,
                                                'SQL':  stage_2_batch_pred
                                                })], return_only_outputs=True)[0]['text']
    # print("total_tokens:", cb.total_tokens)
    return response, cb.total_tokens

In [16]:
from langchain_openai import AzureChatOpenAI
import os

os.environ["AZURE_OPENAI_API_KEY"] = "0c75de50975e4f278b882fe90da47f2f"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://ces.openai.azure.com"
os.environ["AZURE_OPENAI_API_VERSION"] = "2024-02-01"
os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"] = "gpt-35-turbo"
model = AzureChatOpenAI(
    openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
    azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"],
    temperature=0.01
)


### query aug

In [32]:


from langchain.chains import LLMChain
from langchain_openai import ChatOpenAI, OpenAI
import datetime
from FlagEmbedding import FlagReranker
from openai import BadRequestError, RateLimitError
from tqdm.notebook import tqdm
table_loader = TableLoader(table_name='tabfact', split='test', use_sample=False, small_test=True)
# model = ChatOpenAI(model_name='gpt-3.5-turbo-0125', openai_api_base="https://api.chatanywhere.com.cn/v1",
#                        openai_api_key="sk-WZtqZEeuE0Xb6syVghDgAxdwe0ASWLkQRGxl61UI7B9RqNC4", temperature=0.01)
save_path = f"result/answer/tabfact_test_05-10_14-41-49.csv"
reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True)

muilti_answer_instruction = PromptTemplate(input_variables=["information", "claim"], 

# no cot prompt, not used
template = """
Below is a sub-table generated by excuting the corresponding SQL. You need to understand the logic behind the SQL filtering. Complete task with the help of extra information below.

SQL Excuted: 
```{SQL}```
Sub-table:
{table}
Extra information:
{information}

Query: {query}
Think step by step and answer the last question given in the query. Only return the string instead of other format information. Do not repeat the question.
""" )
# Task: answer the last question given in the query. Only return the string instead of other format information. Do not repeat the question.
# Task: verify whether the provided claim/query is true or false, return 0 if it's false, or 1 if it's true. Please think step by step and return 0/1 at last.


tokens = []
outputs = []
labels = []
ids = []
samplings = []
i = 1982
# with tqdm(total=len(table_loader.dataset)-1990, desc=f"Processing",ncols=150) as pbar:
#     while i < len(table_loader.dataset):
#         try:
sample = table_loader.normalize_table(
                    table_loader.dataset[i])
all_tokens = 0
all_queries = []
formatter = TableFormat(format='none', data=sample, save_embedding=False)
sample_data = formatter.get_sample_data(sample_type='random', query=sample['query'])
with get_openai_callback() as cb:
    llm_chain = LLMChain(llm=model, prompt=step_back_prompt, verbose=False)
    batch_pred = llm_chain.batch([{"query": sample['query'], "table": TableFormat.format_html(sample_data)}], return_only_outputs=True)
    all_queries.append(batch_pred[0]['text'].strip())
    llm_chain = LLMChain(llm=model, prompt=decompose_prompt, verbose=False)
    batch_pred = llm_chain.batch([{"query": sample['query'], "table": TableFormat.format_html(sample_data)}], return_only_outputs=True)
    all_queries.extend(batch_pred[0]['text'].split(';'))
    print(all_queries)
all_tokens += cb.total_tokens
args_list = [{"query": q, "sample": sample} for q in all_queries if reranker.compute_score([(q, sample['query'])], normalize=True) < 0.95]
print(len(args_list))
ans_from_B = parallel_run_kwargs(scene_B, args_list)
results = [res[0] for res in ans_from_B if res[0] != 'Cannot get answer from sub-table']
all_tokens += sum([res[1] for res in ans_from_B])
#With answer
# results= []
with get_openai_callback() as cb:
    imp_input = scene_A(sample['query'], sample, True)
    llm_chain = LLMChain(llm=model, prompt=get_k_shot_with_answer(), verbose=True)
    batch_pred = llm_chain.batch([{"query": sample['query'],"SQL": imp_input[1], "table": imp_input[2], "information": '\n'.join(results)}], return_only_outputs=True)
    print(batch_pred[0])
    # all_tokens += cb.total_tokens
    # print('ALL TOKENS', all_tokens)
    ids.append(i)
    labels.append(sample['query'])
    outputs.append(batch_pred[0]['text'])
    


----------using 4*GPUs----------
["What is the call sign for Astral Media's radio station?", "What is the call sign for Astral Media's radio station?", " What is the branding for Astral Media's radio station?"]
1


[1m> Entering new LLMChain chain...[0m
Prompt after formatting:
[32;1m[1;3m
    Your task is accurately output columns related to the query or contain useful information about the query. This process involves linking similar words or semantically similar terms to columns in the table.
    Approach this task as follows，read the claim thoroughly and list every possible link from query term to column in Table. Then Based on the column linking, output all useful columns at last. Make sure all columns in the link step are included and every column is in the Table.


    Table: <table>
<caption>1986 - 87 north west counties football league</caption>
<thead>
<tr><th>  position</th><th>               team</th><th>  played</th><th>  drawn</th><th>  lost</th><th>  goals_for</th><t

In [33]:
outputs

["Thought: The sub-table does not contain the branding for Astral Media's radio station. Therefore, it is not possible to verify the claim based on the provided sub-table.\nAnswer: 0"]