In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pipeasy_spark as ppz
import pyspark
from pyspark.ml.feature import (
    OneHotEncoder, StringIndexer, StandardScaler, OneHotEncoder, OneHotEncoderEstimator,
    VectorAssembler
)

In [3]:
session = pyspark.sql.SparkSession.builder.appName('titanic').getOrCreate()
titanic = session.read.csv('./datasets/titanic.csv', header=True, inferSchema=True, sep='\t')

In [4]:
titanic.show(1)

+-----------+--------+------+--------------------+----+----+-----+-----+---------+----+-----+--------+
|PassengerId|Survived|Pclass|                Name| Sex| Age|SibSp|Parch|   Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+----+----+-----+-----+---------+----+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|male|22.0|    1|    0|A/5 21171|7.25| null|       S|
+-----------+--------+------+--------------------+----+----+-----+-----+---------+----+-----+--------+
only showing top 1 row



In [5]:
df = titanic.select('Survived', 'Name', 'Pclass', 'Sex', 'Age').dropna()
df.show(1)

+--------+--------------------+------+----+----+
|Survived|                Name|Pclass| Sex| Age|
+--------+--------------------+------+----+----+
|       0|Braund, Mr. Owen ...|     3|male|22.0|
+--------+--------------------+------+----+----+
only showing top 1 row



In [6]:
transformer = ppz.map_by_column({
    'Survived': None,  # this variable is not modified ('Survived': [] is also valid), or the column can be omitted
    'Name': [ppz.transformers.ColumnDropper()],
    'Pclass': [OneHotEncoder()],
    'Sex': [StringIndexer(), OneHotEncoderEstimator(dropLast=False)],
    'Age': [VectorAssembler(), StandardScaler()]
}, target_name='Survived')

In [7]:
trained_transformer = transformer.fit(df)
df_transformed = trained_transformer.transform(df)
df_transformed.show(20)

+--------+-------------+-------------+--------------------+--------------------+
|Survived|       Pclass|          Sex|                 Age|            features|
+--------+-------------+-------------+--------------------+--------------------+
|       0|    (3,[],[])|(2,[0],[1.0])|[1.5054181442954726]|(6,[3,5],[1.0,1.5...|
|       1|(3,[1],[1.0])|(2,[1],[1.0])| [2.600267703783089]|[0.0,1.0,0.0,0.0,...|
|       1|    (3,[],[])|(2,[1],[1.0])|[1.7791305341673767]|(6,[4,5],[1.0,1.7...|
|       1|(3,[1],[1.0])|(2,[1],[1.0])| [2.394983411379161]|[0.0,1.0,0.0,0.0,...|
|       0|    (3,[],[])|(2,[0],[1.0])| [2.394983411379161]|(6,[3,5],[1.0,2.3...|
|       0|(3,[1],[1.0])|(2,[0],[1.0])|[3.6951172632707054]|[0.0,1.0,0.0,1.0,...|
|       0|    (3,[],[])|(2,[0],[1.0])|[0.13685619493595...|(6,[3,5],[1.0,0.1...|
|       1|    (3,[],[])|(2,[1],[1.0])|[1.8475586316353527]|(6,[4,5],[1.0,1.8...|
|       1|(3,[2],[1.0])|(2,[1],[1.0])|[0.9579933645516644]|[0.0,0.0,1.0,0.0,...|
|       1|    (3,[],[])|(2,[

In [8]:
from pyspark.ml.classification import LogisticRegression
logit = LogisticRegression(featuresCol='features', labelCol='Survived')
predictor = logit.fit(df_transformed)

In [9]:
predictor.transform(df_transformed).show(10)

+--------+-------------+-------------+--------------------+--------------------+--------------------+--------------------+----------+
|Survived|       Pclass|          Sex|                 Age|            features|       rawPrediction|         probability|prediction|
+--------+-------------+-------------+--------------------+--------------------+--------------------+--------------------+----------+
|       0|    (3,[],[])|(2,[0],[1.0])|[1.5054181442954726]|(6,[3,5],[1.0,1.5...|[2.72954543088323...|[0.93874770446287...|       0.0|
|       1|(3,[1],[1.0])|(2,[1],[1.0])| [2.600267703783089]|[0.0,1.0,0.0,0.0,...|[-1.4995958294148...|[0.18248581215260...|       1.0|
|       1|    (3,[],[])|(2,[1],[1.0])|[1.7791305341673767]|(6,[4,5],[1.0,1.7...|[-0.2130400374162...|[0.44694051906310...|       1.0|
|       1|(3,[1],[1.0])|(2,[1],[1.0])| [2.394983411379161]|[0.0,1.0,0.0,0.0,...|[-1.5535336402292...|[0.17457648656092...|       1.0|
|       0|    (3,[],[])|(2,[0],[1.0])| [2.394983411379161]|(6,