## Imports

In [1]:
import openai
import tiktoken

import pandas as pd
import json

from dateutil import parser as date_parser
from unidecode import unidecode

import os
import time
import re

## Table Generator

In [2]:
class TableGenerator_JSON():
    TEMPLATE = """
    You are a retriever of facts.
    List all %s.
    The response will be formatted as JSON shown below.
    Each element of the response will contain %d fields: %s.
    Do not output any additional text that is not in JSON format.
        
    RESPONSE FORMAT:
    [{
        %s
    }]
    """ 
    
    def _norm_field(self, s):
        s = s.lower().replace(" ","_").replace("-","_").replace(".", "").replace(",","_")\
                .replace("(", "").replace(")", "").replace(":", "").replace('"','').replace("'","")\
                .replace("/", "")
        return re.sub('_+', '_', s)
    
    def generate_prompt(self, table_description, fields, example):
        num_fields = len(fields)
        fields_json = []
        fields = [self._norm_field(f) for f in fields]
        for field, value in zip(fields, example):
            fields_json.append('"%s": "%s"' % (field, value))
        response_format = ', '.join(fields_json)
        prompt = self.TEMPLATE % (table_description, num_fields, fields, response_format)
        return prompt  
    
    def parse_llm_response(self, response, fields): 
        fields = [self._norm_field(f) for f in fields]
        
        res = []
        try:
            if not response.startswith("[") and "[" in response:
                response = response[response.find("["):]

            if not response.endswith("]") and "]" in response:
                response = response[:response.rfind("]")+1]

            if '[' not in response and ']' not in response and '{' in response and '}' in response:
                response = '[' + response + ']'    

            response_json = json.loads(response)

            if isinstance(response_json, dict) and len(response_json.keys()) == 1:
                response_json = list(response_json.values())[0]    
        except:  
            split_response = response.split("{")
            response_json = []
            for s in split_response[1:]:
                split_s = s.split("}")
                if len(split_s) > 1:
                    content = split_s[0]
                    attributes = content.split(",")
                    elements = {}
                    for attr in attributes:
                        knv = attr.split(":")   
                        if len(knv) > 1:
                            parsed_k = "%s" % knv[0].replace('"','').strip()
                            parsed_v = "%s" % knv[1].replace('"','').strip()
                            elements[parsed_k] = parsed_v

                    response_json.append(elements)  
        
        valid_response_json = []
        for row in response_json:
            element = {}
            for k, v in row.items():
                if k in fields:
                    element[k] = row[k]
                valid_response_json.append(element)
        
        df = pd.DataFrame.from_records(valid_response_json) 
        return df

## Experiment Runner

In [3]:
class ExperimentRunner():
    openai.api_key = ""
    openai.api_base = "https://api.deepinfra.com/v1/openai"
    MODEL = "meta-llama/Meta-Llama-3.1-70B-Instruct"
    NOTE = 'full_table_first_example'
    MAX_LEN = 3900
    
    def __init__(self, table_generator, metadata_path):
        with open(metadata_path, "rb") as f:
            self.metadata = json.load(f)
            
        self.table_generator = table_generator
        self.encoding = tiktoken.get_encoding("cl100k_base")
        
        self.result_folder = "DATA/%s_%s_%s" % (self.MODEL.split("/")[-1].replace('-', '_'), 
                                                   self.NOTE,
                                                   time.strftime("%Y%m%d-%H%M%S"))
        
        print("Experiment result folder: %s" % self.result_folder)
        
        os.makedirs(self.result_folder)
        os.makedirs("%s/tables" % self.result_folder)
        
        self.result = {}
        
    def fetch_data(self, idx):
        task = self.metadata[idx]
        
        task_name = task['name']        
        print("Fetching data for %s" % task_name)
        
        query, columns = task['table_title'], task['columns']  
        
        df_ref = pd.read_csv(task['path'])  
        row_ref = df_ref.iloc[0]
        example = [row_ref[i] for i in range(df_ref.shape[1])]   
        
        prompt = self.table_generator.generate_prompt(query, columns, example)        

        self.result[idx] = {'prompt': prompt}        
            
        try:
            max_tokens = self.MAX_LEN - len(self.encoding.encode(prompt))
            result = openai.ChatCompletion.create(
                model=self.MODEL,
                messages=[{"role": "user", "content": prompt}],
                temperature=0,
                max_tokens=max_tokens
            )   
            response = result["choices"][0]["message"]["content"].strip()

            if 'response' in self.result[idx]:
                self.result[idx]['response'].append(response)
            else:    
                self.result[idx]['response'] = [response]

            df = self.table_generator.parse_llm_response(response, columns) 
     
            df.columns = df_ref.columns
            df = df.drop_duplicates(subset=task['keys'])

            table_path = "%s/tables/%s.csv" % (self.result_folder, task_name)
            self.result[idx]['table_path'] = table_path                
            df.to_csv(table_path, index=False)            

            print("Created table with %d rows" % len(df))

            return df
        except Exception as e:  
            print(e.__class__.__name__)
            
    def save_result(self):
        with open("%s/result.json" % self.result_folder, "w") as outfile:
            result_json = json.dumps(self.result, indent=4)
            outfile.write(result_json)

## Test

In [4]:
tg = TableGenerator_JSON()

runner = ExperimentRunner(tg, metadata_path="DATA/benchmark/cfg.json")

print("\n====================\n")

for i in range(100):
    print("Table # %d" % (i+1))
    idx = "%d" % i
    table = runner.fetch_data(idx)
    print("\n====================\n")
    
runner.save_result() 

Experiment result folder: DATA/Meta_Llama_3.1_70B_Instruct_full_table_first_example_20240922-100716


Table # 1
Fetching data for republican_straw_polls_2012
Created table with 40 rows


Table # 2
Fetching data for russia_demographics_1946_2012
Created table with 21 rows


Table # 3
Fetching data for belgium_demographics_1900_2011
Created table with 33 rows


Table # 4
Fetching data for australia_demographics_1900_2010
Created table with 33 rows


Table # 5
Fetching data for new_brunswick_parishes_2006_2011
Created table with 55 rows


Table # 6
Fetching data for ice_hockey_2006
Created table with 75 rows


Table # 7
Fetching data for biathlon_sprint_standings_2009_10
Created table with 10 rows


Table # 8
Fetching data for anaheim_ducks_draft_picks_1998_2013
Created table with 66 rows


Table # 9
Fetching data for south_african_class_15f_4_8_2
Created table with 121 rows


Table # 10
Fetching data for tour_de_france_2009
Created table with 108 rows


Table # 11
Fetching data for men_b

Created table with 20 rows


Table # 76
Fetching data for academy_award_best_actress_2000s
Created table with 50 rows


Table # 77
Fetching data for mens_walking_20km_record_1911_2007
Created table with 30 rows


Table # 78
Fetching data for booknotes_interviews_1996
Created table with 36 rows


Table # 79
Fetching data for bafta_best_actor_leading_role_2000s
Created table with 10 rows


Table # 80
Fetching data for troublemaker_song_release_history
Created table with 6 rows


Table # 81
Fetching data for wind_power_kansas_2001_2011
Created table with 11 rows


Table # 82
Fetching data for fljotsdalshreppur_population_1998_2011
Created table with 14 rows


Table # 83
Fetching data for social_credit_party_1951_1984
Created table with 12 rows


Table # 84
Fetching data for soekarno_hatta_airport_2001_2012
Created table with 12 rows


Table # 85
Fetching data for ulrike_maier_season_standings_1985_1994
Created table with 10 rows


Table # 86
Fetching data for black_dog_barking_charts
Crea