# spark sharing


In [86]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import LogisticRegressionModel
from pyspark.ml.classification import LogisticRegressionSummary
from pyspark.ml.tuning import ParamGridBuilder
from pyspark.ml.tuning import CrossValidator

from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.ml.feature import StringIndexer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.sql.functions import lit
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorSlicer
from pyspark.sql.functions import regexp_replace
# from pyspark.sql.functions import when
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.ml.linalg import Vector

from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer

import numpy as np
import pandas as pd

In [None]:
Estimator:
An Estimator is an algorithm which can be fit on a DataFrame to produce a Transformer. 
E.g., a learning algorithm is an Estimator which trains on a DataFrame and produces a model.

Transformer:
A Transformer is an algorithm which can transform one DataFrame into another DataFrame. 
E.g., an ML model is a Transformer which transforms a DataFrame with features into a DataFrame with predictions.

## 1、basis

### 1.1 SparkSession是DataFrame和SQL功能的主入口

In [4]:
# spark = SparkSession.builder.master("spark://10.0.120.106:7077").appName("spark_app").config("", "").getOrCreate()

# spark = SparkSession.builder.master("local").appName("spark_app").config("", "").getOrCreate()

# main entry of programming Spark with DataFrame. 'enableHiveSupport()' enables spark to utilize SQL or HQL
# A SparkSession can be used create DataFrame, register DataFrame as tables, execute SQL over tables, cache tables, and read parquet files.
spark = SparkSession.builder.master("local").appName("spark_app").config("", "").enableHiveSupport().getOrCreate()


In [5]:
spark

### 1.2 读取数据文件

In [6]:
# 虽然是txt格式，但是按照csv来读取可以设置seperator
read_data = spark.read.csv('C:\spark\data\my_data\iris_data.txt', sep=',')
read_data.head(5)

[Row(_c0='5.1', _c1='3.5', _c2='1.4', _c3='0.2', _c4='Iris-setosa'),
 Row(_c0='4.9', _c1='3.0', _c2='1.4', _c3='0.2', _c4='Iris-setosa'),
 Row(_c0='4.7', _c1='3.2', _c2='1.3', _c3='0.2', _c4='Iris-setosa'),
 Row(_c0='4.6', _c1='3.1', _c2='1.5', _c3='0.2', _c4='Iris-setosa'),
 Row(_c0='5.0', _c1='3.6', _c2='1.4', _c3='0.2', _c4='Iris-setosa')]

In [7]:
## 读写hive
# df = spark.sql('select * from xxx')

### 1.3 拆分数据集

In [8]:
train_data, test_data = read_data.randomSplit([0.7, 0.3], 1)

In [9]:
train_data

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string]

In [10]:
test_data

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string]

### 1.4 操作dataframe

#### A distributed collection of data grouped into named columns. A DataFrame is equivalent to a relational table in Spark SQL

In [11]:
# show the structure of dataframe - [train_data]
train_data.show(50)

+---+---+---+---+---------------+
|_c0|_c1|_c2|_c3|            _c4|
+---+---+---+---+---------------+
|4.3|3.0|1.1|0.1|    Iris-setosa|
|4.4|2.9|1.4|0.2|    Iris-setosa|
|4.4|3.0|1.3|0.2|    Iris-setosa|
|4.4|3.2|1.3|0.2|    Iris-setosa|
|4.6|3.1|1.5|0.2|    Iris-setosa|
|4.6|3.2|1.4|0.2|    Iris-setosa|
|4.6|3.4|1.4|0.3|    Iris-setosa|
|4.6|3.6|1.0|0.2|    Iris-setosa|
|4.7|3.2|1.3|0.2|    Iris-setosa|
|4.8|3.0|1.4|0.1|    Iris-setosa|
|4.8|3.0|1.4|0.3|    Iris-setosa|
|4.8|3.1|1.6|0.2|    Iris-setosa|
|4.8|3.4|1.6|0.2|    Iris-setosa|
|4.8|3.4|1.9|0.2|    Iris-setosa|
|4.9|2.5|4.5|1.7| Iris-virginica|
|4.9|3.1|1.5|0.1|    Iris-setosa|
|4.9|3.1|1.5|0.1|    Iris-setosa|
|5.0|2.0|3.5|1.0|Iris-versicolor|
|5.0|2.3|3.3|1.0|Iris-versicolor|
|5.0|3.2|1.2|0.2|    Iris-setosa|
|5.0|3.3|1.4|0.2|    Iris-setosa|
|5.0|3.4|1.5|0.2|    Iris-setosa|
|5.0|3.5|1.3|0.3|    Iris-setosa|
|5.0|3.5|1.6|0.6|    Iris-setosa|
|5.0|3.6|1.4|0.2|    Iris-setosa|
|5.1|3.3|1.7|0.5|    Iris-setosa|
|5.1|3.4|1.5|0

