## Imports

In [None]:
import openai
import tiktoken

import pandas as pd
import json

from dateutil import parser as date_parser
from unidecode import unidecode

import os
import time
import re

In [None]:
openai.api_key = "<your_open_api_key>"

## Table Generator

In [None]:
class TableGenerator_JSON():
    TEMPLATE = """
    You are a retriever of facts.
    List all %s.
    The response will be formatted as JSON shown below.
    Each element of the response will contain %d fields: %s.
    Do not output any additional text that is not in JSON format.
        
    RESPONSE FORMAT:
    [{
        %s
    }]
    """ 
    
    def _norm_field(self, s):
        s = s.lower().replace(" ","_").replace("-","_").replace(".", "").replace(",","_")\
                .replace("(", "").replace(")", "").replace(":", "").replace('"','').replace("'","")\
                .replace("/", "")
        return re.sub('_+', '_', s)
    
    def generate_prompt(self, table_description, fields, example):
        num_fields = len(fields)
        fields_json = []
        fields = [self._norm_field(f) for f in fields]
        for field, value in zip(fields, example):
            fields_json.append('"%s": "%s"' % (field, value))
        response_format = ', '.join(fields_json)
        prompt = self.TEMPLATE % (table_description, num_fields, fields, response_format)
        return prompt  
    
    def parse_llm_response(self, response, fields): 
        fields = [self._norm_field(f) for f in fields]
        
        res = []
        try:
            if not response.startswith("[") and "[" in response:
                response = response[response.find("["):]

            if not response.endswith("]") and "]" in response:
                response = response[:response.rfind("]")+1]

            if '[' not in response and ']' not in response and '{' in response and '}' in response:
                response = '[' + response + ']'    

            response_json = json.loads(response)

            if isinstance(response_json, dict) and len(response_json.keys()) == 1:
                response_json = list(response_json.values())[0]    
        except:  
            split_response = response.split("{")
            response_json = []
            for s in split_response[1:]:
                split_s = s.split("}")
                if len(split_s) > 1:
                    content = split_s[0]
                    attributes = content.split(",")
                    elements = {}
                    for attr in attributes:
                        knv = attr.split(":")   
                        if len(knv) > 1:
                            parsed_k = "%s" % knv[0].replace('"','').strip()
                            parsed_v = "%s" % knv[1].replace('"','').strip()
                            elements[parsed_k] = parsed_v

                    response_json.append(elements)  
        
        valid_response_json = []
        for row in response_json:
            element = {}
            for k, v in row.items():
                if k in fields:
                    element[k] = row[k]
                valid_response_json.append(element)
        
        df = pd.DataFrame.from_records(valid_response_json) 
        return df

## Experiment Runner

In [None]:
class ExperimentRunner():
    MODEL = "gpt-3.5-turbo-instruct-0914"
    NOTE = "full_table_first_example"
    MAX_LEN = 4097
    
    def __init__(self, table_generator, metadata_path):
        with open(metadata_path, "rb") as f:
            self.metadata = json.load(f)
            
        self.table_generator = table_generator
        self.encoding = tiktoken.encoding_for_model(self.MODEL)
        
        self.result_folder = "%s_%s_%s" % (self.MODEL.replace('-', '_'), 
                                           self.NOTE,
                                           time.strftime("%Y%m%d-%H%M%S"))
        
        print("Experiment result folder: %s" % self.result_folder)
        
        os.makedirs(self.result_folder)
        os.makedirs("%s/tables" % self.result_folder)
        
        self.result = {}
        
    def fetch_data(self, idx):
        task = self.metadata[idx]
        
        task_name = task['name']        
        print("Fetching data for %s" % task_name)
        
        query, columns = task['table_title'], task['columns']  
        
        df_ref = pd.read_csv(os.path.join("../benchmark/tables", task['file']))  
        row_ref = df_ref.iloc[0]
        example = [row_ref[i] for i in range(df_ref.shape[1])]   
        
        prompt = self.table_generator.generate_prompt(query, columns, example)        

        self.result[idx] = {'prompt': prompt}        
            
        try:
            max_tokens = self.MAX_LEN - len(self.encoding.encode(prompt))
            result = openai.Completion.create(engine=self.MODEL, prompt=prompt, temperature=0, max_tokens=max_tokens)
            response = result['choices'][0]['text'].strip()

            if 'response' in self.result[idx]:
                self.result[idx]['response'].append(response)
            else:    
                self.result[idx]['response'] = [response]

            df = self.table_generator.parse_llm_response(response, columns) 
     
            df.columns = df_ref.columns
            df = df.drop_duplicates(subset=task['keys'])

            table_path = "%s/tables/%s.csv" % (self.result_folder, task_name)
            self.result[idx]['table_path'] = table_path                
            df.to_csv(table_path, index=False)            

            print("Created table with %d rows" % len(df))

            return df
        except Exception as e:  
            print(e.__class__.__name__)
            
    def save_result(self):
        with open("%s/result.json" % self.result_folder, "w") as outfile:
            result_json = json.dumps(self.result, indent=4)
            outfile.write(result_json)

## Test

In [None]:
tg = TableGenerator_JSON()

runner = ExperimentRunner(tg, metadata_path="../benchmark/cfg.json")

print("\n====================\n")

for i in range(100):
    print("Table # %d" % (i+1))
    idx = "%d" % i
    table = runner.fetch_data(idx)
    print("\n====================\n")
    
runner.save_result() 