# Compare sqlite, PostgreSQL and parquet versions of truth tables

Owner: **Joanne Bogart [@jrbogart](https://github.com/LSSTDESC/DC2-analysis/issues/new?body=@jrbogart)**  
Last Verified to Run: 

This notebook makes various comparisons between different manifestations (sqlite, parquet, PostgreSQL) of the "same" truth table

__Logistics__: This notebook is intended to be run through the JupyterHub NERSC interface available here: https://jupyter.nersc.gov. To setup your NERSC environment, please follow the instructions available here: 
https://confluence.slac.stanford.edu/display/LSSTDESC/Using+Jupyter+at+NERSC
### Prerequisites
* For access to PostgreSQL see [Getting Started with PostgreSQL at NERSC](https://confluence.slac.stanford.edu/x/s4joE), especially the "Preliminaries" section


In [None]:
import os.path
import psycopg2
import sqlite3
import pyarrow.parquet as pq
import pyarrow as pa                 # maybe not necessary
import numpy as np
%matplotlib inline 
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
class Catalog:
    """
    Store information needed to access a catalog in each of the three formats
    """
    _sqlite_root = None
    _pq_root = None
    _pg_conn = None
    Summary = 1
    Variability = 2
    Auxiliary = 3
    
    def set_sqlite_root(r):
        Catalog._sqlite_root = r
        
    def set_parquet_root(r):
        Catalog._parquet_root = r
        
    def set_pg_connection(c):
        Catalog._pg_conn = c
        
    def pg_connection():
        return Catalog._pg_conn
        
    def __init__(self, pg_table, sqlite_path, sqlite_table, parquet_path, 
                 table_type=Summary, name=''):
        self._pg_table = pg_table
        self._sqlite_path = sqlite_path
        self._sqlite_table = sqlite_table
        self._parquet_path = parquet_path
        self._table_type = table_type
        self._name = name
        if table_type < Catalog.Summary or table_type > Catalog.Auxiliary:
            print("Unknown table type")
        
    @property
    def pg_table(self):
        return self._pg_table
    
    @property 
    def sqlite_table(self):
        return self._sqlite_table
        
    @property
    def sqlite_abspath(self):
        return os.path.join(Catalog._sqlite_root, self._sqlite_path)
    
    @property
    def parquet_abspath(self):
        return os.path.join(Catalog._parquet_root, self._parquet_path)
    
    @property 
    def table_type(self):
        return self._table_type
    
    @property
    def name(self):
        return self._name
    
 

    

Make the PostgreSQL db connection

In [None]:
dbname = 'desc_dc2_drp'
dbuser = 'desc_dc2_drp_user'
dbhost = 'nerscdb03.nersc.gov'
dbconfig = {'dbname' : dbname, 'user' : dbuser, 'host' : dbhost}
pg_conn = psycopg2.connect(**dbconfig)

Catalog.set_pg_connection(pg_conn)

In [None]:
Catalog.set_sqlite_root('/global/cfs/cdirs/lsst/shared/DC2-prod/Run3.1i/truth')
Catalog.set_parquet_root('/global/cscratch1/sd/jrbogart/desc/truth/pq')
pg_tables = ['agn_truth.truth_summary', 'agn_truth.agn_auxiliary_info',
            'agn_truth.agn_variability_truth', 'lensed_host_truth.truth_summary',
            'lensed_agn_truth.truth_summary',
            'lensed_agn_truth.lensed_agn_variability_truth',
            'lensed_sne_truth.truth_summary',
            'lensed_sne_truth.lensed_sn_variability_truth']

sqlite_paths = ['agntruth/agn_truth_cat.db', 'agntruth/agn_truth_cat.db',
               'agntruth/agn_variability_truth_cat.db',
               'lensed_hosttruth/lensed_host_truth_cat.db',
               'lensed_agntruth/lensed_agn_truth_cat.db',
               'lensed_agntruth/lensed_agn_variability_truth_cat.db',
               'lensed_snetruth/lensed_sne_truth_cat.db',
               'lensed_snetruth/lensed_sne_truth_cat.db']
sqlite_tables = ['truth_summary', 'agn_auxiliary_info', 'agn_variability_truth',
                'truth_summary', 'truth_summary','lensed_agn_variability_truth',
                'truth_summary', 'lensed_sn_variability_truth']

parquet_paths = ['agn_truth_summary.parquet', 'agn_auxiliary_info.parquet',
                'agn_variability_truth.parquet', 'lensed_host_truth_summary.parquet',
                'lensed_agn_truth_summary.parquet',
                'lensed_agn_variability_truth.parquet',
                'lensed_sne_truth_summary.parquet',
                'lensed_sn_variability_truth.parquet']
table_types = [Catalog.Summary, Catalog.Auxiliary, Catalog.Variability, Catalog.Summary,
              Catalog.Summary, Catalog.Variability, Catalog.Summary, Catalog.Variability]
names = ['agn summary', 'agn aux', 'agn variability', 'lensed host summary', 'lensed agn summary',
        'lensed agn variability', 'lensed sn summary', 'lensed sn variability']
n_tables = len(parquet_paths)
catalogs = []
for i in range(n_tables):
    catalogs.append(Catalog(pg_tables[i], sqlite_paths[i], sqlite_tables[i],
                           parquet_paths[i], table_types[i], names[i]))

Convenience utilities

In [None]:
def get_pg_columns(catalog, select_list, alias_list=None, other=None):
    '''
    Returns data frame version of db select output.  
    Works for PostgreSQL.  It seems sqlite cursors don't support "with"
    '''
    cols = ','.join(select_list)
    q = f"""SELECT {cols} FROM {catalog.pg_table}"""
    if other:
        q = ''.join([q,' ',other])
        
    #print('Postgres query: ', q)
    conn = Catalog.pg_connection()

    with conn.cursor() as cursor:
        cursor.execute(q)
        records = cursor.fetchall()

    if alias_list:
        return pd.DataFrame(records, columns=alias_list)
    else:
        return pd.DataFrame(records, columns=select_list)

In [None]:
def get_sqlite_columns(catalog, select_list, alias_list=None, other=None):
    '''
    Returns data frame version of db select output.  
    '''
    cols = ','.join(select_list)
    q = f"""SELECT {cols} FROM {catalog.sqlite_table}"""
    if other:
        q = ''.join([q,' ',other])
    conn = sqlite3.connect(catalog.sqlite_abspath)
    cursor = conn.cursor()
  
    cursor.execute(q)
    records = cursor.fetchall()
 
    if alias_list:
        return pd.DataFrame(records, columns=alias_list)
    else:
        return pd.DataFrame(records, columns=select_list)

In [None]:
def get_parquet_columns(catalog, column_list):
    return(pq.read_table(catalog.parquet_abspath, column_list)).to_pandas()
    

Standard queries, to be applied to each representation should include
* getting length of id column (all tables have an id column)
* comparing values for a couple rows.  For sqlite and parquet, expect rows will be returned in the same order.  For PostgreSQL might have to match ids (for variability tables would have to match (id, obshistid) for uniqueness)
* plotting ra, dec (summary tables)
* for variability tables, plot a light curve
* histogramming a column or two

In [None]:
# Compare counts.  Fetching counts is fast for parquet but rather slow 
# for large tables in the databases.

for i in range(n_tables):
    pg_df = get_pg_columns(catalogs[i], ['count(id) as count_id'],
                           alias_list=['count_id'])
    print(f"\nFor catalog {catalogs[i].name}\n     pg count={pg_df['count_id'][0]}")    
   
    sq_df = get_sqlite_columns(catalogs[i], ['count(id) as count_id'],
                               alias_list=['count_id'])
    pq_df = get_parquet_columns(catalogs[i], ['id'])
   
    print(f" sqlite count={sq_df['count_id'][0]}")
    print(f"parquet count={len(pq_df['id'])}")

### Compare Rows
To find column names, can open parquet file.  Then pq_file.schema.column(i) returns information about the ith column.   Includes field `name` and `physical_type`

For truth summary and auxiliary tables, id is unique.  For variability tables, use
(id, obsHistID)   (the field is lowercase in the PostgreSQL db, but it shouldn't care if uppercase is passed in)

To start just print out values for all columns for each format (excep not coord in Postgres since the others don't have it) for a row or two and compare visually. Later might want to check for within-tolerance for floating point fields and identity for the others.

### Plot coverage
For each table of type Catalog.Summary, plot (ra, dec)

In [None]:
def plot_coverage(catalog):
    # Get the data
    columns = ['ra', 'dec']

    if catalog.table_type != Catalog.Summary:
        print('Only plot coverage for truth summary catalogs')
        return
    
    pg_data = get_pg_columns(catalog, columns)
    sq_data = get_sqlite_columns(catalog, columns)
    parquet_data = get_parquet_columns(catalog, columns)
    
    # and plot
    plt.figure(figsize=(15, 6))
    plt.suptitle(f'{catalog.name} Coverage')
    plt.subplot(131).set_title('PostgreSQL')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(pg_data['ra']), np.array(pg_data['dec']), s=0.6)
 
    plt.subplot(132).set_title('SQLite')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(sq_data['ra']), np.array(sq_data['dec']), s=0.6)   
    plt.subplot(133).set_title('Parquet')
    plt.xlabel('ra')
    plt.ylabel('dec')
    plt.scatter(np.array(parquet_data['ra']), np.array(parquet_data['dec']), s=0.6)       

In [None]:
def compare_coverage(catalog, rtol=1e-5, atol=1e-8, verbose=False):
    # Get the data
    columns = ['ra', 'dec', 'id']

    if catalog.table_type != Catalog.Summary:
        print('Only plot coverage for truth summary catalogs')
        return
    
    pg_data = get_pg_columns(catalog, columns).sort_values(by=['id'])
    sq_data = get_sqlite_columns(catalog, columns).sort_values(by=['id'])
    id_mask = (pg_data['id'] == sq_data['id'])

    parquet_data = get_parquet_columns(catalog, columns).sort_values(by=['id'])

    if verbose:
        fmt = '{} first entry: id={}, ra={}, dec={}'
        print(fmt.format('SQLite  ', sq_data['id'][0], sq_data['ra'][0], 
                         sq_data['dec'][0]))
        print(fmt.format('Postgres', pg_data['id'][0], pg_data['ra'][0], 
                         pg_data['dec'][0]))
        print(fmt.format('Parquet ', parquet_data['id'][0], parquet_data['ra'][0], 
                         parquet_data['dec'][0]))
    print('\nCatalog ', catalog.name)
    ok = np.allclose(np.asarray(sq_data['ra']), np.asarray(pg_data['ra']),
                     rtol=rtol, atol=atol)
    ok = ok and np.allclose(np.asarray(sq_data['dec']), np.asarray(pg_data['dec']),
                            rtol=rtol, atol=atol)
    if ok:
        print('SQLite and PostgreSQL datasets are sufficiently close')
    else:
        print('SQLite, PostgreSQL mismatch')
        
    ok = np.allclose(np.asarray(sq_data['ra']), np.asarray(parquet_data['ra']),
                     rtol=rtol, atol=atol)
    ok = ok and np.allclose(np.asarray(sq_data['dec']), np.asarray(parquet_data['dec']),
                            rtol=rtol, atol=atol)
    if ok:
        print('SQLite and Parquet datasets are sufficiently close')
    else:
        print('SQLite, Parquet mismatch')

In [None]:
plot_coverage(catalogs[0])

In [None]:
plot_coverage(catalogs[3])

In [None]:
plot_coverage(catalogs[4])

In [None]:
plot_coverage(catalogs[6])

#### Compare ra, dec numerically

In [None]:
for c in catalogs:
    if c.table_type == Catalog.Summary:
        # Since ra, dec are double precision, make tolerances a little tighter than default
        compare_coverage(c, rtol=1e-6, atol=1e-9)

### Light curves
Compare light curves for AGNs.  The same could be done for lensed agn and lensed SNe.

First find delta flux readings for stars which are expected to be in the field of view for a particular visit. This sort of query returns practically instantly in PostgreSQL because in the `stellar_variability_truth` table both of the columns mentioned in the `WHERE` clause - `id` and `obshistid` - are indexed.

In [None]:
def format_cone_search(coord_column, ra, dec, radius):
    '''
    Parameters
    coord_column:  name of column of type earth in the table
    ra:  ra value at center of cone (degrees)
    dec:  dec value at center of cone (degrees)
    radius: radius of cone (arcseconds)
    
    Returns
    Condition to be inserted into WHERE clause for the query
    '''
    cond = f"""conesearch({coord_column},'{ra}','{dec}','{radius}')"""
    return cond

In [None]:
# pick a location that probably gets lots of visits
ra = 53.0      
dec = -28.1
radius = 80   
truth_schema = 'agn_truth'
tbl_spec = f"""SELECT S.id, S.ra, S.dec, max(abs(V.delta_flux)) as max_delta_flux,count(V.bandpass) AS visit_count 
           FROM {truth_schema}.truth_summary AS S JOIN 
           {truth_schema}.agn_variability_truth AS V ON S.id=V.id """
where = "WHERE " + format_cone_search('S.coord', ra, dec, radius) + " AND S.is_variable=1 "
group_by = " GROUP BY S.id,S.ra,S.dec ORDER BY max_delta_flux DESC"
q = tbl_spec + where + group_by
print(q)

# This takes a couple minutes to complete
dbconn = Catalog.pg_connection()
with dbconn.cursor() as cursor:
    %time cursor.execute(q)
    records = cursor.fetchall()


In [None]:
df_lengths = pd.DataFrame(records, columns=['id', 'ra','dec', 'max_delta_flux','visit_count'])
df_lengths

### Plot light curves for one star
Pick first object for large delta flux


In [None]:
lc_id = df_lengths['id'][0]
lc_ra = df_lengths['ra'][0]
lc_dec = df_lengths['dec'][0]


In [None]:
print('id is: ', lc_id)

#### Get the data
For each data source get delta_flux and time values for the plot and some summary information about the star. Order conveniently for plotting.

This query runs quickly in PostgreSQL (the variability table is indexed by id),
but the same query made on the SQLite and parquet files is slow.

In [None]:
name = 'agn variability'

#PostgreSQL
cat = None
for c in catalogs:
    if c.name == name:
        print('Found catalog')
        cat = c
        break
if cat:
    cols = ['bandpass', 'MJD', 'delta_flux']
    other = f"WHERE id='{lc_id}' ORDER BY bandpass, MJD"
    pg_df = get_pg_columns(cat, cols, other=other)
    print('PostgreSQL data shape: ', pg_df.shape)
else:
    print('No such catalog "', name, '"')
    print('Fix before proceeding with remainder of notebook')

In [None]:
# SQLite
sq_df = get_sqlite_columns(cat, cols, other=other)

In [None]:
print('SQLite data shape: ',sq_df.shape)

In [None]:
#Parquet
pq_cols = ['id', 'bandpass', 'MJD', 'delta_flux']
pq_df_orig = get_parquet_columns(cat, pq_cols)
print('Parquet data shape initially: ', pq_df_orig.shape)
pq_df = pq_df_orig.query(f'id == "{lc_id}"')
print('Parquet data shape after filtering: ', pq_df.shape)

In [None]:
pq_df.sort_values(by=['bandpass', 'MJD'])

#### Plotting

In [None]:
def plot_band_lc(times, fluxes, params):
    out = plt.scatter(np.asarray(times), np.asarray(fluxes), **params)

In [None]:
def format_title(id, ra, dec, band=None, object='star'):  
    if band is None:
        return f'Per-band light curves for star {id} at (ra,dec)=({ra}, {dec})'
    else:
        return f'Light curve for {object} {id}, band={band} at (ra,dec)=({ra}, {dec})'

In [None]:
def plot_object(title, the_data, data_source, band):
    '''
    Plot requested band
    
    Parameters
    -----------
    title : string
    the_data : list of data frame which must include columns filtername, obsstart, mag
    data_source : list of data source names
    band : Must be one of ['r', 'i','g', 'u', 'y', 'z']
    '''
    if band not in ['r', 'i','g', 'u', 'y', 'z']:
        print('Unknown band "',band,'"')
        return
    color_dict  = {'r' : 'red', 'g' : 'green', 'i' : 'orange', 'u' : 'magenta',
                  'y' : 'blue', 'z' : 'black'}
    
    plt.figure(figsize=(18, 6))
    plt.suptitle(title)
    plt_shape_min = 100 + 10 * len(data_source) + 1
    for i in range(len(data_source)):
        plt.subplot(plt_shape_min + i).set_title(data_source[i])
    
        good_d = the_data[i][(np.isnan(the_data[i].delta_flux)==False)]
        band_d = good_d[(good_d.bandpass==band)]
        plt.xlabel('Julian date')
        plt.ylabel('Delta flux')
        params = {'marker' : 'o', 'label' : f'band {band}', 'color' : color_dict[band],
                 's' : 1.0}

        plot_band_lc(list(band_d['MJD']), list(band_d['delta_flux']),
                    params)

In [None]:
for band in ('i', 'r', 'g', 'u'):
    title = format_title(lc_id, lc_ra, lc_dec, band, object='agn')
    plot_object(title, [pg_df, sq_df, pq_df], ('PostgreSQL', 'SQLite', 'Parquet'), band)