In [1]:
import pandas as pd
import numpy as np
import json
import os
import glob

import pickle
from pydantic import BaseModel
from openai import OpenAI

from tenacity import (
                        retry,
                        stop_after_attempt,
                        wait_random_exponential
)

from tqdm import tqdm

In [2]:
pd.set_option('display.max_rows', None)

In [3]:
path = '/home/manoelflorencio/cta_for_jd/LakeBench'
os.chdir(path)
print(os.getcwd())

/home/manoelflorencio/cta_for_jd/LakeBench


In [4]:
@retry(wait=wait_random_exponential(min=1,max=60), stop=stop_after_attempt(6))
def execute_prompt(client, system_msg, user_msg):
    completion = client.chat.completions.create(
                                            model="gpt-4o",
                                            messages=[
                                                        {
                                                            "role": "system", 
                                                             "content": f"{system_msg}"
                                                        },
                                                        {
                                                            "role": "user",
                                                            "content": f"{user_msg}"
                                                        }
                                                    ]
                                            )
    return completion

In [5]:
client = OpenAI()

In [6]:
prompts = np.load('Table_JOIN_Prompts/table_JD_with_columns_SG_CSV0000000000000925.npy')
prompts.shape

(1255, 2)

In [7]:
table_cartesians = pd.read_csv('table_cartesians.csv')
table_cartesians = table_cartesians[(table_cartesians['LEFT_TABLE'] == 'SG_CSV0000000000000925.csv') | (table_cartesians['RIGHT_TABLE'] == 'SG_CSV0000000000000925.csv')]
table_cartesians.shape

(1255, 2)

In [8]:
def generate_predictions(system_msg, user_msg, client):
    
    result = execute_prompt(client, system_msg, user_msg)
    jd_prediction = result.choices[0].message.content.split('Answer: ')[-1].strip()
    
    return jd_prediction

In [9]:
prompts[0][1]

'Target Table description: The table SG_CSV0000000000001178.csv presents information on the financial years, types of facilities, and their respective categories, either sold or rented, from 2006 to 2008, covering various types of community and service centers such as childcare centers, education centers, social service centers, and medical service centers.\n                  Target Table columns descriptions: [\'The "financial_year" column contains values indicating the year associated with financial data, reflecting the period during which transactions or records were categorized, specifically spanning various years such as 2006 and 2007.\'\n "The \'facility_type\' column describes the specific type of facility involved in each entry, such as childcare centres, education centres, social service centres, or other similar community and service-oriented establishments."\n "The \'category\' column describes the status of a facility in terms of its occupancy, indicating whether the facili

In [11]:
answers = np.array([])
# answers = np.load('answers.npy')
cont, step = len(answers), 500

for i in tqdm(range(cont,len(prompts))):

    system_msg, user_msg = prompts[i][0], prompts[i][1]
    result = generate_predictions(system_msg, user_msg, client)
    answers = np.append(answers, result)
    
    cont += 1
    if(cont % step):
        np.save('answersSG_CSV0000000000000925.npy', answers)
np.save('answersSG_CSV0000000000000925.npy', answers)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1255/1255 [32:01<00:00,  1.53s/it]


In [10]:
answers = np.load('answers.npy')

In [18]:
len(answers)

10000

In [12]:
table_cartesians = table_cartesians.iloc[:10000,:]
table_cartesians['JOINABLE'] = answers

In [13]:
table_cartesians

Unnamed: 0,LEFT_TABLE,RIGHT_TABLE,JOINABLE
967,SG_CSV0000000000001178.csv,SG_CSV0000000000000925.csv,No
2221,SG_CSV0000000000000451.csv,SG_CSV0000000000000925.csv,No
3474,SG_CSV0000000000000147.csv,SG_CSV0000000000000925.csv,No
4726,SG_CSV0000000000000048.csv,SG_CSV0000000000000925.csv,No
5977,SG_CSV0000000000001638.csv,SG_CSV0000000000000925.csv,No
7227,SG_CSV0000000000000744.csv,SG_CSV0000000000000925.csv,No
8476,SG_CSV0000000000001161.csv,SG_CSV0000000000000925.csv,No
9724,SG_CSV0000000000001130.csv,SG_CSV0000000000000925.csv,No
10971,SG_CSV0000000000001358.csv,SG_CSV0000000000000925.csv,No
12217,SG_CSV0000000000000270.csv,SG_CSV0000000000000925.csv,No


In [14]:
table_cartesians.to_csv('table_cartesians_JD_with_columns_descriptions_SG_CSV0000000000000925.csv', index=False)