In [1]:
# Dask puts out more advisory logging that we care for.
# It takes some doing to quiet all of it, but this recipe works.
import dask
import logging
import dask_jobqueue
from dask.dataframe.utils import make_meta
from dask.distributed import Client

dask.config.set({"logging.distributed": "critical"})

# This also has to be done, for the above to be effective
logger = logging.getLogger("distributed")
logger.setLevel(logging.CRITICAL)

import warnings

# Finally, suppress the specific warning about Dask dashboard port usage
warnings.filterwarnings("ignore", message="Port 8787 is already in use.")

from pathlib import Path

import numpy as np
import pandas as pd
from astropy.io import ascii
import matplotlib.pyplot as plt

from hats import read_hats

import lsdb

from catalog_filtering import bandFilterLenient, contains_PM
import hpms_pipeline as hpms

print("Imported libraries.")

Imported libraries.


In [2]:
# Directory variables
CATALOG_DIR = Path("../../catalogs")
DES_NAME = "des_light"
DES_DIR = CATALOG_DIR / DES_NAME
DES_MARGIN_CACHE_NAME = "des_margin_cache_18_arcsec"
DES_MARGIN_CACHE_DIR = CATALOG_DIR / DES_MARGIN_CACHE_NAME
RESULTS_NAME = "des_hpms"
RESULTS_DIR = CATALOG_DIR / RESULTS_NAME

# Filter variables
bandList = ['G','R','I','Z','Y']
class_star = None
spread_model = 0.05
magnitude_error = 0.05
check_flags = True
check_invalid_mags = True
query_string = bandFilterLenient(
    bandList,
    classStar=class_star,
    spreadModel=spread_model,
    magError=magnitude_error,
    flag=check_flags,
    invalidMags=check_invalid_mags
)
des_cols = (
    [f'CLASS_STAR_{band}' for band in bandList] + 
    [f'FLAGS_{band}' for band in bandList] + 
    ['RA','DEC','COADD_OBJECT_ID'] + 
    [f'SPREAD_MODEL_{band}' for band in bandList] + 
    [f'WAVG_MAG_PSF_{band}' for band in bandList] + 
    [f'WAVG_MAGERR_PSF_{band}' for band in bandList]
)
des_id_col = 'COADD_OBJECT_ID_1'

#Algorithm variables
k = 1
max_obj_deviation = 0.2
pm_speed_min = 2000 #units are milliseconds per year
pm_speed_max = 10**5
cone_search_rad = 7200 #two arcseconds
max_neighbor_dist = 18
xmatch_max_neighbors = 100
min_neighbors = 3

'''
TODO: Verify with Kostya that this is what is expected
'''
# Computing Variables:
queue = "RM-shared" #SBATCH -p RM-shared
account_name = "jpassos"
memory_size = "x GB"
num_cores = int
job_extra = [f'--ntasks-per-node={num_cores}']
walltime_per_worker_job = "DD:HH:MM"
pre_worker_launch_commands = [
    "source ~/.bashrc",
    "conda activate lsdb-main"
]
print("Defined Vars")

Defined Vars


In [7]:
class Job(dask_jobqueue.slurm.SLURMJob):
    # Rewrite the default, which is a property equal to cores/processes
    worker_process_threads = 3

class Cluster(dask_jobqueue.SLURMCluster):
    job_cls = Job

gb_per_job = 16
processes = 1  # Single dask worker per slurm job
gb_per_core = 2  # PSC "regular memory" nodes provide fixed 2GB / core
cluster = Cluster(
    # Number of Dask workers per node
    processes=processes,
    # Regular memory node type on PSC bridges2
    queue="RM-shared",
    # dask_jobqueue requires cores and memory to be specified
    # We set them to match RM specs
    cores=8,
    memory=f"{gb_per_job}GB",
    n_workers=4
)

client = Client(cluster)

In [3]:
des_dr2 = lsdb.read_hats(DES_DIR, margin_cache=DES_MARGIN_CACHE_DIR)
des_dr2

