In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import traceback
import time
import logging

base_dir = os.path.abspath('..')
sys.path.append(base_dir)

from dl import authClient as ac, queryClient as qc
from tqdm import tqdm
from utils.analyze_lensing import integrated_event_duration_posterior
from utils import filtering
from utils.kde_label import cluster_label_dataframe
from utils.mc_backgrounds import synthesize_background
from tqdm import tqdm
from dl.helpers.utils import convert

plt.rcParams['axes.facecolor']='white'
plt.rcParams['savefig.facecolor']='white'
read_dir = os.path.join(base_dir, "results/12aug2024/")
fig_dir = os.path.join(base_dir, "plots/2sep2024_MC/")
results_dir = os.path.join(base_dir, "results/2sep2024_MC/")
log_dir = os.path.join(base_dir, "logs/2sep2024_MC")
log_file = os.path.join(log_dir, "MC_analysis.log")

for d in [fig_dir, results_dir, log_dir]:
    os.makedirs(d, exist_ok=True)

logging.basicConfig(filename=log_file, level=logging.INFO)

In [2]:
def make_query_string(i_batch, batch_size):
    sub_query = f"""
    SELECT objectid 
        FROM mydb://numbered_stable_stars_sep2 
        WHERE row_number BETWEEN {i_batch * batch_size + 1} AND {(i_batch + 1) * batch_size}
    """
    result = f"""
    SELECT m.objectid, m.filter, m.mag_auto, m.magerr_auto, m.mjd, m.exposure, e.exptime
        FROM nsc_dr2.meas AS m
        INNER JOIN nsc_dr2.exposure AS e
        ON e.exposure = m.exposure
        WHERE m.objectid IN ({sub_query})
    """
    return result

def submit_job(i_batch, batch_size):
    query = make_query_string(i_batch, batch_size)
    result = qc.query(sql=query, async_=True, timeout=3600)
    print(f"Submit job #{i_batch}")
    return result

In [3]:
er_df = pd.read_parquet(f"{read_dir}event_rates.parquet")
taus = np.geomspace(1e-4, 1e4, num=50)
rates = er_df.loc["rate"]
bw = 0.13
params = dict(achromatic = True, 
              factor_of_two = True)
seed = 9047851
rng = np.random.default_rng(seed=seed)
logging.info(f"Bandwidth = {bw}, params = {params}, taus = {taus}, rng_seed = {seed}, rng_state = {rng.bit_generator.state}")

In [None]:
batch_size = int(1e5)
n_objects = 5279477
num_batches = int(n_objects / batch_size) + 1
poll_rate = 10
batch_nums = np.arange(0, num_batches)
filters = ['u', 'g', 'r', 'i', 'z', 'Y', 'VR']
next_job_id = None

for i_batch in tqdm(batch_nums):
    try:
        job_id = next_job_id
        unimodal_filter_counts = np.zeros(len(filters))
        background_filter_counts = np.zeros(len(filters))

        if job_id is None:
            job_id = submit_job(i_batch, batch_size)
    
        while qc.status(job_id) == "EXECUTING":
            time.sleep(poll_rate)
    
        if i_batch != num_batches - 1:
            next_job_id = submit_job(i_batch + 1, batch_size)
    
        if qc.status(job_id) == "COMPLETED":
            # Get the data, filter out bands with fewer than 3 samples, sort by time
            lcs = convert(qc.results(job_id))
            print(f"Processing batch {i_batch}")
            lcs = lcs.groupby(by=["objectid", "filter"], 
                              sort=False).filter(lambda x: len(x) > 2)
            lcs.sort_values(by="mjd", inplace=True)

            # Randomly inject synthetic "flares"
            lcs = synthesize_background(lcs, rates, taus, rng)

            # KDE label the result
            cl = cluster_label_dataframe(lcs, bandwidth=bw)

            # Save unstable, background, and unimodal IDs
            lc_class_grouped = cl.groupby(by="objectid",
                                          sort=False,
                                          as_index=False,
                                          group_keys=False)
            lightcurve_class_df = lc_class_grouped.apply(filtering.lightcurve_classifier)
            lightcurve_class_fname = os.path.join(results_dir, 
                                                  f"batch{i_batch}_lc_class_mc.parquet")
            lightcurve_class_df.columns = [lightcurve_class_df.columns[0], "lightcurve_class"]
            lightcurve_class_df.to_parquet(lightcurve_class_fname)

            # Filter out unstable looking stars
            filtered_df = cl.groupby(by="objectid", sort=False).filter(filtering.unstable_filter)

            # Sort by MJD and filter out background events from stable stars
            # Save Background Lightcurves
            filtered_df.sort_values(by="mjd", inplace=True)
            g = filtered_df.groupby("objectid", sort=False)
            background_df = g.filter(lambda group: filtering.lens_filter(group, **params))
            fname = f"mc_lightcurves_batch{i_batch}.parquet"
            background_df.to_parquet(os.path.join(results_dir, fname))

            # Filter out the stars that still look stable
            unimodal_df = g.filter(filtering.unimodal_filter)

            # Count number of stars and number of observations in each filter
            unimodal_filter_counts = unimodal_df["filter"].value_counts()
            background_filter_counts = background_df["filter"].value_counts()
            agg_data = {f: np.array([unimodal_filter_counts.get(f, default=0), 
                                     background_filter_counts.get(f, default=0)]) for f in filters}
            n_background = background_df["objectid"].unique().size
            n_unimodal = unimodal_df["objectid"].unique().size
            agg_data["n_objects"] = np.array([n_unimodal, n_background])
            idx = pd.MultiIndex.from_product([[i_batch], ["Unimodal", "Background"]])
            aggregate_df = pd.DataFrame(data=agg_data, index=idx)
            aggregate_df.to_parquet(os.path.join(results_dir, f"batch{i_batch}_aggregates.parquet"))
            logging.info(f"Processed batch #{i_batch}")
        elif qc.status(job_id) == "ERROR":
            logging.error(f"ERROR Batch {i_batch}: {qc.error(job_id)}")
            continue
        else:
            err_str = f"""
            ERROR Batch {i_batch}: Something unexpected occurred. 
            job_id = {job_id}, status = {qc.status(job_id)}, 
            type = {type(job_id).__name__}, error = {qc.error(job_id)}
            """
            logging.error(err_str)
            continue

    except Exception as e:
        print(f"Exception {e} Occurred batch #{i_batch}")
        err_str = f"""
        ERROR Batch {i_batch}: Something unexpected occurred. 
        job_id = {job_id}
        """
        logging.error(err_str)
        logging.exception("Stack trace:")
        qc.abort(next_job_id)
        next_job_id = None
        continue

  0%|          | 0/53 [00:00<?, ?it/s]

Submit job #0
Submit job #1


  2%|▏         | 1/53 [32:48<28:26:04, 1968.55s/it]

Submit job #2

Total time: 1m:25.32s for 778.05 MB
Processing batch 1


  4%|▍         | 2/53 [57:09<23:39:22, 1669.85s/it]

Submit job #3

Total time: 1m:55.24s for 1017.38 MB
Processing batch 2
