In [1]:
from fink_mm.ztf_join_gcn import ztf_grb_filter
from fink_utils.science.utils import ang2pix
from pyspark.sql import functions as F
from pyspark.sql.functions import explode, col, pandas_udf
from fink_mm.utils.fun_utils import get_pixels

from fink_mm.ztf_join_gcn import remove_skymap

import numpy as np
import pandas as pd
import os
import io
from pyarrow import fs

import pyspark.sql.functions as F
from pyspark.sql import DataFrame

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType, ArrayType, IntegerType

from fink_filters.classification import extract_fink_classification
from fink_utils.spark.utils import concat_col

from fink_mm.observatory import obsname_to_class, INSTR_FORMAT
from fink_mm.observatory.observatory import Observatory
from fink_mm.gcn_stream.gcn_reader import load_voevent_from_file, load_json_from_file
from fink_mm.init import init_logging
from enum import Enum

from fink_mm.utils.fun_utils import compute_rate, format_rate_results, get_association_proba, get_observatory
from fink_utils.spark.partitioning import convert_to_datetime

from fink_filters.filter_mm_module.filter import (
    f_grb_bronze_events,
    f_grb_silver_events,
    f_grb_gold_events,
    f_gw_bronze_events,
)

In [2]:
# gcn_path = "/user/julien.peloton/fink_mm/gcn_storage"
month = "06"
gcn_path = "gcn_storage_fix"

gcn_alert = (
    spark.read.format("parquet")
    .option("mergeSchema", True)
    .load(gcn_path + "/year=2023/")
    .drop("year").drop("month").drop("day")
)
gcn_alert = gcn_alert.filter("observatory != 'LVK'")

In [70]:
gcn_alert.count()

624

In [3]:
# gcn_alert = gcn_alert.fillna({"gcn_status": "initial"}).drop("year").drop("month").drop("day")

In [4]:
# timecol = "triggerTimejd"
# converter = lambda x: convert_to_datetime(x)  # noqa: E731
# if "timestamp" not in gcn_alert.columns:
#     gcn_alert = gcn_alert.withColumn("timestamp", converter(gcn_alert[timecol]))

# if "year" not in gcn_alert.columns:
#     gcn_alert = gcn_alert.withColumn("year", F.date_format("timestamp", "yyyy"))

# if "month" not in gcn_alert.columns:
#     gcn_alert = gcn_alert.withColumn("month", F.date_format("timestamp", "MM"))

# if "day" not in gcn_alert.columns:
#     gcn_alert = gcn_alert.withColumn("day", F.date_format("timestamp", "dd"))

In [5]:
# gcn_alert.write.mode("append").partitionBy("year", "month", "day").parquet(
#     "gcn_storage_fix"
# )

In [6]:
ztf_path = "/user/julien.peloton"
ztf_alert = (
    spark.read.format("parquet")
    .option("mergeSchema", True)
    .load(
        ztf_path
        + f"/archive/science/year=2023/"
    )
    .drop("year").drop("month").drop("day")
)

In [7]:
ztf_dataframe = ztf_alert.drop(
    "candid",
    "schemavsn",
    "publisher",
    "cutoutScience",
    "cutoutTemplate",
    "cutoutDifference",
    "month", "day"
)

ztf_dataframe = ztf_grb_filter(
    ztf_dataframe, 5, 2, 0.5, 2
)

# compute pixels for ztf alerts
ztf_dataframe = ztf_dataframe.withColumn(
    "hpix",
    ang2pix(ztf_dataframe.candidate.ra, ztf_dataframe.candidate.dec, F.lit(32)),
)

ztf_dataframe = ztf_dataframe.withColumn("ztf_ra", col("candidate.ra")).withColumn(
    "ztf_dec", col("candidate.dec")
)

In [8]:
gcn_dataframe = gcn_alert.withColumn(
    "hpix_circle",
    get_pixels(gcn_alert.observatory, gcn_alert.raw_event, F.lit(32)),
)