Unnamed: 0_level_0,CLASS_STAR_G,CLASS_STAR_R,CLASS_STAR_I,CLASS_STAR_Z,CLASS_STAR_Y,FLAGS_G,FLAGS_R,FLAGS_I,FLAGS_Z,FLAGS_Y,RA,DEC,COADD_OBJECT_ID,SPREAD_MODEL_G,SPREAD_MODEL_R,SPREAD_MODEL_I,SPREAD_MODEL_Z,SPREAD_MODEL_Y,WAVG_MAG_PSF_G,WAVG_MAG_PSF_R,WAVG_MAG_PSF_I,WAVG_MAG_PSF_Z,WAVG_MAG_PSF_Y,WAVG_MAGERR_PSF_G,WAVG_MAGERR_PSF_R,WAVG_MAGERR_PSF_I,WAVG_MAGERR_PSF_Z,WAVG_MAGERR_PSF_Y
npartitions=1582,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1
"Order: 4, Pixel: 0",double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],int16[pyarrow],int16[pyarrow],int16[pyarrow],int16[pyarrow],int16[pyarrow],double[pyarrow],double[pyarrow],int64[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow],double[pyarrow]
"Order: 5, Pixel: 8",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 3, Pixel: 743",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
"Order: 1, Pixel: 47",...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


## Obtaining Benchmarking Datasets

In [12]:
gaia_high_pm_stars = pd.read_csv('../gaia_high_pm_stars.csv',usecols=['ra_gaia', 'dec_gaia'])

gaia_high_pm_stars

Unnamed: 0,ra_gaia,dec_gaia
0,1.383284,-37.367744
1,33.079599,3.567385
2,53.567196,-49.890084
3,62.611,-53.612997
4,77.959937,-45.043813
5,50.000344,-43.066553
6,5.03561,-64.869617
7,32.622946,-50.820906
8,32.624069,-50.820823
9,11.341389,-33.497993


In [6]:
def obtain_runtime_data(row, des_cols, cone_search_rad, query_string, 
                  xmatch_max_neighbors, max_neighbor_dist, min_neighbors,
                  k, max_obj_deviation, id_col, client):
    
    with client:
        hpms.execute_pipeline(
            des_1, query_string=query_string, xmatch_max_neighbors=xmatch_max_neighbors,
            max_neighbor_dist=max_neighbor_dist, min_neighbors=min_neighbors, k=k,
            max_obj_deviation=max_obj_deviation, id_col=des_id_col
        ).to_hats(base_catalog_path=HPMS_ONE_DEG_DIR,catalog_name=HPMS_ONE_DEG_NAME)
      
    df = HPMS_filtered_catalog.compute().reset_index(drop=True, level = 0)

    return bool(not (df.query(f'kth_min_proj_error < {max_obj_deviation}').empty))

In [26]:
%%time




CPU times: user 19.6 s, sys: 2.69 s, total: 22.3 s
Wall time: 6min 28s


In [33]:
print('Rows of HPMS Filtered:',len(lsdb.read_hats(HPMS_ONE_DEG_DIR)))
print('Rows of Reg:', len(des_1))

Rows of HPMS Filtered: 7132
Rows of Reg: 420911


In [None]:
class RuntimeStats:
    def __init__(self, length=0, filtered_length=0, wall_time=0, radius=0):
        self.length = length
        self.filtered_length = filtered_length
        self.wall_time = wall_time  # seconds
        self.area = np.pi * radius * radius  # degrees
        self.avg_density = self.length / self.area if self.area != 0 else 0

des_one_stats = RuntimeStats(420911, 7132, 388, 1)
des_five_stats = RuntimeStats(7747405, 105865, 868, 5)

des_complete_area = 5430 # Degrees Squared
des_complete_avg_density = 127255 # Objs per degree squared

runs = [des_one_stats, des_five_stats]

def plot_runtime_per_object_vs_avg_density(runs):
    densities = [run.avg_density for run in runs]
    runtime_per_object = [run.wall_time / run.length if run.length != 0 else 0 for run in runs]

    plt.figure(figsize=(8, 6))
    plt.plot(densities, runtime_per_object, marker='o')
    plt.xlabel('Average Density (objects per square degree)')
    plt.ylabel('Wall Time per Object (s)')
    plt.title('Normalized Runtime vs. Average Density')
    plt.grid(True)
    plt.show()

plot_runtime_per_object_vs_avg_density(runs)