In [12]:
train_data

DataFrame[_c0: string, _c1: string, _c2: string, _c3: string, _c4: string]

In [13]:
# get all columns' names
train_data.columns

['_c0', '_c1', '_c2', '_c3', '_c4']

In [14]:
# get the num of rows
train_data.count()

106

In [15]:
# filter
train_data.filter(train_data._c3=='0.1').show()

+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|4.3|3.0|1.1|0.1|Iris-setosa|
|4.8|3.0|1.4|0.1|Iris-setosa|
|4.9|3.1|1.5|0.1|Iris-setosa|
|4.9|3.1|1.5|0.1|Iris-setosa|
|5.2|4.1|1.5|0.1|Iris-setosa|
+---+---+---+---+-----------+



In [16]:
# contains
train_data.filter(train_data._c4.contains('virginica')).show(5)

+---+---+---+---+--------------+
|_c0|_c1|_c2|_c3|           _c4|
+---+---+---+---+--------------+
|4.9|2.5|4.5|1.7|Iris-virginica|
|5.6|2.8|4.9|2.0|Iris-virginica|
|5.7|2.5|5.0|2.0|Iris-virginica|
|5.8|2.7|5.1|1.9|Iris-virginica|
|5.8|2.7|5.1|1.9|Iris-virginica|
+---+---+---+---+--------------+
only showing top 5 rows



In [17]:
# The available aggregate functions are avg, max, min, sum, count.
train_data.agg({'_c3': 'max'}).show()

+--------+
|max(_c3)|
+--------+
|     2.5|
+--------+



In [18]:
train_data.agg({'_c3': 'avg'}).show()

+-----------------+
|         avg(_c3)|
+-----------------+
|1.161320754716981|
+-----------------+



In [19]:
train_data.agg({'_c3': 'sum'}).show()

+------------------+
|          sum(_c3)|
+------------------+
|123.09999999999998|
+------------------+



In [20]:
# Creates or replaces a local temporary view
train_data.createOrReplaceTempView('train_table')

In [21]:
# query like sql
train_1 = spark.sql("select * from train_table where _c4=='Iris-setosa'")
train_1.show()

+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|4.3|3.0|1.1|0.1|Iris-setosa|
|4.4|2.9|1.4|0.2|Iris-setosa|
|4.4|3.0|1.3|0.2|Iris-setosa|
|4.4|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|4.6|3.2|1.4|0.2|Iris-setosa|
|4.6|3.4|1.4|0.3|Iris-setosa|
|4.6|3.6|1.0|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.8|3.0|1.4|0.1|Iris-setosa|
|4.8|3.0|1.4|0.3|Iris-setosa|
|4.8|3.1|1.6|0.2|Iris-setosa|
|4.8|3.4|1.6|0.2|Iris-setosa|
|4.8|3.4|1.9|0.2|Iris-setosa|
|4.9|3.1|1.5|0.1|Iris-setosa|
|4.9|3.1|1.5|0.1|Iris-setosa|
|5.0|3.2|1.2|0.2|Iris-setosa|
|5.0|3.3|1.4|0.2|Iris-setosa|
|5.0|3.4|1.5|0.2|Iris-setosa|
|5.0|3.5|1.3|0.3|Iris-setosa|
+---+---+---+---+-----------+
only showing top 20 rows



