# Compare sqlite, PostgreSQL and parquet versions of truth tables

Owner: **Joanne Bogart [@jrbogart](https://github.com/LSSTDESC/DC2-analysis/issues/new?body=@jrbogart)**  
Last Verified to Run: 

This notebook makes various comparisons between different manifestations (sqlite, parquet, PostgreSQL) of the "same" truth table

__Logistics__: This notebook is intended to be run through the JupyterHub NERSC interface available here: https://jupyter.nersc.gov. To setup your NERSC environment, please follow the instructions available here: 
https://confluence.slac.stanford.edu/display/LSSTDESC/Using+Jupyter+at+NERSC
### Prerequisites
* For access to PostgreSQL see [Getting Started with PostgreSQL at NERSC](https://confluence.slac.stanford.edu/x/s4joE), especially the "Preliminaries" section


In [None]:
import os.path
import psycopg2
import sqlite3
import pyarrow.parquet as pq
import pyarrow as pa                 # maybe not necessary
import numpy as np
%matplotlib inline 
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
class Catalog:
    """
    Store information needed to access a catalog in each of the three formats
    """
    _sqlite_root = None
    _pq_root = None
    _pg_conn = None
    Summary = 1
    Variability = 2
    Auxiliary = 3
    
    def set_sqlite_root(r):
        Catalog._sqlite_root = r
        
    def set_parquet_root(r):
        Catalog._parquet_root = r
        
    def set_pg_connection(c):
        Catalog._pg_conn = c
        
    def pg_connection():
        return Catalog._pg_conn
        
    def __init__(self, pg_table, sqlite_path, sqlite_table, parquet_path, 
                 table_type=Summary, name=''):
        self._pg_table = pg_table
        self._sqlite_path = sqlite_path
        self._sqlite_table = sqlite_table
        self._parquet_path = parquet_path
        self._table_type = table_type
        self._name = name
        if table_type < Catalog.Summary or table_type > Catalog.Auxiliary:
            print("Unknown table type")
        
    @property
    def pg_table(self):
        return self._pg_table
    
    @property 
    def sqlite_table(self):
        return self._sqlite_table
        
    @property
    def sqlite_abspath(self):
        return os.path.join(Catalog._sqlite_root, self._sqlite_path)
    
    @property
    def parquet_abspath(self):
        return os.path.join(Catalog._parquet_root, self._parquet_path)
    
    @property 
    def table_type(self):
        return self._table_type
    
    @property
    def name(self):
        return self._name
    
 

    

Make the PostgreSQL db connection

In [None]:
dbname = 'desc_dc2_drp'
dbuser = 'desc_dc2_drp_user'
dbhost = 'nerscdb03.nersc.gov'
dbconfig = {'dbname' : dbname, 'user' : dbuser, 'host' : dbhost}
pg_conn = psycopg2.connect(**dbconfig)

Catalog.set_pg_connection(pg_conn)

In [None]:
Catalog.set_sqlite_root('/global/cfs/cdirs/lsst/shared/DC2-prod/Run3.1i/truth')
Catalog.set_parquet_root('/global/cscratch1/sd/jrbogart/desc/truth/pq')
pg_tables = ['agn_truth.truth_summary', 'agn_truth.agn_auxiliary_info',
            'agn_truth.agn_variability_truth', 'lensed_host_truth.truth_summary',
            'lensed_agn_truth.truth_summary',
            'lensed_agn_truth.lensed_agn_variability_truth',
            'lensed_sne_truth.truth_summary',
            'lensed_sne_truth.lensed_sn_variability_truth']

sqlite_paths = ['agntruth/agn_truth_cat.db', 'agntruth/agn_truth_cat.db',
               'agntruth/agn_variability_truth_cat.db',
               'lensed_hosttruth/lensed_host_truth_cat.db',
               'lensed_agntruth/lensed_agn_truth_cat.db',
               'lensed_agntruth/lensed_agn_variability_truth_cat.db',
               'lensed_snetruth/lensed_sne_truth_cat.db',
               'lensed_snetruth/lensed_sne_truth_cat.db']
sqlite_tables = ['truth_summary', 'agn_auxiliary_info', 'agn_variability_truth',
                'truth_summary', 'truth_summary','lensed_agn_variability_truth',
                'truth_summary', 'lensed_sn_variability_truth']

parquet_paths = ['agn_truth_summary.parquet', 'agn_auxiliary_info.parquet',
                'agn_variability_truth.parquet', 'lensed_host_truth_summary.parquet',
                'lensed_agn_truth_summary.parquet',
                'lensed_agn_variability_truth.parquet',
                'lensed_sne_truth_summary.parquet',
                'lensed_sn_variability_truth.parquet']
table_types = [Catalog.Summary, Catalog.Auxiliary, Catalog.Variability, Catalog.Summary,
              Catalog.Summary, Catalog.Variability, Catalog.Summary, Catalog.Variability]
names = ['agn summary', 'agn aux', 'agn variabity', 'lensed host summary', 'lensed agn summary',
        'lensed agn variability', 'lensed sn summary', 'lensed sn variability']
n_tables = len(parquet_paths)
catalogs = []
for i in range(n_tables):
    catalogs.append(Catalog(pg_tables[i], sqlite_paths[i], sqlite_tables[i],
                           parquet_paths[i], table_types[i], names[i]))

Convenience utilities

In [None]:
#def get_pg_columns(conn, catalog, select_list, alias_list=None, where=None):
def get_pg_columns(catalog, select_list, alias_list=None, where=None):
    '''
    Returns data frame version of db select output.  
    Works for PostgreSQL.  It seems sqlite cursors don't support "with"
    '''
    cols = ','.join(select_list)
    q = f"""SELECT {cols} FROM {catalog.pg_table}"""
    if where:
        q = ''.join([q,' WHERE ',where])
    conn = Catalog.pg_connection()

    with conn.cursor() as cursor:
        cursor.execute(q)
        records = cursor.fetchall()
    if alias_list:
        return pd.DataFrame(records, columns=alias_list)
    else:
        return pd.DataFrame(records, columns=select_list)

In [None]:
def get_sqlite_columns(catalog, select_list, alias_list=None, where=None):
    '''
    Returns data frame version of db select output.  
    '''
    cols = ','.join(select_list)
    q = f"""SELECT {cols} FROM {catalog.sqlite_table}"""
    if where:
        q = ''.join([q,' WHERE ',where])
    conn = sqlite3.connect(catalog.sqlite_abspath)
    cursor = conn.cursor()
  
    cursor.execute(q)
    records = cursor.fetchall()
    if alias_list:
        return pd.DataFrame(records, columns=alias_list)
    else:
        return pd.DataFrame(records, columns=select_list)

In [None]:
def get_parquet_columns(catalog, column_list):
    return(pq.read_table(catalog.parquet_abspath, column_list)).to_pandas()
    

Standard queries, to be applied to each representation should include
* getting length of id column (all tables have an id column)
* comparing values for a couple rows.  For sqlite and parquet, expect rows will be returned in the same order.  For PostgreSQL might have to match ids (for variability tables would have to match (id, obshistid) for uniqueness)
* plotting ra, dec (summary tables)
* for variability tables, plot a light curve
* histogramming a column or two

In [None]:
# Compare counts

for i in range(n_tables):
    pg_df = get_pg_columns(catalogs[i], ['count(id) as count_id'],
                           alias_list=['count_id'])
    print(f"\nFor catalog {catalogs[i].name}\n     pg count={pg_df['count_id'][0]}")    
   
    sq_df = get_sqlite_columns(catalogs[i], ['count(id) as count_id'],
                               alias_list=['count_id'])
    pq_df = get_parquet_columns(catalogs[i], ['id'])
   
    print(f" sqlite count={sq_df['count_id'][0]}")
    print(f"parquet count={len(pq_df['id'])}")

### Compare Rows
To find column names, can open parquet file.  Then pq_file.schema.column(i) returns information about the ith column.   Includes field `name` and `physical_type`

For truth summary and auxiliary tables, id is unique.  For variability tables, use
(id, obsHistID)   (the field is lowercase in the PostgreSQL db, but it shouldn't care if uppercase is passed in)

To start just print out values for all columns for each format (excep not coord in Postgres since the others don't have it) for a row or two and compare visually. Later might want to check for within-tolerance for floating point fields and identity for the others.

### Plot coverage
For each table of type Catalog.Summary, plot (ra, dec)

In [None]:
def plot_coverage(catalog):
    # Get the data
    columns = ['ra', 'dec']

    if catalog.table_type != Catalog.Summary:
        print('Only plot coverage for truth summary catalogs')
        return
    
    pg_data = get_pg_columns(catalog, columns)
    sq_data = get_sqlite_columns(catalog, columns)
    parquet_data = get_parquet_columns(catalog, columns)
    
    # and plot
    plt.figure(figsize=(15, 6))
    plt.suptitle(f'{catalog.name} Coverage')
    plt.subplot(131).set_title('PostgreSQL')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(pg_data['ra']), np.array(pg_data['dec']), s=0.6)
 
    plt.subplot(132).set_title('SQLite')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(sq_data['ra']), np.array(sq_data['dec']), s=0.6)   
    plt.subplot(133).set_title('Parquet')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(parquet_data['ra']), np.array(parquet_data['dec']), s=0.6)       

In [None]:
plot_coverage(catalogs[0])

In [None]:
plot_coverage(catalogs[1])

In [None]:
plot_coverage(catalogs[3])

In [None]:
plot_coverage(catalogs[4])

In [None]:
plot_coverage(catalogs[6])

### Histogram a column
Find delta flux readings for stars which are expected to be in the field of view for a particular visit. This sort of query returns practically instantly. In the `stellar_variability_truth` both of the columns mentioned in the `WHERE` clause - `id` and `obshistid` - are indexed.

### Light curves
Find a reasonable candidate (maybe one each agn, lensed agn, lensed sn?) and compare light curves

In [None]:
def format_cone_search(coord_column, ra, dec, radius):
    '''
    Parameters
    coord_column:  name of column of type earth in the table
    ra:  ra value at center of cone (degrees)
    dec:  dec value at center of cone (degrees)
    radius: radius of cone (arcseconds)
    
    Returns
    Condition to be inserted into WHERE clause for the query
    '''
    cond = f"""conesearch({coord_column},'{ra}','{dec}','{radius}')"""
    return cond

In [None]:
# pick a location that probably gets lots of visits
# for (53.0, -28.1, 80)  get 
ra = 53.0      
dec = -28.1
radius = 80   
truth_schema = 'agn_truth'
tbl_spec = f"""SELECT S.id, S.ra, S.dec, max(abs(V.delta_flux)),count(V.bandpass) AS visit_count 
           FROM {truth_schema}.truth_summary AS S JOIN 
           {truth_schema}.agn_variability_truth AS V ON S.id=V.id """
where = "WHERE " + format_cone_search('S.coord', ra, dec, radius) + " AND S.is_variable=1 "
group_by = " GROUP BY S.id,S.ra,S.dec"
q = tbl_spec + where + group_by
print(q)

# This takes a couple minutes to complete
dbconn = Catalog.pg_connection()
with dbconn.cursor() as cursor:
    %time cursor.execute(q)
    records = cursor.fetchall()


In [None]:
df_lengths = pd.DataFrame(records, columns=['id', 'ra','dec', 'max_delta_flux','count'])
df_lengths

Similar to above, but this time don't count visits. Get the delta_flux values instead

In [None]:

columns = ['S.id', 'ra', 'dec', 'bandpass', 'delta_flux']
col_list = (',').join(columns)
tbl_spec = f"""SELECT {col_list} 
           FROM {truth_schema}.truth_summary AS S JOIN 
           {truth_schema}.agn_variability_truth AS V ON S.id=V.id """
where = "WHERE " + format_cone_search('S.coord', ra, dec, radius) + " and S.is_variable=1 "
q = tbl_spec + where
print(q)

with dbconn.cursor() as cursor:
    %time cursor.execute(q)
    records_lc = cursor.fetchall()

In [None]:
df_cone_lcs = pd.DataFrame(records_lc, columns=columns)
df_cone_lcs

### Plot light curves for one star
Pick the fourth object from the results of the first query in this section (id=2643077650549) since max_delta_flux is large
#### Get the data
Get delta_flux and time values for the plot and some summary information about the star. Use `ORDER BY` clause so that data are presented conveniently for plotting.

In [None]:
id = 2643077650549
var_tbl = 'agn_variability_truth'
lc_q = f"""SELECT bandpass,mjd,delta_flux FROM {truth_schema}.{var_tbl}
       WHERE id='{id}' ORDER BY bandpass, mjd;"""
print(lc_q)
with dbconn.cursor() as cursor:
    %time cursor.execute(lc_q)
    lc_records = cursor.fetchall()
print(len(lc_records))
df_single_lc = pd.DataFrame(lc_records, columns=['bandpass','mjd','delta_flux'])
df_single_lc


Save ra, dec for the object into variables to be used below

In [None]:
sum_tbl = 'truth_summary'
sum_q = f"""SELECT ra,dec FROM {truth_schema}.{sum_tbl} 
         WHERE id='{id}';"""
print(sum_q)
with dbconn.cursor() as cursor:
    %time cursor.execute(sum_q)
    sum_record = cursor.fetchone()
lc_ra = sum_record[0]
lc_dec = sum_record[1]
print(f'ra={lc_ra}, dec={lc_dec}')

#### Plotting

In [None]:
from astropy.time import Time
def plot_band_lc(axes, times, fluxes, params):
    out = axes.scatter(np.asarray(times), np.asarray(fluxes), **params)

In [None]:
def plot_level(axes, yvalue, params):
    xmin, xmax = axes.get_xlim()
    out = axes.plot(np.asarray([xmin, xmax]), np.asarray([yvalue, yvalue]), **params)

In [None]:
def format_title(id, ra, dec, band=None, object='star'):  
    if band is None:
        return f'Per-band light curves for star {id} at (ra,dec)=({ra}, {dec})'
    else:
        return f'Light curve for {object} {id}, band={band} at (ra,dec)=({ra}, {dec})'

In [None]:
def plot_object(title, the_data, band=None):
    '''
    Plot r, g and i light 'curves' (delta_flux as scatter plot) for an object
    or plot only requested band
    Parameters
    -----------
    title : string
    the_data : data frame which must include columns filtername, obsstart, mag
    '''
    good_d = the_data[(np.isnan(the_data.delta_flux)==False)]
    red_d = good_d[(good_d.bandpass=="r")]
    green_d = good_d[(good_d.bandpass=="g")]
    i_d = good_d[(good_d.bandpass=="i")]
    #print("red data shape: ", red_e.shape, "   green data shape: ", green_e.shape, "  i data shape: ", i_e.shape)
    fix, axes = plt.subplots(figsize=(12,8))

    plt.title(title)
    plt.xlabel('Julian date')
    plt.ylabel('Delta flux')

    params_r = {'marker' : 'o', 'label' : 'r band', 'color' : 'red'}
    params_g = {'marker' : 'o', 'label' : 'g band', 'color' : 'green'}
    params_i = {'marker' : 'o', 'label' : 'i band', 'color' : 'orange'}
    #print('In plot_object printing i-band values')
    #for ival in list(i_d['mag']): print(ival)
    if band is None or band=='r':
        plot_band_lc(axes, list(red_d['mjd']), list(red_d['delta_flux']), params_r)
    if band is None or band=='g':
        plot_band_lc(axes, list(green_d['mjd']), list(green_d['delta_flux']), params_g)
    if band is None or band=='i':
        plot_band_lc(axes, list(i_d['mjd']), list(i_d['delta_flux']), params_i)
    #plot_level(axes, coadd_mag['r'], {'label' : 'r coadd mag', 'color' : 'red'})
    #plot_level(axes, coadd_mag['g'], {'label' : 'g coadd mag', 'color' : 'green'})
    #plot_level(axes, coadd_mag['i'], {'label' : 'i coadd mag', 'color' : 'orange'})
    if band is None:
        plt.legend()

In [None]:
for band in ('r','g','i'):
    title = format_title(id, lc_ra, lc_dec, band, object='agn')
    plot_object(title, df_single_lc, band)

In [None]:
# Now do it again for sqlite data and parquet data

#### Is it in the object table?
First get truth information for our chosen star, then try to find a match, restricting to point sources within a few arcseconds

In [None]:
truth_q = f"SELECT ra,dec,flux_g,flux_r,flux_i from {truth_schema}.truth_summary where id='{id}'"
print(truth_q)
with dbconn.cursor() as cursor:
    %time cursor.execute(truth_q)
    truth_records = cursor.fetchall()
#truth_records
print(len(truth_records))
truth_df = None
truth_df = pd.DataFrame(truth_records, columns=['ra', 'dec', 'flux_g', 'flux_r', 'flux_i'])
truth_df.shape
truth_df