## Imports

In [None]:
import openai
import tiktoken

import pandas as pd
import json

from dateutil import parser as date_parser
from unidecode import unidecode

import os
import time
import re

In [None]:
openai.api_key = "<your_open_api_key>"

## Utils

In [None]:
class ParsingError(Exception):
    pass

class WrongKeyError(Exception):
    pass

## Table Generator

In [None]:
class TableGenerator_JSON():
    KEYS_TEMPLATE = """
    You are a retriever of facts.
    We want to create a table with the detailed information about %s.
    %s.
    List all %s entities for the table. 
    The response will be formatted as JSON list shown below.
    
    RESPONSE FORMAT:
    [{
        %s
    }]
    """ 
   
    CELL_TEMPLATE = """
    You are a retriever of facts.
    We want to create a table with the detailed information about %s.
    Columns in the table are %s.
    %s.  
    For the table row whose key is %s what is the value of attribute %s.
    The response will be formatted as JSON dictionary shown below.
    Pay special attention to wrap all property names and values in double quotes!
    
    RESPONSE FORMAT:
    {
        %s
    }
    """ 
    
    def _norm_field(self, s):
        s = s.lower().replace(" ","_").replace("-","_").replace(".", "").replace(",","_")\
                .replace("(", "").replace(")", "").replace(":", "").replace('"','').replace("'","")\
                .replace("/", "")
        return re.sub('_+', '_', s)
    
    def _key_columns(self, keys):
        if len(keys) == 1:
            return "The key column in the table is %s" % keys[0]
        else:
            return "The key columns in the table are %s" % ", ".join(keys)
    
    def generate_keys_prompt(self, query, keys):  
        keys_json = []
        keys = [self._norm_field(k) for k in keys]
        key_columns = self._key_columns(keys)
        for key in keys:
            keys_json.append('"%s": "%s"' % (key, key))
        response_format = ', '.join(keys_json)
        prompt = self.KEYS_TEMPLATE % (query, key_columns, ", ".join(keys), response_format)        
        return prompt
    
    def parse_keys_response(self, response, keys): 
        try:
            if not response.startswith("[") and "[" in response:
                response = response[response.find("["):]

            if not response.endswith("]") and "]" in response:
                response = response[:response.rfind("]")+1]

            if '[' not in response and ']' not in response and '{' in response and '}' in response:
                response = '[' + response + ']'    

            response_json = json.loads(response)

            if isinstance(response_json, dict) and len(response_json.keys()) == 1:
                response_json = list(response_json.values())[0]    
        except:  
            split_response = response.split("{")
            response_json = []
            for s in split_response[1:]:
                split_s = s.split("}")
                if len(split_s) > 1:
                    content = split_s[0]
                    attributes = content.split(",")
                    elements = {}
                    for attr in attributes:
                        knv = attr.split(":")   
                        if len(knv) > 1:
                            parsed_k = "%s" % knv[0].replace('"','').strip()
                            parsed_v = "%s" % knv[1].replace('"','').strip()
                            elements[parsed_k] = parsed_v

                    response_json.append(elements)
        
        norm_keys = [self._norm_field(key) for key in keys]
        keys_json = []
        for item in response_json:
            key_item = {}
            for key in norm_keys:
                key_item[key] = item[key]
            keys_json.append(key_item)    
        
        return keys_json
    
    def generate_cell_prompt(self, query, keys, fields, field, fetched_key):
        keys = [self._norm_field(k) for k in keys]
        key_columns = self._key_columns(keys)    
        
        fields = [self._norm_field(f) for f in fields]
        all_columns = ", ".join(fields)
        
        keys = [self._norm_field(k) for k in keys]
        key_json = []   
        for key in keys:
            key_value = fetched_key[key]
            key_json.append("%s = %s" % (key, key_value))

        row_key = '(%s)' % ', '.join(key_json)
        
        field = self._norm_field(field)
        response_format = '{"%s": "value of %s"}' % (field, field)
        
        prompt = self.CELL_TEMPLATE % (query, all_columns, key_columns, row_key, field, response_format)    
        return prompt    
    
    def parse_cell_response(self, response): 
        cell_content = response[response.find("{"):response.rfind("}")+1]
        try:
            return json.loads(cell_content)
        except:
            cell_content = cell_content[1:-1]
            knv = cell_content.split(":") 
            if len(knv) == 2:
                parsed_k = "%s" % knv[0].replace('"','').strip()
                parsed_v = "%s" % knv[1].replace('"','').strip()
                return {parsed_k : parsed_v}
            else:
                raise ParsingError()
    
    def create_dataframe(self, rows, columns, keys, df_ref): 
        df = pd.DataFrame.from_dict(rows)  
        columns = [self._norm_field(col) for col in columns]
        df = df[columns]
        df.columns = df_ref.columns
        df = df.drop_duplicates(subset=keys)
        return df

## Experiment Runner