In [22]:
# describe the dataframe by providing basic statistics for numeric and string columns
train_data.describe().show()

+-------+------------------+-------------------+------------------+------------------+--------------+
|summary|               _c0|                _c1|               _c2|               _c3|           _c4|
+-------+------------------+-------------------+------------------+------------------+--------------+
|  count|               106|                106|               106|               106|           106|
|   mean| 5.770754716981132|  3.071698113207549|3.6443396226415086| 1.161320754716981|          null|
| stddev|0.7997579373228328|0.44588162362012673|1.7494532581939912|0.7643946568198022|          null|
|    min|               4.3|                2.0|               1.0|               0.1|   Iris-setosa|
|    max|               7.9|                4.4|               6.9|               2.5|Iris-virginica|
+-------+------------------+-------------------+------------------+------------------+--------------+



In [23]:
# add some quartiles statistics infomation
train_data.summary().show()

+-------+------------------+-------------------+------------------+------------------+--------------+
|summary|               _c0|                _c1|               _c2|               _c3|           _c4|
+-------+------------------+-------------------+------------------+------------------+--------------+
|  count|               106|                106|               106|               106|           106|
|   mean| 5.770754716981132|  3.071698113207549|3.6443396226415086| 1.161320754716981|          null|
| stddev|0.7997579373228328|0.44588162362012673|1.7494532581939912|0.7643946568198022|          null|
|    min|               4.3|                2.0|               1.0|               0.1|   Iris-setosa|
|    25%|               5.1|                2.8|               1.5|               0.3|          null|
|    50%|               5.7|                3.0|               4.1|               1.3|          null|
|    75%|               6.3|                3.4|               5.1|               

In [24]:
# summary for specific column and statistics
train_data.select('_c0').summary('mean').show()

+-------+-----------------+
|summary|              _c0|
+-------+-----------------+
|   mean|5.770754716981132|
+-------+-----------------+



In [25]:
train_data.select('_c0')

DataFrame[_c0: string]

In [26]:
train_data.select('_c0').show(10)

+---+
|_c0|
+---+
|4.3|
|4.4|
|4.4|
|4.4|
|4.6|
|4.6|
|4.6|
|4.6|
|4.7|
|4.8|
+---+
only showing top 10 rows



In [27]:
train_data.select('_c0').collect()

