In [9]:
from pyspark.sql import SparkSession
import numpy as np
import pandas as pd

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, IntegerType, StringType
import warnings

warnings.filterwarnings('ignore')

In [10]:
spark = (SparkSession
         .builder
         .getOrCreate()
        )

In [11]:
DATA_DIR='/home/data/phd/research_projects/fink_science/fink_science/cbpf_classifier/data/alerts/'
df = (spark
         .read
         .format('parquet')
         .load(DATA_DIR+'small_sample.parquet')
     )

In [12]:
df_change = df.select(
        df.diaSource.midpointTai,
        df.diaSource.filterName,
        df.diaSource.psFlux,
        df.diaSource.psFluxErr,
        df.diaObject.mwebv,
        df.diaObject.z_final,
        df.diaObject.z_final_err,
        df.diaObject.hostgal_zphot,
        df.diaObject.hostgal_zphot_err)

In [13]:
df_data = (df_change
    .withColumnRenamed("diaSource.midpointTai", "midpointTai")
    .withColumnRenamed("diaSource.midpointTai", "midpointTai")
    .withColumnRenamed("diaSource.filterName", "filterName")
    .withColumnRenamed("diaSource.psFlux", "psFlux")
    .withColumnRenamed("diaSource.psFluxErr", "psFluxErr")
    .withColumnRenamed("diaObject.mwebv", "mwebv")
    .withColumnRenamed("diaObject.z_final", "z_final")
    .withColumnRenamed("diaObject.z_final_err", "z_final_err")
    .withColumnRenamed("diaObject.hostgal_zphot", "hostgal_zphot")
    .withColumnRenamed("diaObject.hostgal_zphot_err", "hostgal_zphot_err")
)

In [14]:
df_data.show()

+-----------+----------+-----------------+------------------+--------------------+-------------------+--------------------+-------------------+--------------------+
|midpointTai|filterName|           psFlux|         psFluxErr|               mwebv|            z_final|         z_final_err|      hostgal_zphot|   hostgal_zphot_err|
+-----------+----------+-----------------+------------------+--------------------+-------------------+--------------------+-------------------+--------------------+
| 61125.1297|         z|  32415.529296875|  926.963134765625| 0.09092952311038971|               -9.0|                -9.0|               -9.0|                -9.0|
| 61125.1456|         z|       28024.9375| 535.8877563476562| 0.06947284936904907|0.08547317981719971| 0.02944999933242798|0.08547317981719971| 0.02944999933242798|
| 61125.1465|         z| 6534.21923828125| 1155.543701171875| 0.04673771560192108|               -9.0|                -9.0|               -9.0|                -9.0|
| 61125.20

In [15]:
def normalize_lc(lc_array: np.array) -> np.array:
    """
    Normalize light curves.

    Parameters:
    ----------
    lc_array: 1D np.array
        Input light curve of an alert.
        
    Returns:
    --------
    result: np.array
        normalized light curve of an alert.
    """

    result = np.zeros((lc_array.shape[0],))
    result[:] = lc_array[:]
    result[:] -= lc_array[0]

    norm = (lc_array[:] - np.min(lc_array[:])) / np.ptp(lc_array[:])
    result[:] = norm

    return result

In [17]:
def apply_filter_name(filter_name: pd.Series) -> pd.Series:
    filter_dict = {'u':1, 'g':2, 'r':3, 'i':4, 'z':5, 'Y':6}
    if len(filter_name) < 10:
    
        filter_name = filter_name.copy()
        filter_name.map(filter_dict)
    
    else:
        nfilter_name = filter_name
        
    return pd.Series(filter_name)

In [18]:
@pandas_udf(DoubleType(), PandasUDFType.SCALAR)
def predict_nn(
        midpointTai: pd.Series, psFlux: pd.Series, psFluxErr: pd.Series,
        filterName: pd.Series,
        mwebv: pd.Series, z_final: pd.Series,
        z_final_err: pd.Series, hostgal_zphot: pd.Series,
        hostgal_zphot_err: pd.Series
    ) -> pd.Series:
    
    """
    Return predctions from a model given inputs as pd.Series

    Parameters:
    -----------
    midpointTai: spark DataFrame Column
        SNID JD Time (float)
    psFlux: spark DataFrame Column
        flux from LSST (float)
    psFluxErr: spark DataFrame Column
        flux error from LSST (float)
    filterName:
        (string)
    mwebv:
        (float)
    z_final: spark DataFrame Column
        redshift of a given event (float)
    z_final_err: spark DataFrame Column
        redshift error of a given event (float)       
    hostgal_zphot: spark DataFrame Column
        photometric redshift of host galaxy (float)
    hostgal_zphot_err: spark DataFrame Column
        error in photometric redshift of host galaxy (float)
    Returns:
    --------
    preds: pd.Series
        predictions of a broad class in an pd.Series format (pd.Series[float])
    """
    bands = []
    lcs = []
    meta = []

    filterName = apply_filter_name(filterName)
    #new_midpointTai = normalize_light_curve(midpointTai)
    
    
    return pd.Series(filterName)

In [22]:
df_data.withColumn('mapFilterName', df_data.filterName).show()


+-----------+----------+-----------------+------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------+
|midpointTai|filterName|           psFlux|         psFluxErr|               mwebv|            z_final|         z_final_err|      hostgal_zphot|   hostgal_zphot_err|mapFilterName|
+-----------+----------+-----------------+------------------+--------------------+-------------------+--------------------+-------------------+--------------------+-------------+
| 61125.1297|         z|  32415.529296875|  926.963134765625| 0.09092952311038971|               -9.0|                -9.0|               -9.0|                -9.0|         null|
| 61125.1456|         z|       28024.9375| 535.8877563476562| 0.06947284936904907|0.08547317981719971| 0.02944999933242798|0.08547317981719971| 0.02944999933242798|         null|
| 61125.1465|         z| 6534.21923828125| 1155.543701171875| 0.04673771560192108|               -9.0|   