In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import warnings

warnings.filterwarnings('ignore')

In [26]:
spark = (SparkSession
         .builder
         .getOrCreate()
        )
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
DATA_DIR='/home/data/phd/research_projects/fink_science/fink_science/cbpf_classifier/data/alerts/'
df = (spark
         .read
         .format('parquet')
         .load(DATA_DIR+'small_sample.parquet')
     )


                                                                                

In [27]:
from fink_utils.spark.utils import concat_col
from pyspark.sql.functions import lit

cols = ['midPointTai', 'psFlux', 'psFluxErr', 'filterName']

for col_ in cols:
    df = concat_col(df, col_, prefix='c', current='diaSource', history='prvDiaSources')
    
colsc = ['c'+i for i in cols]

colnames = [
    'cmidPointTai', 'cpsFlux', 'cpsFluxErr', 'cfiltername', 
    'diaObject.mwebv', 'diaObject.z_final', 'diaObject.z_final_err',
    'diaObject.hostgal_zphot', 'diaObject.hostgal_zphot_err', 
    lit('/home/data/phd/research_projects/fink_science/fink_science/cbpf_classifier/models/model_test_meta_ragged_1det_after')
]

#df_sub = df.select(colnames)

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, StringType, IntegerType

import tensorflow as tf
import numpy as np
import pandas as pd
from tensorflow_addons import optimizers
from utilities import normalize_lc

tf.optimizers.RectifiedAdam = optimizers.RectifiedAdam

@pandas_udf(IntegerType(), PandasUDFType.SCALAR)
def predict_nn(
        midpointTai: pd.Series, psFlux: pd.Series, psFluxErr: pd.Series,
        filterName: pd.Series, mwebv: pd.Series, z_final: pd.Series,
        z_final_err: pd.Series, hostgal_zphot: pd.Series,
        hostgal_zphot_err: pd.Series,
        model
        ) -> pd.Series:
    """
    Return predctions from a model given inputs as pd.Series

    Parameters:
    -----------
    midpointTai: spark DataFrame Column
        SNID JD Time (float)
    psFlux: spark DataFrame Column
        flux from LSST (float)
    psFluxErr: spark DataFrame Column
        flux error from LSST (float)
    filterName:
        (string)
    mwebv:
        (float)
    z_final: spark DataFrame Column
        redshift of a given event (float)
    z_final_err: spark DataFrame Column
        redshift error of a given event (float)       
    hostgal_zphot: spark DataFrame Column
        photometric redshift of host galaxy (float)
    hostgal_zphot_err: spark DataFrame Column
        error in photometric redshift of host galaxy (float)
    model: spark DataFrame Column
        path to pre-trained Hierarchical Classifier model. (string)

    Returns:
    --------
    preds: pd.Series
        predictions of a broad class in an pd.Series format (pd.Series[float])
    """

    filter_dict = {'u':1, 'g':2, 'r':3, 'i':4, 'z':5, 'Y':6}
    
    class_dict = {
        0: 111,
        1: 112,
        2: 113,
        3: 114,
        4: 115,
        5: 121,
        6: 122,
        7: 123,
        8: 124,
        9: 131,
        10: 132,
        11: 133,
        12: 134,
        13: 135,
        14: 211,
        15: 212,
        16: 213,
        17: 214,
        18: 221
    }
        
    bands = []
    lcs = []
    meta = []

    for i, mjds in enumerate(midpointTai):
        
        if len(mjds) > 0:
            bands.append(np.array(
                [filter_dict[f] for f in filterName.values[i]]
            ).astype(np.int16))        
            lc = np.concatenate(
                [mjds[:,None], psFlux.values[i][:,None], psFluxErr.values[i][:,None]], axis=-1
                )
            
            
            
            if not np.isnan(mwebv.values[i]):
                
                lcs.append(normalize_lc(lc).astype(np.float32))
                
                # print([
                #             mwebv.values[i], z_final.values[i],
                #             z_final_err.values[i], hostgal_zphot.values[i],
                #             hostgal_zphot_err.values[i]
                #         ])
                
                meta.append([
                            mwebv.values[i], z_final.values[i],
                            z_final_err.values[i], hostgal_zphot.values[i],
                            hostgal_zphot_err.values[i]
                        ])
           
    
    X = {
        'meta': np.array(meta),
        'band': tf.RaggedTensor.from_row_lengths(
            values=tf.concat(bands, axis=0),
            row_lengths=[a.shape[0] for a in bands]
        ),

        'lc': tf.RaggedTensor.from_row_lengths(
            values=tf.concat(lcs, axis=0),
            row_lengths=[a.shape[0] for a in lcs]
        )
    }
    for i, x in enumerate(X['meta'][:,3]):
        if x < 0:
            X['meta'][i,1:] = -1
        else:
            X['meta'][i,1:] = x

    NN = tf.keras.models.load_model(model.values[0], custom_objects={'RectifiedAdam': optimizers.RectifiedAdam})
    preds = NN.predict(X)
    
    
    return pd.Series([class_dict[p.argmax()] for p in preds])