# remove the gw skymap to save memory before the join
gcn_dataframe = gcn_dataframe.withColumn(
    "raw_event",
    remove_skymap(gcn_dataframe.observatory, gcn_dataframe.raw_event),
)

gcn_dataframe = gcn_dataframe.withColumn("hpix", explode("hpix_circle"))

gcn_dataframe = gcn_dataframe.withColumnRenamed("ra", "gcn_ra").withColumnRenamed(
    "dec", "gcn_dec"
)

In [9]:
join_condition = [
    ztf_dataframe.hpix == gcn_dataframe.hpix,
    ztf_dataframe.candidate.jdstarthist > gcn_dataframe.triggerTimejd,
    ztf_dataframe.candidate.jd - gcn_dataframe.triggerTimejd < 20
]
df_join_mm = gcn_dataframe.join(ztf_dataframe, join_condition, "inner")

In [10]:
df_grb = df_join_mm# .repartition(10000)
# df_grb.count()

In [11]:
df_grb = concat_col(df_grb, "magpsf")
df_grb = concat_col(df_grb, "diffmaglim")
df_grb = concat_col(df_grb, "jd")
df_grb = concat_col(df_grb, "fid")

df_grb = df_grb.withColumn(
    "c_rate",
    compute_rate(
        df_grb["candidate.magpsf"],
        df_grb["candidate.jdstarthist"],
        df_grb["candidate.jd"],
        df_grb["candidate.fid"],
        df_grb["cmagpsf"],
        df_grb["cdiffmaglim"],
        df_grb["cjd"],
        df_grb["cfid"],
    ),
)

df_grb = format_rate_results(df_grb, "c_rate")

# TODO : do something better with satellites
# df_grb = add_tracklet_information(df_grb)

df_grb = df_grb.withColumn("tracklet", F.lit(""))

df_grb = df_grb.withColumn(
    "fink_class",
    extract_fink_classification(
        df_grb["cdsxmatch"],
        df_grb["roid"],
        df_grb["mulens"],
        df_grb["snn_snia_vs_nonia"],
        df_grb["snn_sn_vs_all"],
        df_grb["rf_snia_vs_nonia"],
        df_grb["candidate.ndethist"],
        df_grb["candidate.drb"],
        df_grb["candidate.classtar"],
        df_grb["candidate.jd"],
        df_grb["candidate.jdstarthist"],
        df_grb["rf_kn_vs_nonkn"],
        df_grb["tracklet"],
    ),
)

hdfs_adress = "134.158.75.222"
last_time = "20230101"
end_time = "20231231"

# refine the association and compute the serendipitous probability
df_grb = df_grb.withColumn(
    "p_assoc",
    get_association_proba(
        df_grb["observatory"],
        df_grb["raw_event"],
        df_grb["ztf_ra"],
        df_grb["ztf_dec"],
        df_grb["start_vartime"],
        F.lit(hdfs_adress),
        df_grb["gcn_status"],
        F.lit(gcn_path)
    ),
)

# select only relevant columns
cols_to_remove = [
    "candidate",
    "prv_candidates",
    "timestamp",
    "hpix",
    "hpix_circle",
    "index",
    "fink_broker_version",
    "fink_science_version",
    "cmagpsf",
    "cdiffmaglim",
    "cjd",
    "cfid",
    "tracklet",
    "ivorn",
    "hpix_circle",
    "triggerTimejd",
]
cols_fink = [i for i in df_grb.columns if i not in cols_to_remove]
cols_extra = [
    "candidate.candid",
    "candidate.fid",
    "candidate.jdstarthist",
    "candidate.rb",
    "candidate.jd",
    "candidate.magpsf",
    "candidate.sigmapsf"
]
df_grb = df_grb.select(cols_fink + cols_extra) #.filter("p_assoc != -1.0")
df_grb = df_grb.withColumnRenamed("err_arcmin", "gcn_loc_error")

