In [7]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import udf
import pickle  
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType
import pandas as pd
from pyspark.sql.functions import pandas_udf

sparkConf = SparkConf()
sparkConf.setMaster("spark://spark-master:7077")
sparkConf.setAppName("MLExample")
sparkConf.set("spark.driver.memory", "2g")
sparkConf.set("spark.executor.cores", "1")
sparkConf.set("spark.driver.cores", "1")
# create the spark session, which is the entry point to Spark SQL engine.
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()

#load the model 
model_lr = pickle.load(open('/home/jovyan/data/lr_model.pkl', 'rb'))

# create broadcasted model
sc = spark.sparkContext
braodcast_model = sc.broadcast(model_lr)

dataSchema = StructType(
        [StructField("ntp", IntegerType(), True),
         StructField("pgc", IntegerType(), True),
         StructField("dbp", IntegerType(), True),
         StructField("tsft", IntegerType(), True),
         StructField("si", IntegerType(), True),
         StructField("bmi", FloatType(), True),       
         StructField("dpf", FloatType(), True),
         StructField("age", IntegerType(), True),
         StructField("class", StringType(), True)
         ])

df = spark.read.schema(dataSchema).format("csv").option("header", "true") \
       .load("/home/jovyan/data/prediction_set.csv").drop("class")
df.printSchema()
df.show(10)

@udf('integer')
def predict_udf(*cols):
    return int(braodcast_model.value.predict((cols,)))

list_of_columns = df.columns  # Retrieves the names of all columns in the DataFrame as a list.
df_prediction = df.withColumn('prediction', predict_udf(*list_of_columns))

df_prediction.show()

root
 |-- ntp: integer (nullable = true)
 |-- pgc: integer (nullable = true)
 |-- dbp: integer (nullable = true)
 |-- tsft: integer (nullable = true)
 |-- si: integer (nullable = true)
 |-- bmi: float (nullable = true)
 |-- dpf: float (nullable = true)
 |-- age: integer (nullable = true)

+---+---+---+----+---+----+-----+---+
|ntp|pgc|dbp|tsft| si| bmi|  dpf|age|
+---+---+---+----+---+----+-----+---+
|  1|126| 60|   0|  0|30.1|0.349| 47|
|  1| 93| 70|  31|  0|30.4|0.315| 23|
| 12| 84| 72|  31|  0|29.7|0.297| 46|
|  0|139| 62|  17|210|22.1|0.207| 21|
|  0| 97| 64|  36|100|36.8|  0.6| 25|
|  8|120|  0|   0|  0|30.0|0.183| 38|
|  1| 97| 70|  15|  0|18.2|0.147| 21|
|  6|107| 88|   0|  0|36.8|0.727| 31|
|  0|189|104|  25|  0|34.3|0.435| 41|
|  2| 83| 66|  23| 50|32.2|0.497| 22|
+---+---+---+----+---+----+-----+---+
only showing top 10 rows

+---+---+---+----+---+----+-----+---+----------+
|ntp|pgc|dbp|tsft| si| bmi|  dpf|age|prediction|
+---+---+---+----+---+----+-----+---+----------+
|  1|

In [9]:
# Stop the spark context
spark.stop()