In [128]:
import numpy

pyspark.__version__

'2.4.7'

In [33]:
df = df.withColumn('pBroadClass', predict_nn(*colnames))
df.select('pBroadClass').show()

2022-09-13 18:55:27.597620: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-09-13 18:55:27.597641: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
[Stage 51:>                                                         (0 + 1) / 1]2022-09-13 18:55:28.621747: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-09-13 18:55:28.621764: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2022-09-13 18:55:28.621776: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (andre-PC): /proc/driver/nvidia/version doe

Py4JJavaError: An error occurred while calling o883.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 51.0 failed 1 times, most recent failure: Lost task 0.0 in stage 51.0 (TID 51, localhost, executor driver): java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
	at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
	at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
	at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
	at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:102)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:100)
	at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
	at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:823)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:823)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1925)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1913)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1912)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1912)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:948)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:948)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:948)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2146)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2095)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2084)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:759)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2082)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2101)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:365)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3389)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3370)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:80)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:127)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:75)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withAction(Dataset.scala:3369)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2550)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2764)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:254)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:291)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:750)
Caused by: java.lang.IllegalArgumentException
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
	at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
	at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
	at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
	at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
	at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:162)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$$anon$2.<init>(ArrowEvalPythonExec.scala:102)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:100)
	at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
	at org.apache.spark.sql.execution.python.EvalPythonExec$$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:823)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:823)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
	at org.apache.spark.scheduler.Task.run(Task.scala:123)
	at org.apache.spark.executor.Executor$TaskRunner$$anonfun$10.apply(Executor.scala:408)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1360)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


In [None]:
# df.columns

In [None]:
# df.select(df['diaSource']['diaSourceId']).show()

In [None]:
# df.select((df
#            .diaSource
#            .diaSourceId
#           )
#          ).show()

### pypsark DataFrame format (only showed features used by classifier)

```
root
 |-- alertId: long (nullable = true)
 |-- diaSource: struct (nullable = true)
 |    |-- midPointTai: double (nullable = true)
 |    |-- filterName: string (nullable = true)
 |    |-- psFlux: float (nullable = true)
 |    |-- psFluxErr: float (nullable = true)
 |-- prvDiaSources: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- midPointTai: double (nullable = true)
 |    |    |-- filterName: string (nullable = true)
 |    |    |-- psFlux: float (nullable = true)
 |    |    |-- psFluxErr: float (nullable = true)
 |-- prvDiaForcedSources: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- midPointTai: double (nullable = true)
 |    |    |-- filterName: string (nullable = true)
 |    |    |-- psFlux: float (nullable = true)
 |    |    |-- psFluxErr: float (nullable = true)
 |-- prvDiaNondetectionLimits: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- midPointTai: double (nullable = true)
 |    |    |-- filterName: string (nullable = true)
 |-- diaObject: struct (nullable = true)
 |    |-- mwebv: float (nullable = true)
 |    |-- z_final: float (nullable = true)
 |    |-- z_final_err: float (nullable = true)
 |    |-- hostgal_zspec: float (nullable = true)
 |    |-- hostgal_zspec_err: float (nullable = true)
 |    |-- hostgal_zphot: float (nullable = true)
 |    |-- hostgal_zphot_err: float (nullable = true)
```