In [12]:
timecol = "jd"
converter = lambda x: convert_to_datetime(x)  # noqa: E731
if "timestamp" not in df_grb.columns:
    df_grb = df_grb.withColumn("timestamp", converter(df_grb[timecol]))

if "year" not in df_grb.columns:
    df_grb = df_grb.withColumn("year", F.date_format("timestamp", "yyyy"))

if "month" not in df_grb.columns:
    df_grb = df_grb.withColumn("month", F.date_format("timestamp", "MM"))

if "day" not in df_grb.columns:
    df_grb = df_grb.withColumn("day", F.date_format("timestamp", "dd"))

In [13]:
df_join = df_grb

df_join = df_join.withColumn(
    "is_grb_bronze",
    f_grb_bronze_events(
        df_join["fink_class"], df_join["observatory"], df_join["rb"]
    ),
)

df_join = df_join.withColumn(
    "is_grb_silver",
    f_grb_silver_events(
        df_join["fink_class"],
        df_join["observatory"],
        df_join["rb"],
        df_join["p_assoc"],
    ),
)

df_join = df_join.withColumn(
    "is_grb_gold",
    f_grb_gold_events(
        df_join["fink_class"],
        df_join["observatory"],
        df_join["rb"],
        df_join["gcn_loc_error"],
        df_join["p_assoc"],
        df_join["rate"]
    ),
)

df_join = df_join.withColumn(
    "is_gw_bronze",
    f_gw_bronze_events(
        df_join["fink_class"], df_join["observatory"], df_join["rb"]
    ),
)

In [14]:
df_join.printSchema()

root
 |-- observatory: string (nullable = true)
 |-- instrument: string (nullable = true)
 |-- event: string (nullable = true)
 |-- triggerId: string (nullable = true)
 |-- gcn_ra: double (nullable = true)
 |-- gcn_dec: double (nullable = true)
 |-- gcn_loc_error: double (nullable = true)
 |-- ackTime: timestamp (nullable = true)
 |-- triggerTimeUTC: timestamp (nullable = true)
 |-- raw_event: string (nullable = true)
 |-- gcn_status: string (nullable = true)
 |-- objectId: string (nullable = true)
 |-- cdsxmatch: string (nullable = true)
 |-- DR3Name: string (nullable = true)
 |-- Plx: float (nullable = true)
 |-- e_Plx: float (nullable = true)
 |-- gcvs: string (nullable = true)
 |-- vsx: string (nullable = true)
 |-- roid: integer (nullable = true)
 |-- rf_snia_vs_nonia: double (nullable = true)
 |-- snn_snia_vs_nonia: double (nullable = true)
 |-- snn_sn_vs_all: double (nullable = true)
 |-- mulens: double (nullable = true)
 |-- nalerthist: integer (nullable = true)
 |-- rf_kn_vs_n

In [82]:
# df_join.count()

In [None]:
write_path = "ztf_x_gcn_data"
# grbxztf_write_path = write_path + "/offline"

df_join.write.partitionBy("year", "month", "day").parquet(
    write_path
)

In [None]:
from astropy.time import Time

In [None]:
trigId = []

def gcn_from_hdfs(client, root_path, triggerId, triggerTime, gcn_status):
    path_date = os.path.join(
        root_path,
        f"year={triggerTime.year:04d}/month={triggerTime.month:02d}/day={triggerTime.day:02d}",
    )
    for p, _, files in client.walk(path_date):
        for f in np.sort(files):
            trigId.append(f.split("_")[0])
            path_to_load = os.path.join(p, f)
            with client.read(path_to_load) as reader:
                content = reader.read()
                pdf = pd.read_parquet(io.BytesIO(content))
                if triggerId in pdf["triggerId"].values and gcn_status in pdf["gcn_status"].values:
                    return pdf[
                        (pdf["triggerId"] == triggerId)
                        & (pdf["gcn_status"] == gcn_status)
                    ]

    raise FileNotFoundError(
        "File not found at these locations {} with triggerId = {} and gcn_status = {}".format(
            path_date, triggerId, gcn_status
        )
    )

