In [None]:
# Ran on databricks

In [0]:
import seaborn as sns

df = sns.load_dataset('tips')
df

Unnamed: 0,total_bill,tip,sex,smoker,day,time,size
0,16.99,1.01,Female,No,Sun,Dinner,2
1,10.34,1.66,Male,No,Sun,Dinner,3
2,21.01,3.50,Male,No,Sun,Dinner,3
3,23.68,3.31,Male,No,Sun,Dinner,2
4,24.59,3.61,Female,No,Sun,Dinner,4
...,...,...,...,...,...,...,...
239,29.03,5.92,Male,No,Sat,Dinner,3
240,27.18,2.00,Female,Yes,Sat,Dinner,2
241,22.67,2.00,Male,Yes,Sat,Dinner,2
242,17.82,1.75,Male,No,Sat,Dinner,2


In [0]:
# convert df to spark df
df = spark.createDataFrame(df)
df.show()

In [0]:
df.printSchema()

In [0]:
# convert categorical data into numerical data for ml model
from pyspark.ml.feature import StringIndexer

In [0]:
df.select(['sex', 'smoker', 'day', 'time']).show(4)

In [0]:
indexer = StringIndexer(inputCols=['sex', 'smoker', 'day', 'time'], outputCols=['sex_n', 'smoker_n', 'day_n', 'time_n'])

In [0]:
indexed_df = indexer.fit(df).transform(df)
indexed_df.show()

In [0]:
from pyspark.ml.feature import VectorAssembler

In [0]:
print(indexed_df.columns)

In [0]:
vectorAssembler = VectorAssembler(inputCols=['total_bill', 'tip', 'sex_n', 'smoker_n', 'day_n', 'time_n'], outputCol='features')

In [0]:
output = vectorAssembler.transform(indexed_df)
output.show()

In [0]:
training_df = output.select(['features', 'size'])
training_df.show()

In [0]:
from pyspark.ml.classification import LogisticRegression
train_data, test_data = training_df.randomSplit([.7, .3])
LR = LogisticRegression(featuresCol='features', labelCol='size')
LR = LR.fit(train_data)

In [0]:
pred = LR.evaluate(test_data)
pred.predictions.toPandas()

Unnamed: 0,features,size,rawPrediction,probability,prediction
0,"(12.69, 2.0, 0.0, 0.0, 0.0, 0.0)",2,"[-4.351676668767605, -1.0865988631994554, 6.10...","[2.2208959690678968e-05, 0.0005814782701798628...",2.0
1,"(13.37, 2.0, 0.0, 0.0, 0.0, 0.0)",2,"[-4.363726741096271, -1.730472350190885, 6.109...","[2.1372968730486736e-05, 0.0002974902261176327...",2.0
2,"(21.7, 4.3, 0.0, 0.0, 0.0, 0.0)",2,"[-4.590943989364682, -10.081918237222327, 5.62...","[1.0776843915262641e-05, 4.444181393244114e-08...",3.0
3,"(31.27, 5.0, 0.0, 0.0, 0.0, 0.0)",3,"[-4.784758299117603, -19.284706767751935, 5.54...","[2.3627124513471347e-06, 1.1916899099832572e-1...",4.0
4,"[8.77, 2.0, 0.0, 0.0, 1.0, 0.0]",2,"[-4.190030506618428, 1.5125118184115447, 6.088...","[3.128298285334602e-05, 0.009373263396022055, ...",2.0
5,"[9.68, 1.32, 0.0, 0.0, 1.0, 0.0]",2,"[-4.182621283781519, 0.7880389221046489, 6.259...","[2.6974001076155044e-05, 0.0038875471493995417...",2.0
6,"[9.94, 1.56, 0.0, 0.0, 1.0, 0.0]",2,"[-4.195535154350879, 0.4934350609849947, 6.204...","[2.791244079064217e-05, 0.003035230692311033, ...",2.0
7,"[13.94, 3.06, 0.0, 0.0, 1.0, 0.0]",2,"[-4.318333495086702, -3.59666191160791, 5.8783...","[2.9672629292272415e-05, 6.106242073631815e-05...",2.0
8,"[16.29, 3.71, 0.0, 0.0, 1.0, 0.0]",3,"[-4.382473871076162, -5.952942144307061, 5.742...","[2.7597380924371595e-05, 5.7388141846259995e-0...",2.0
9,"[18.04, 3.0, 0.0, 0.0, 1.0, 0.0]",2,"[-4.388911720457437, -7.466736054150511, 5.927...","[2.296898127818084e-05, 1.0579365203040922e-06...",2.0


In [0]:
pred.accuracy, pred.recallByLabel