In [None]:
class ExperimentRunner():
    MODEL = "gpt-3.5-turbo-instruct-0914"
    NOTE = 'cell_by_cell'
    MAX_LEN = 4097
    
    def __init__(self, table_generator, metadata_path):
        with open(metadata_path, "rb") as f:
            self.metadata = json.load(f)
            
        self.table_generator = table_generator
        self.encoding = tiktoken.encoding_for_model(self.MODEL)
        
        self.result_folder = "%s_%s_%s" % (self.MODEL.replace('-', '_'), 
                                                   self.NOTE,
                                                   time.strftime("%Y%m%d-%H%M%S"))
        
        print("Experiment result folder: %s" % self.result_folder)
        
        os.makedirs(self.result_folder)
        os.makedirs("%s/tables" % self.result_folder)
        
        self.result = {}
        
    def normalize_key(self, value, is_date=False):
        if value != value:
            return ''

        if is_date:
            try:
                return str(date_parser.parse(value))
            except:
                try:
                    return str(pd.to_datetime(value))
                except:
                    pass  

        if isinstance(value, str):  
            value = value.lower()

            if value in ('none', 'n/a', 'nan', '-'):
                return '' 

            value = value.replace('&', 'and')

            if value == 'united states':
                return 'usa'
            if value == 'united kingdom':
                return 'uk'

            value = unidecode(value)        
            value = ''.join(c for c in value if c.isalnum()) 
            return value

        return str(value)

    def normalize_primary_columns(self, df, norm_columns, date_columns, primary_columns):
        for col in norm_columns:
            df[col] = df[col].apply(self.normalize_key, col in date_columns)  
        return [tuple(r) for r in df[primary_columns].to_numpy()]    
        
    def fetch_data(self, idx):
        task = self.metadata[idx]
        
        task_name = task['name']        
        print("Fetching data for %s" % task_name)
        
        query = task['table_title']
        keys = task['keys']
        columns = task['columns'] 
        date_columns = task['dateColumns']
        
        keys_prompt = self.table_generator.generate_keys_prompt(query, keys)
        self.result[idx] = {'keys_prompt': keys_prompt}        
            
        try:
            max_tokens = self.MAX_LEN - len(self.encoding.encode(keys_prompt))
            result = openai.Completion.create(engine=self.MODEL, prompt=keys_prompt, temperature=0, max_tokens=max_tokens)
            keys_response = result['choices'][0]['text'].strip()

            self.result[idx]['keys_response'] = [keys_response]    

            parsed_keys_response = self.table_generator.parse_keys_response(keys_response, keys)
            
            print("Fetched %d key instances" % len(parsed_keys_response))                       
            
            self.result[idx]['cell_prompts'] = []
            self.result[idx]['cell_responses'] = []
            rows = []            
            
            df_ref = pd.read_csv(task['path'])
            ref_entities = self.normalize_primary_columns(df_ref, columns, date_columns, keys)
            
            norm_keys = [self.table_generator._norm_field(k) for k in keys]
            norm_date_cols = [self.table_generator._norm_field(c) for c in date_columns]
            
            keys_already_checked = set()
            
            for key_instance in parsed_keys_response:
                keys_tuple = []
                for nk in norm_keys:
                    keys_tuple.append(self.normalize_key(key_instance[nk], nk in norm_date_cols))
                keys_tuple = tuple(keys_tuple)
  
                if keys_tuple in keys_already_checked:
                    continue
                keys_already_checked.add(keys_tuple)
                
                row = key_instance.copy()                  
                for col in columns:
                    if col in keys:
                        continue
                    try:
                        if not keys_tuple in ref_entities:
                            raise WrongKeyError()
            
                        cell_prompt_i = self.table_generator.generate_cell_prompt(query, keys, columns, col, key_instance)   
                        self.result[idx]['cell_prompts'].append(cell_prompt_i)

                        max_tokens = self.MAX_LEN - len(self.encoding.encode(cell_prompt_i))
                        result = openai.Completion.create(engine=self.MODEL, prompt=cell_prompt_i, temperature=0, max_tokens=max_tokens)
                        cell_response = result['choices'][0]['text'].strip()
                        self.result[idx]['cell_responses'].append(cell_response)

                        parsed_cell_response = self.table_generator.parse_cell_response(cell_response)
                        row.update(parsed_cell_response)
                    except Exception as ie:
                        print(ie.__class__.__name__)
                        field = self.table_generator._norm_field(col)
                        failed_cell = '{"%s": "%s"}' % (field, ie.__class__.__name__)
                        row.update(json.loads(failed_cell))
                rows.append(row)                     
        
            df_ref = pd.read_csv(task['path'])
            df = self.table_generator.create_dataframe(rows, columns, keys, df_ref) 

            table_path = "%s/tables/%s.csv" % (self.result_folder, task_name)
            self.result[idx]['table_path'] = table_path                
            df.to_csv(table_path, index=False)            

            print("Created table with %d rows" % len(df))

            return df
        except Exception as e:  
            print(e.__class__.__name__)
            
    def save_result(self):
        with open("%s/result.json" % self.result_folder, "w") as outfile:
            result_json = json.dumps(self.result, indent=4)
            outfile.write(result_json)

## Test

In [None]:
tg = TableGenerator_JSON()

runner = ExperimentRunner(tg, metadata_path="../benchmark/cfg.json")

print("\n====================\n")

for i in range(100):
    print("Table # %d" % (i+1))
    idx = "%d" % i
    table = runner.fetch_data(idx)
    print("\n====================\n")
    
runner.save_result()   