# Testing notebook

## Working the CSVs and Excel sheets

In [1]:
# Exploring the 'training' dataset
import pandas as pd
df = pd.read_csv("txt2sql_alerce_train_v2.csv")
df.head()

Unnamed: 0,req_id,request,table_info,external_knowledge,domain_knowledge,gold_query,difficulty,type,nested_type,rephrased_request,rephrased_request_gpt-3.5-turbo-0125_t0.4,rephrased_request_gpt-4o-2024-05-13_t0.2
0,13,Give me all the SNe that were first detected b...,"['object', 'probability']",\n-- mjd date for December = 59914.0\n-- mjd d...,\n-- Super Nova (SNe) is a large explosion tha...,"\nSELECT\n object.oid, probability.class_na...",simple,object,none,,,
1,10,Get the object identifiers and probabilities i...,"['probability', 'object']",0,0,"\nSELECT\n sq1.oid, sq1.probability as SN_pro...",medium,object,simple,,,
2,15,"Get the object identifiers, probabilities in t...","['object', 'probability', 'detection', 'magstat']",\n-- mjd date for September 01 = 60188.0\n-- m...,\n-- A fast riser is defined as an object whos...,"\nSELECT\n sq.oid, sq.probability, sq.candi...",advanced,other,multi,,,
3,4,"Get the object identifier, candidate identifie...","['object', 'probability', 'magstat', 'detection']",\n-- mjd date for the start of the year 2019 =...,0,"\nSELECT\n sq.oid, sq.fid, sq.dmdt_first,\n ...",advanced,other,multi,,,
4,25,Query objects within 10 degress of the next po...,"['probability', 'object']",\n-- mjd date for February 01 = 59976.0\n-- mj...,0,"\nWITH catalog ( source_id, ra, dec) AS (\n ...",advanced,spatial,simple,,,


In [2]:
# Reading the excel file
excel = pd.read_excel("SQLusecases_alerce.xlsx", 
                      sheet_name="examples_alerce_usecasesV3_1")

# Cleaning the excel a bit
excel = excel.drop(columns=["Unnamed: 0.1", "Unnamed: 0"])
excel.head()