In [None]:
from hdfs import InsecureClient
import json
hdfs_client = InsecureClient(f"http://{hdfs_adress}:50070")

In [None]:
gcn_res = gcn_from_hdfs(
    hdfs_client, 
    "/user/julien.peloton/fink_mm/gcn_storage/raw", 
    "S231005z", 
    Time("2023-10-05").to_datetime(), 
    "initial"
)

In [None]:
for tr, g_stat in local_join[["triggerId", "gcn_status"]].values:
    print(tr)
    print(g_stat)
    gcn_res = gcn_from_hdfs(
        hdfs_client, 
        "/user/julien.peloton/fink_mm/gcn_storage/raw", 
        tr, 
        Time("2023-10-05").to_datetime(), 
        g_stat
    )
    print(gcn_res)
    skymap_str = json.loads(gcn_res["raw_event"].iloc[0])["event"]["skymap"]
    print()

In [None]:
skymap_str

In [None]:
from astropy.time import Time


jd_test = Time(2460183.9071875, format="jd")
jd_test.to_datetime().

In [None]:
gcn_path = "/user/julien.peloton/fink_mm/gcn_storage/raw/year=2023/month=09"

gcn_alert = spark.read.format("parquet").option("mergeSchema", True).load(gcn_path)

In [None]:
gcn_local = gcn_alert.toPandas()

In [None]:
obs = get_observatory(gcn_local["observatory"].values[0], gcn_local["raw_event"].values[0])

In [None]:
t_obs = Time(obs.get_trigger_time()[1], format="jd").to_datetime()

In [None]:
root = "/user/julien.peloton/fink_mm/gcn_storage/raw"
path_date = os.path.join(
    root,
    f"year={t_obs.year:04d}/month={t_obs.month:02d}/day={t_obs.day:02d}",
)
for p, _, files in hdfs_client.walk(path_date):
    print("----")
    print(p)
    print()
    print(r)
    print()
    print(l)
    print("----")

In [None]:
gcn_from_hdfs(hdfs_client, "S230923bk", t_obs, "initial")

In [None]:
from hdfs import InsecureClient
hdfs_client = InsecureClient(
                f"http://{hdfs_adress}:50070", user="hdfs", root="/user/julien.peloton"
            )

In [None]:
gcn_from_hdfs(hdfs_client, "714809315", "initial", last_time, end_time)

In [None]:
last_time

In [None]:
for p, r, l in hdfs_client.walk("/user/julien.peloton/fink_mm/gcn_storage/raw"):
    print("----")
    print(p)
    print()
    print(r)
    print()
    print(l)
    print("----")

In [None]:
hdfs_client.parts("/user/julien.peloton/fink_mm/gcn_storage/raw")

In [None]:
import signal
import pyarrow as pa
import pyarrow.parquet as pq
import os
import time
import pandas as pd

from gcn_kafka import Consumer
import logging

from pyarrow.fs import FileSystem

import fink_mm.gcn_stream.gcn_reader as gr
from fink_mm.init import get_config, init_logging, return_verbose_level
from fink_mm.utils.fun_utils import get_hdfs_connector
from fink_mm.observatory import TOPICS, TOPICS_FORMAT
from fink_client.scripts.fink_datatransfer import my_assign
from astropy.time import Time

In [None]:
gcn_fs = get_hdfs_connector(hdfs_adress, 8020, "roman.le-montagner")

In [None]:
table = pa.Table.from_pandas(gcn_res)

pq.write_to_dataset(
    table,
    root_path="toto_gcn",
    basename_template="{}_{}_{}".format(
        "toto", time.time(), "{i}"
    ),
    existing_data_behavior="overwrite_or_ignore",
    filesystem=gcn_fs,
)