[Row(_c0='4.3'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.6'),
 Row(_c0='4.6'),
 Row(_c0='4.6'),
 Row(_c0='4.6'),
 Row(_c0='4.7'),
 Row(_c0='4.8'),
 Row(_c0='4.8'),
 Row(_c0='4.8'),
 Row(_c0='4.8'),
 Row(_c0='4.8'),
 Row(_c0='4.9'),
 Row(_c0='4.9'),
 Row(_c0='4.9'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.0'),
 Row(_c0='5.1'),
 Row(_c0='5.1'),
 Row(_c0='5.1'),
 Row(_c0='5.1'),
 Row(_c0='5.2'),
 Row(_c0='5.2'),
 Row(_c0='5.3'),
 Row(_c0='5.4'),
 Row(_c0='5.4'),
 Row(_c0='5.4'),
 Row(_c0='5.4'),
 Row(_c0='5.4'),
 Row(_c0='5.5'),
 Row(_c0='5.5'),
 Row(_c0='5.5'),
 Row(_c0='5.5'),
 Row(_c0='5.5'),
 Row(_c0='5.6'),
 Row(_c0='5.6'),
 Row(_c0='5.6'),
 Row(_c0='5.6'),
 Row(_c0='5.6'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.7'),
 Row(_c0='5.8'),
 Row(_c0='5.8'),
 Row(_c0='5.8'),
 Row(_c0='5.8'

In [28]:
train_data.select('_c0').head(5)

[Row(_c0='4.3'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.6')]

In [29]:
type(train_data.select('_c0').head(5))

list

In [30]:
train_data.select('_c0').take(5)

[Row(_c0='4.3'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.4'),
 Row(_c0='4.6')]

In [31]:
train_data.select('_c0').take(5)[0]

Row(_c0='4.3')

In [32]:
type(train_data.select('_c0').take(5)[0])

pyspark.sql.types.Row

In [33]:
train_data.select('_c0').take(5)[0][0]

'4.3'

In [34]:
type(train_data.select('_c0').take(5)[0][0])

str

### 1.5 操作column

In [35]:
# pick a column in a dataframe
train_data._c0

Column<b'_c0'>

In [36]:
train_data['_c0']

Column<b'_c0'>

In [37]:
# show the result - 'column' object is not callable
train_data['_c0'].collect()

TypeError: 'Column' object is not callable

In [38]:
# do calculations on dataframe through column
operate_data = train_data.select(['_c0', '_c1'])
operate_data
operate_data.show(5)

+---+---+
|_c0|_c1|
+---+---+
|4.3|3.0|
|4.4|2.9|
|4.4|3.0|
|4.4|3.2|
|4.6|3.1|
+---+---+
only showing top 5 rows



In [39]:
# sum
add_col = operate_data._c0 + operate_data['_c1']
add_col

Column<b'(_c0 + _c1)'>

In [40]:
add_operate_data = operate_data.withColumn('sum', add_col)

In [41]:
add_operate_data.show(5)

+---+---+------------------+
|_c0|_c1|               sum|
+---+---+------------------+
|4.3|3.0|               7.3|
|4.4|2.9| 7.300000000000001|
|4.4|3.0|               7.4|
|4.4|3.2|7.6000000000000005|
|4.6|3.1| 7.699999999999999|
+---+---+------------------+
only showing top 5 rows



## 1、LR

In [42]:
train_data.show(10)

+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|4.3|3.0|1.1|0.1|Iris-setosa|
|4.4|2.9|1.4|0.2|Iris-setosa|
|4.4|3.0|1.3|0.2|Iris-setosa|
|4.4|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|4.6|3.2|1.4|0.2|Iris-setosa|
|4.6|3.4|1.4|0.3|Iris-setosa|
|4.6|3.6|1.0|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.8|3.0|1.4|0.1|Iris-setosa|
+---+---+---+---+-----------+
only showing top 10 rows



In [43]:
stringIndexer = StringIndexer(inputCol='_c4', outputCol='indexedLabel')
stringIndexer_model = stringIndexer.fit(train_data)
train_data_2 = stringIndexer_model.transform(train_data)
train_data_2.show(20)

+---+---+---+---+---------------+------------+
|_c0|_c1|_c2|_c3|            _c4|indexedLabel|
+---+---+---+---+---------------+------------+
|4.3|3.0|1.1|0.1|    Iris-setosa|         0.0|
|4.4|2.9|1.4|0.2|    Iris-setosa|         0.0|
|4.4|3.0|1.3|0.2|    Iris-setosa|         0.0|
|4.4|3.2|1.3|0.2|    Iris-setosa|         0.0|
|4.6|3.1|1.5|0.2|    Iris-setosa|         0.0|
|4.6|3.2|1.4|0.2|    Iris-setosa|         0.0|
|4.6|3.4|1.4|0.3|    Iris-setosa|         0.0|
|4.6|3.6|1.0|0.2|    Iris-setosa|         0.0|
|4.7|3.2|1.3|0.2|    Iris-setosa|         0.0|
|4.8|3.0|1.4|0.1|    Iris-setosa|         0.0|
|4.8|3.0|1.4|0.3|    Iris-setosa|         0.0|
|4.8|3.1|1.6|0.2|    Iris-setosa|         0.0|
|4.8|3.4|1.6|0.2|    Iris-setosa|         0.0|
|4.8|3.4|1.9|0.2|    Iris-setosa|         0.0|
|4.9|2.5|4.5|1.7| Iris-virginica|         2.0|
|4.9|3.1|1.5|0.1|    Iris-setosa|         0.0|
|4.9|3.1|1.5|0.1|    Iris-setosa|         0.0|
|5.0|2.0|3.5|1.0|Iris-versicolor|         1.0|
|5.0|2.3|3.3|

In [44]:
# check the maps relation
sorted( set([(i[0], i[1]) for i in train_data_2.select(['_c4', 'indexedLabel']).collect()]), key=lambda x:x[0] )

[('Iris-setosa', 0.0), ('Iris-versicolor', 1.0), ('Iris-virginica', 2.0)]

In [45]:
# transform string to double
train_data = train_data.withColumn('_c0', train_data['_c0'].cast('double'))\
       .withColumn('_c1', train_data['_c1'].cast('double'))\
       .withColumn('_c2', train_data['_c2'].cast('double'))\
       .withColumn('_c3', train_data['_c3'].cast('double'))

In [46]:
##### pipeline #####

# assemble features to vector and indicate the label
input_col = ['_c0', '_c1', '_c2', '_c3']
vecAssembler = VectorAssembler(inputCols=input_col, outputCol="features")
# new_train_data = vecAssembler.transform(train_data)
stringIndexer = StringIndexer(inputCol="_c4", outputCol="label")
pipeline = Pipeline(stages=[vecAssembler, stringIndexer])
pipelineFit = pipeline.fit(train_data)
new_train_data = pipelineFit.transform(train_data)

In [47]:
new_train_data.show(20)

+---+---+---+---+---------------+-----------------+-----+
|_c0|_c1|_c2|_c3|            _c4|         features|label|
+---+---+---+---+---------------+-----------------+-----+
|4.3|3.0|1.1|0.1|    Iris-setosa|[4.3,3.0,1.1,0.1]|  0.0|
|4.4|2.9|1.4|0.2|    Iris-setosa|[4.4,2.9,1.4,0.2]|  0.0|
|4.4|3.0|1.3|0.2|    Iris-setosa|[4.4,3.0,1.3,0.2]|  0.0|
|4.4|3.2|1.3|0.2|    Iris-setosa|[4.4,3.2,1.3,0.2]|  0.0|
|4.6|3.1|1.5|0.2|    Iris-setosa|[4.6,3.1,1.5,0.2]|  0.0|
|4.6|3.2|1.4|0.2|    Iris-setosa|[4.6,3.2,1.4,0.2]|  0.0|
|4.6|3.4|1.4|0.3|    Iris-setosa|[4.6,3.4,1.4,0.3]|  0.0|
|4.6|3.6|1.0|0.2|    Iris-setosa|[4.6,3.6,1.0,0.2]|  0.0|
|4.7|3.2|1.3|0.2|    Iris-setosa|[4.7,3.2,1.3,0.2]|  0.0|
|4.8|3.0|1.4|0.1|    Iris-setosa|[4.8,3.0,1.4,0.1]|  0.0|
|4.8|3.0|1.4|0.3|    Iris-setosa|[4.8,3.0,1.4,0.3]|  0.0|
|4.8|3.1|1.6|0.2|    Iris-setosa|[4.8,3.1,1.6,0.2]|  0.0|
|4.8|3.4|1.6|0.2|    Iris-setosa|[4.8,3.4,1.6,0.2]|  0.0|
|4.8|3.4|1.9|0.2|    Iris-setosa|[4.8,3.4,1.9,0.2]|  0.0|
|4.9|2.5|4.5|1

In [48]:
new_train_data.select('features')

DataFrame[features: vector]

In [49]:
new_train_data.select('features').collect()[0][0].toArray()

array([4.3, 3. , 1.1, 0.1])

In [50]:
# transform test_data using pipeline
test_data = test_data.withColumn('_c0', test_data['_c0'].cast('double'))\
       .withColumn('_c1', test_data['_c1'].cast('double'))\
       .withColumn('_c2', test_data['_c2'].cast('double'))\
       .withColumn('_c3', test_data['_c3'].cast('double'))

new_test_data = pipelineFit.transform(test_data)

In [51]:
new_test_data.show(5)

+---+---+---+---+---------------+-----------------+-----+
|_c0|_c1|_c2|_c3|            _c4|         features|label|
+---+---+---+---+---------------+-----------------+-----+
|4.5|2.3|1.3|0.3|    Iris-setosa|[4.5,2.3,1.3,0.3]|  0.0|
|4.7|3.2|1.6|0.2|    Iris-setosa|[4.7,3.2,1.6,0.2]|  0.0|
|4.9|2.4|3.3|1.0|Iris-versicolor|[4.9,2.4,3.3,1.0]|  1.0|
|4.9|3.0|1.4|0.2|    Iris-setosa|[4.9,3.0,1.4,0.2]|  0.0|
|4.9|3.1|1.5|0.1|    Iris-setosa|[4.9,3.1,1.5,0.1]|  0.0|
+---+---+---+---+---------------+-----------------+-----+
only showing top 5 rows



#### train model

In [52]:
"""
class pyspark.ml.classification.LogisticRegression(featuresCol='features', labelCol='label', predictionCol='prediction', 
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-06, fitIntercept=True, threshold=0.5, thresholds=None,probabilityCol='probability', 
rawPredictionCol='rawPrediction', standardization=True, weightCol=None, aggregationDepth=2, family='auto', 
lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None)
"""

"\nclass pyspark.ml.classification.LogisticRegression(featuresCol='features', labelCol='label', predictionCol='prediction', \nmaxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-06, fitIntercept=True, threshold=0.5, thresholds=None,probabilityCol='probability', \nrawPredictionCol='rawPrediction', standardization=True, weightCol=None, aggregationDepth=2, family='auto', \nlowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None)\n"

In [299]:
# default -> pick 'features' and 'label' columns and input them into the LR model 
lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0.7)
lr_model = lr.fit(new_train_data)
lr_model

LogisticRegression_4dccbdbda106ca8df82a

In [300]:
model_summary = lr_model.summary

In [301]:
# display the objective per iteration
objectiveHistory = model_summary.objectiveHistory
print("objectiveHistory:")
for objective in objectiveHistory:
    print(objective)

objectiveHistory:
1.0963514880151817
1.0893542842612507
1.054540991144702
1.0124314678821982
1.011233434646947
1.0074285594594459
1.0063296684786371
1.0055798406732783
1.0049471652881403
1.00441485771545
1.0039481425760803
1.0035264995749944
1.0029650783759592
1.0027539879272085
1.0024669435204465
1.0023585354698963
1.0020777724700385
1.000921768222353
1.000294291368903
0.999603477864346
0.9993452139639785


In [302]:
print("True positive rate by label:")
for i, rate in enumerate(model_summary.truePositiveRateByLabel):
    print("label %d: %s" % (i, rate))

True positive rate by label:
label 0: 1.0
label 1: 1.0
label 2: 0.71875


#### predict

In [303]:
prediction = lr_model.transform(new_test_data)

In [304]:
prediction.show(10)

+---+---+---+---+---------------+-----------------+-----+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3|            _c4|         features|label|       rawPrediction|         probability|prediction|
+---+---+---+---+---------------+-----------------+-----+--------------------+--------------------+----------+
|4.5|2.3|1.3|0.3|    Iris-setosa|[4.5,2.3,1.3,0.3]|  0.0|[0.62950089310642...|[0.47938582953639...|       0.0|
|4.7|3.2|1.6|0.2|    Iris-setosa|[4.7,3.2,1.6,0.2]|  0.0|[0.95831436415938...|[0.56412656968181...|       0.0|
|4.9|2.4|3.3|1.0|Iris-versicolor|[4.9,2.4,3.3,1.0]|  1.0|[-0.0714091938476...|[0.29502316801143...|       1.0|
|4.9|3.0|1.4|0.2|    Iris-setosa|[4.9,3.0,1.4,0.2]|  0.0|[0.92624403516029...|[0.55622529533334...|       0.0|
|4.9|3.1|1.5|0.1|    Iris-setosa|[4.9,3.1,1.5,0.1]|  0.0|[0.97689783827449...|[0.57148286882367...|       0.0|
|5.0|3.0|1.6|0.2|    Iris-setosa|[5.0,3.0,1.6,0.2]|  0.0|[0.87628458451282...|[0.54386120971375...|       0.0|
|

In [305]:
# rawPrediction may vary between algorithms, but it intuitively gives a measure of confidence in each possible label (where larger = more confident).
prediction.select('rawPrediction').collect()[0][0]

DenseVector([0.6295, 0.2347, -0.2567])

In [306]:
prediction.select('probability').collect()[0][0]

DenseVector([0.4794, 0.323, 0.1976])

In [307]:
# metricName = f1|weightedPrecision|weightedRecall|accuracy  
# f1-score
evaluator_f1 = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol='label', metricName='f1')

In [308]:
print('f1-score: {}'.format(evaluator_f1.evaluate(prediction)))

f1-score: 0.838581742186758


In [309]:
# accuracy
evaluator_acc = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol='label', metricName='accuracy')

In [310]:
print('accuracy: {}'.format(evaluator_acc.evaluate(prediction)))

accuracy: 0.8409090909090909


#### grid search + cross validation

In [311]:
# lr = LogisticRegression(maxIter=20, regParam=0.3, elasticNetParam=0.7)  .baseOn({lr_new.labelCol: 'label'}).baseOn([lr_new.predictionCol, 'prediction'])
lr_new = LogisticRegression()
grid = ParamGridBuilder().addGrid(lr_new.regParam, [0.3, 0.5]).addGrid(lr_new.maxIter, [10, 15, 50, 100]).addGrid(lr_new.elasticNetParam, [0.7]).build()

In [312]:
evaluator_new = MulticlassClassificationEvaluator()

In [313]:
cv = CrossValidator(estimator=lr_new, estimatorParamMaps=grid, evaluator=evaluator_new)

In [314]:
cvModel = cv.fit(new_train_data).bestModel

In [316]:
cvModel.getOrDefault('regParam')

0.3

In [317]:
cvModel.getOrDefault('maxIter')

50

In [318]:
cvModel.getOrDefault('elasticNetParam')

0.7

In [319]:
new_prediction = cvModel.transform(new_test_data)

In [320]:
new_prediction.show()

+---+---+---+---+---------------+-----------------+-----+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3|            _c4|         features|label|       rawPrediction|         probability|prediction|
+---+---+---+---+---------------+-----------------+-----+--------------------+--------------------+----------+
|4.5|2.3|1.3|0.3|    Iris-setosa|[4.5,2.3,1.3,0.3]|  0.0|[0.59097048069783...|[0.55794501229109...|       0.0|
|4.7|3.2|1.6|0.2|    Iris-setosa|[4.7,3.2,1.6,0.2]|  0.0|[0.62201012642163...|[0.56890526854945...|       0.0|
|4.9|2.4|3.3|1.0|Iris-versicolor|[4.9,2.4,3.3,1.0]|  1.0|[-0.1629133809135...|[0.34872996270770...|       1.0|
|4.9|3.0|1.4|0.2|    Iris-setosa|[4.9,3.0,1.4,0.2]|  0.0|[0.65722319573996...|[0.57751955497068...|       0.0|
|4.9|3.1|1.5|0.1|    Iris-setosa|[4.9,3.1,1.5,0.1]|  0.0|[0.67514680309141...|[0.58509688104800...|       0.0|
|5.0|3.0|1.6|0.2|    Iris-setosa|[5.0,3.0,1.6,0.2]|  0.0|[0.60590042385807...|[0.56495002579399...|       0.0|
|

In [321]:
print('f1-score: {}'.format(evaluator_f1.evaluate(new_prediction)))

f1-score: 0.9322456540324879


In [322]:
print('accuracy: {}'.format(evaluator_acc.evaluate(new_prediction)))

accuracy: 0.9318181818181818
