## Bibliotecas

In [1]:
import sys
sys.path.append('../../../')

In [2]:
## Spark SQL
from pyspark.sql import SparkSession
import pyspark.sql.functions as f

# Spark ML
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator, Evaluator
from pyspark.ml.util import JavaMLReadable, JavaMLWritable
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [3]:
import mlflow.pyspark.ml

In [4]:
spark = (
    SparkSession
    .builder
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.driver.memory", "6g")
    .getOrCreate()
)

## Spark ML

### Binary Classification

#### Data Split

In [5]:
df_train = spark.read.parquet('../../../data/raw/raw_train')
df_test = spark.read.parquet('../../../data/raw/raw_test')

#### Preprocessing

In [6]:
# df_train = df_train.withColumn("Survived", f.when(f.rand() >= 0.75, 2).otherwise(f.col('Survived')))

In [7]:
df_train.groupby('Survived').count().toPandas()

Unnamed: 0,Survived,count
0,1,244
1,0,379


In [12]:
from src.ml.preprocessing.preprocessing import SparkPreprocessor

In [13]:
from src.ml.preprocessing.normalization import SparkScaler

In [14]:
preproc = SparkPreprocessor({'zscore': 'Age'}, ['Sex', 'Pclass'])

In [15]:
df_preproc = preproc.execute(df_train, )

INFO:root:Normalizing


#### Model

In [8]:
from src.ml.model.trainer import SparkTrainer

Could not import lightgbm, required if using LGBMExplainableModel


In [9]:
from src.ml.model.metrics import BinaryEvaluator

In [10]:
model = SparkTrainer()

In [16]:
model.train(df_preproc, True, LogisticRegression, labelCol='Survived', family='binomial')

41 43
32 41
41 50
32 34
41 43
32 41
41 50
32 34
1.563720930232558 1.7734883720930232
1.4691535150645625 1.721664275466284
Confusion Matrix
+-------+---+---+
|Outcome|  0|  1|
+-------+---+---+
|      0| 41|  9|
|      1|  2| 32|
+-------+---+---+


Results
+------------+-----------+
|   Accuracy |   ROC AUC |
|   0.869048 |  0.933529 |
+------------+-----------+

+-----------+-------------+----------+----------+
|   Outcome |   Precision |   Recall |       F1 |
|         0 |    0.953488 | 0.82     | 0.88172  |
+-----------+-------------+----------+----------+
|         1 |    0.780488 | 0.941176 | 0.853333 |
+-----------+-------------+----------+----------+


<src.ml.model.wrapper.Wrapper at 0x20185c304c0>

In [17]:
lr = LogisticRegression(labelCol='Survived', family='binomial')

In [18]:
evaluator = BinaryEvaluator('f1', 'Survived', 0)

In [19]:
grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()

In [20]:
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator, parallelism=4)

In [None]:
cvModel = cv.fit(df_preproc)

In [25]:
cvModel.transform(df_preproc).toPandas()

Unnamed: 0,Survived,Pclass,Sex,Age,Sex_indexed,Pclass_indexed,Sex_ohe,Pclass_ohe,zscore_vec,zscore_scaled,features,rawPrediction,probability,prediction
0,1,2,female,34.0,1.0,2.0,"(0.0, 1.0)","(0.0, 0.0, 1.0)",[34.0],[2.3748458010440516],"[0.0, 1.0, 0.0, 0.0, 1.0, 2.3748458010440516]","[-0.8692756847271701, 0.8692756847271701]","[0.2954050392916281, 0.7045949607083719]",1.0
1,1,2,female,31.0,1.0,2.0,"(0.0, 1.0)","(0.0, 0.0, 1.0)",[31.0],[2.1653005833048704],"[0.0, 1.0, 0.0, 0.0, 1.0, 2.1653005833048704]","[-0.8929036886451078, 0.8929036886451078]","[0.29051097168711193, 0.7094890283128881]",1.0
2,1,1,male,36.0,0.0,1.0,"(1.0, 0.0)","(0.0, 1.0, 0.0)",[36.0],[2.514542612870172],"[1.0, 0.0, 0.0, 1.0, 0.0, 2.514542612870172]","[1.114011700851593, -1.114011700851593]","[0.7528762597670617, 0.24712374023293826]",0.0
3,1,3,male,29.0,0.0,0.0,"(1.0, 0.0)","(1.0, 0.0, 0.0)",[29.0],[2.0256037714787496],"[1.0, 0.0, 1.0, 0.0, 0.0, 2.0256037714787496]","[2.5200131417374476, -2.5200131417374476]","[0.9255329605899157, 0.07446703941008426]",0.0
4,0,2,male,18.0,0.0,2.0,"(1.0, 0.0)","(0.0, 0.0, 1.0)",[18.0],[1.257271306435086],"[1.0, 0.0, 0.0, 0.0, 1.0, 1.257271306435086]","[1.5281532911861775, -1.5281532911861775]","[0.8217359579853842, 0.17826404201461576]",0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
487,1,2,female,34.0,1.0,2.0,"(0.0, 1.0)","(0.0, 0.0, 1.0)",[34.0],[2.3748458010440516],"[0.0, 1.0, 0.0, 0.0, 1.0, 2.3748458010440516]","[-0.8692756847271701, 0.8692756847271701]","[0.2954050392916281, 0.7045949607083719]",1.0
488,1,2,female,30.0,1.0,2.0,"(0.0, 1.0)","(0.0, 0.0, 1.0)",[30.0],[2.09545217739181],"[0.0, 1.0, 0.0, 0.0, 1.0, 2.09545217739181]","[-0.9007796899510871, 0.9007796899510871]","[0.2888902972074874, 0.7111097027925126]",1.0
489,0,3,male,32.0,0.0,0.0,"(1.0, 0.0)","(1.0, 0.0, 0.0)",[32.0],[2.235148989217931],"[1.0, 0.0, 1.0, 0.0, 0.0, 2.235148989217931]","[2.5436411456553856, -2.5436411456553856]","[0.9271451579408136, 0.07285484205918635]",0.0
490,0,3,male,30.0,0.0,0.0,"(1.0, 0.0)","(1.0, 0.0, 0.0)",[30.0],[2.09545217739181],"[1.0, 0.0, 1.0, 0.0, 0.0, 2.09545217739181]","[2.527889143043427, -2.527889143043427]","[0.9260739719894621, 0.07392602801053794]",0.0


In [19]:
evaluator.evaluate(lr.fit(df_preproc).transform(df_preproc))

[0.8286713286713286, 0.762135922330097]

In [95]:
d = {key : {} for key in [1, 2, 3]}

In [96]:
d[1]

{}