Unnamed: 0,req_id,request,table_info,external_knowledge,domain_knowledge,gold_query,difficulty,type,nested_type,Set,python_format
0,0,Get objects that are likely to be YSOs (possib...,"['probability', 'feature']",\n-- feature.name can be 'Multiband_period'\n-...,\n-- Multiband_period: Period obtained using t...,"\nSELECT\n oid, probability, value, name, fid...",advanced,object,tree,Train,"sub_query_1='''\nSELECT\n feature.oid, prob_o..."
1,1,Get all the objects classified as AGN with a p...,"['object', 'probability', 'feature', 'magstat']",\n-- object.ndet represents the number of dete...,\n-- Amplitude: Half of the difference between...,"\nSELECT\n sq.oid, sq.value, sq.name, sq.fid ...",advanced,object,tree,Test,\nsub_query_object='''\nSELECT\n object.oid...
2,2,Give me the objects classified as YSO by their...,['probability'],,,"\nSELECT\n oid, probability\nFROM\n prob...",simple,object,none,Train,"query='''\nSELECT\n oid, probability\nFROM\..."
3,3,Give me the objects classified as YSO by the l...,"['object', 'probability']","\n-- last june in mjd date: [start=60096.0, en...",,\nSELECT\n *\nFROM\n probability\nWHERE\...,simple,object,simple,Train,query=f'''\nSELECT\n *\nFROM\n probabili...
4,4,"Get the object identifier, candidate identifie...","['object', 'probability', 'magstat', 'detection']",\n-- mjd date for the start of the year 2019 =...,,"\nSELECT\n sq.oid, sq.fid, sq.dmdt_first,\n ...",advanced,other,multi,Train,# objects classified as SN II with probability...


## Tests

### Gold values and query to test

In [8]:
# Select a query
query = "Get the object identifier, candidate identifier, magnitudes, magnitude errors, and band identifiers as a function of time of the objects classified as SN II in the year 2019-2022, with probability larger than 0.6, initial rise rate greater than 0.5 mag/day in ZTF g and r-band and number of detections greater than 50."

# Obtain the gold SQL query and Python query
sql_gold = excel[excel["request"] == query]["gold_query"].item()
python_gold = excel[excel["request"] == query]["python_format"].item()

# Obtain the necessary tables
gold_tables = excel[excel["request"] == query]["table_info"].item()

# Print all in orderly fashion
print("Gold values\n")
print("Tables needed for the query:")
print(gold_tables + "\n")
print("SQL gold query:")
print(sql_gold)
print("Python gold query:\n")
print(python_gold)

Gold values

Tables needed for the query:
['object', 'probability', 'magstat', 'detection']

SQL gold query:

SELECT
  sq.oid, sq.fid, sq.dmdt_first,
  detection.candid, detection.fid as f_id,detection.magpsf, detection.sigmapsf_corr, detection.sigmapsf_corr_ext
FROM
  (
SELECT
  magstat.oid, magstat.fid, magstat.dmdt_first
FROM
  (
SELECT
    object.oid
FROM
    object INNER JOIN
    probability
    ON object.oid = probability.oid
WHERE
    probability.classifier_name='lc_classifier'
    AND probability.class_name='SNII'
    AND probability.probability > 0.6
    AND object.ndet > 50
) as obj_oids
    INNER JOIN
    magstat ON magstat.oid = obj_oids.oid
WHERE
  magstat.dmdt_first < -0.5
  AND (magstat.fid = 1 OR magstat.fid = 2)
) AS sq
  INNER JOIN detection
  ON sq.oid = detection.oid
WHERE
  detection.mjd > 58484.0
  AND detection.mjd < 59944.0
ORDER BY oid

Python gold query:

# objects classified as SN II with probability larger than 0.6 and with at least 50 detections
sub_query_o

In [9]:
# Running the gold query
from secret.config import SQL_URL
import requests
import sqlalchemy as sa

# Setup params for query engine
params = requests.get(SQL_URL).json()['params']
engine = sa.create_engine(f"postgresql+psycopg2://{params['user']}:{params['password']}@{params['host']}/{params['dbname']}")
engine.begin()

pd.read_sql_query(sql_gold, con=engine)

Unnamed: 0,oid,fid,dmdt_first,candid,f_id,magpsf,sigmapsf_corr,sigmapsf_corr_ext
0,ZTF19aapafit,2,-0.558783,829383361815015003,1,17.737854,,
1,ZTF19aapafit,2,-0.558783,921321681815015003,1,18.644684,,
2,ZTF19aapafit,2,-0.558783,962155601815015013,1,20.914700,,
3,ZTF19aapafit,2,-0.558783,851385051815015002,2,17.666134,,
4,ZTF19aapafit,2,-0.558783,953181561815015001,2,19.305613,,
...,...,...,...,...,...,...,...,...
1532,ZTF22aavpkwo,1,-1.033192,2071244590015015007,1,19.303177,0.087055,0.088534
1533,ZTF22aavpkwo,1,-1.033192,2071181600015015005,2,18.701773,0.060684,0.062577
1534,ZTF22aavpkwo,1,-1.033192,2069173620015015008,2,18.709686,0.064962,0.066756
1535,ZTF22aavpkwo,1,-1.033192,2066178460015015011,1,19.141783,0.084445,0.085640


### Trying out the pipelines

In [11]:
import pandas as pd
from pprint import pprint
from main import run_pipeline, engine
engine.begin()

# Model to use
#model = "claude-3-5-sonnet-20240620"
model = "gpt-4o-2024-08-06"

# Format for the pipeline
format = "python"

# RAG parameters
max_tokens = 1000
size = 700
overlap = 300
quantity = 10

# Running the pipeline
result, total_usage, prompts = run_pipeline(query, model, max_tokens, size, 
                                            overlap, quantity, format, False, 
                                            engine, rag_pipe=True, 
                                            self_corr=True)
print("Resulting table:")
print(result)
print("Total usage of the pipeline:")
pprint(total_usage)

# The prompts used will be saved in this file
with open(f"prompts/examples/prompts_query_{model}.txt", "w") as f:
    f.write(str(prompts))

Tables needed: [object, probability, detection, magstat]
Difficulty: advanced
Resulting python query: 
# Sub-query to get object identifiers with the required classification and probability
sub_query_object = f'''
SELECT DISTINCT
    probability.oid
FROM
    probability
WHERE
    probability.classifier_name = 'lc_classifier'
    AND probability.class_name = 'SNII'
    AND probability.probability > 0.6
    AND probability.ranking = 1
'''

# Sub-query to filter objects based on the time range and number of detections
sub_query_time_ndet = f'''
SELECT
    object.oid
FROM
    object
WHERE
    object.firstmjd BETWEEN 58484 AND 59580  -- Corresponds to 2019-01-01 to 2022-12-31 in MJD
    AND object.ndet > 50
'''

# Sub-query to filter objects based on initial rise rate in ZTF g and r-band
sub_query_rise_rate_g = f'''
SELECT
    magstat.oid
FROM
    magstat
WHERE
    magstat.fid = 1  -- Assuming fid=1 corresponds to ZTF g-band
    AND magstat.dmdt_first > 0.5
'''

sub_query_rise_rate_r = f'''