## Introduction to PySpark

**Course Structure:**
* Chapter 1. Getting to know PySpark
* Chapter 2. Manipulating data
* Chapter 3. Getting started with machine learning pipelines
* Chapter 4. Model tuning and selection

### CHAPTER 1. Getting to know PySpark

#### Part 1.1 What is Spark

* It is a platform for cluster computing
* It allows you to spread adta and computations over clusters with multiple nodes (separate computers)
* Makes it easier to work with large datasets
* Each node works on its own subset of the total data and part of total calculations

**Key considerations:**
* Is my data too big to work with on a single machine?
* Can my calculations be easily parallelized?

**Using Spark in Python:**
* The first step is connecting to a cluster (a remote machine connected to all other nodes)
* One computer called the *master* that manages to split up data and computations
* Creating connection by creating an instance of *'SparkContext'* class
* An object holding all attributes can be created with *'SparkConf()'* constructor


In [1]:
# launch Spark server with Spark Connect
# !$HOME/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:$SPARK_VERSION

# check for 'JAVA_HOME' variable setting
import os
import sys

java_home = os.environ.get("JAVA_HOME")
spark_home = os.environ.get("SPARK_HOME")

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable
print(java_home)
print(spark_home)

E:\JAVA\jdk-21.0.0
E:\SPARK\spark-3.5.0-bin-hadoop3\spark-3.5.0-bin-hadoop3


In [2]:
# Code for part 1.1 

# examining the Spark context
from pyspark import SparkConf, SparkContext

# create the Spark context
conf = SparkConf().setMaster("local").setAppName("Spark Example App")
sc = SparkContext(conf=conf)

print(sc)
print(sc.version)




<SparkContext master=local appName=Spark Example App>
3.5.0


#### Part 1.2 Using DataFrames

* Spark's core data structure is the **Resilient-Distributed Dataset (RDD)**
* Spark DataFrame is built on top of RDDs
* It is designed to behave like a SQL table
* First, you have to create a *'SparkSession'* object from *'SparkContext'*
* *'SparkContext'* is the connection to the cluster
* *'SparkSession'* is the interface with the connection
* NOTE: Since PySpark 2.0, creating a **SparkSession** creates a **SparkContext** internally and exposes the **SparkContext** variable to use

In [3]:
# Code for part 1.2

# creating a SparkSession
from pyspark.sql import SparkSession


# create my_spark
my_spark = SparkSession.builder.appName("Spark Example").getOrCreate()
print(my_spark)

# get data into tables for my_spark: 'flights'
path = '21_datasets/flights.csv'
flights_df = my_spark.read.csv(path, header=True, inferSchema=True)
tableName = 'flights'
flights_df.createOrReplaceTempView(tableName)

# viewing tables
print(my_spark.catalog.listTables())


<pyspark.sql.session.SparkSession object at 0x000001B318403B10>
[Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


In [4]:
# Code for part 1.2 (continue)

# are you query-ins?

# get the query
query = "SELECT * FROM flights LIMIT 10"

# get the first 10 rows of flights
flights10 = my_spark.sql(query)

# show the results
flights10.show()


+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

In [5]:
# Code for part 1.2 (continue)

# pandafy a spark DataFrame

# get the query
query = "SELECT origin, dest, COUNT(*) AS N FROM flights GROUP BY origin, dest"

# run the query
flight_counts = my_spark.sql(query)

# convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()
print(pd_counts.head())

  origin dest    N
0    SEA  RNO    8
1    SEA  DTW   98
2    SEA  CLE    2
3    SEA  LAX  450
4    PDX  SEA  144


In [6]:
# Code for part 1.2 (continue)

# put some spark in your data
import pandas as pd
import numpy as np

# create pd_temp
pd_temp = pd.DataFrame(np.random.random(10))

# create spark_temp_df from pd_temp
spark_temp_df = my_spark.createDataFrame(pd_temp)

# examine the tables in the catalog
print(my_spark.catalog.listTables())

# add spark_temp_df to the catalog
tableName='temp'
spark_temp_df.createOrReplaceTempView(tableName)

# examine the tables in the catalog again
print(my_spark.catalog.listTables())
print(spark_temp_df.show())


[Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


  if should_localize and is_datetime64tz_dtype(s.dtype) and s.dt.tz is not None:


[Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='temp', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]
+-------------------+
|                  0|
+-------------------+
|0.06337023132307473|
| 0.7772454844841593|
| 0.3611821467765055|
| 0.7310070835574528|
| 0.5382083427812854|
|0.31519329604281254|
| 0.4251697598541002|
| 0.2547094371450006|
|0.10345785116274009|
| 0.5289974773641084|
+-------------------+

None


In [7]:
# Code for part 1.2 (continue)

# dropping the middle man

# load the file path
file_path = '21_datasets/airports.csv'

# read in the airports data
airports = my_spark.read.csv(file_path, header=True)

# show the data
airports.show()
print(my_spark.catalog.listTables())

# note this new dataframe has not created a table yet, so listTables() does not show it. 

+---+--------------------+----------------+-----------------+----+---+---+
|faa|                name|             lat|              lon| alt| tz|dst|
+---+--------------------+----------------+-----------------+----+---+---+
|04G|   Lansdowne Airport|      41.1304722|      -80.6195833|1044| -5|  A|
|06A|Moton Field Munic...|      32.4605722|      -85.6800278| 264| -5|  A|
|06C| Schaumburg Regional|      41.9893408|      -88.1012428| 801| -6|  A|
|06N|     Randall Airport|       41.431912|      -74.3915611| 523| -5|  A|
|09J|Jekyll Island Air...|      31.0744722|      -81.4277778|  11| -4|  A|
|0A9|Elizabethton Muni...|      36.3712222|      -82.1734167|1593| -4|  A|
|0G6|Williams County A...|      41.4673056|      -84.5067778| 730| -5|  A|
|0G7|Finger Lakes Regi...|      42.8835647|      -76.7812318| 492| -5|  A|
|0P2|Shoestring Aviati...|      39.7948244|      -76.6471914|1000| -5|  U|
|0S9|Jefferson County ...|      48.0538086|     -122.8106436| 108| -8|  A|
|0W3|Harford County Ai...

### CHAPTER 2. Manipulating data

#### Part 2.1 Creating columns

* You can create column with *'.withColumn()'* method
* It has to be an object of class 'Column'
* Updating Spark DataFrame means return a new DataFrame and overwrite the original one
* Example: *'df = df.withColumn("newCol", df.oldCol + 1)'*

In [8]:
# Code for part 2.1

# creating columns

# create the DataFrame flights
flights = my_spark.table("flights")

# show the head and column types
flights.show()
flights.printSchema()

# add duration_hrs
flights = flights.withColumn("duration_hrs", flights.air_time / 60)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

#### Part 2.2 SQL in s nutshell

**Basic understandings of SQL:**
* *'SELECT'* command is followed by the columns you want in the table
* *'FROM'* command is followed by the name of the table contains these columns
* *'WHERE'* command filters rows of tables based on some logical conditions specified
*  *'GROUP BY'* command breaks your data into groups and applies a function your *'SELECT'* statement to each group


#### Part 2.3 Filtering data

* Filtering data with *'.filter()'* method 
* It is counterpart of SQL's *'WHERE'* clause


In [9]:
# Code for part 2.3

# filtering data

# filter flights by passing a string
long_flights1 = flights.filter("distance > 1000")

# filter flights by passing a column of boolean values
long_flights2 = flights.filter(flights.distance > 1000)

# print the data to check if they are equal
long_flights1.show()
long_flights2.show()

# note they are equal.

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|               6.0|
|2014|    4| 19|    1236|       -4|    1508|       -7|     AS| N309AS|   490|   SEA| SAN|     135|    1050|  12|    36|              2.25|
|2014|   11| 19|    1812|       -3|    2352|       -4|     AS| N564AS|    26|   SEA| ORD|     198|    1721|  18|    12|               3.3|
|2014|    8|  3|    1120|        0|    1415|        2|     AS| N305AS|   656|   SEA| PHX|     154|    1107|  11|    20| 2.566666666666667|
|2014|   11| 12|    2346|  

#### Part 2.4 Selecting

* Selecting data with *'.select()'* method
* You need to specify the columns you select with string or column names
* Perform column-wise operation with *'.select()'* method, too
* Use *'.alias()'* method to give the selecting column a name
* Use *'.selectExpr()'* to take SQL expressions as a string


In [10]:
# Code for part 2.4

# selecting

# select the first set of columns
selected1 = flights.select("tailnum", "origin", "dest")

# select the second set of columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)

# define first filter
filterA = flights.origin == "SEA"

# define second filter
filterB = flights.dest == "PDX"

# filter the data, first by filterA, then by filterB
selected2 = temp.filter(filterA).filter(filterB)


In [11]:
# Code for part 2.4

# selecting (ii)

# define avg_speed
avg_speed = (flights.distance / (flights.air_time / 60)).alias("avg_speed")

# select the correct column
speed1 = flights.select("origin", "dest", "tailnum", avg_speed)

# create the same table using a SQL expression
speed2 = flights.selectExpr("origin", "dest", "tailnum", "distance/ (air_time/60) as avg_speed")

#### Part 2.5 Aggregating

* Common aggregation methods:
    * *'.min()'*
    * *'.max()'*
    * *'.count()'*
    * *'.avg()'*
    * *'.sum()'*
* They are created by calling the *'.groupBy()'* method
* *'.groupBy()'* method can also be called with no arguments
* Example: *'df.groupBy().min("col").show()'*
* You can also use *'.agg()'* method to pass functions from *'pyspark.sql.functions'* submodule

In [12]:
# Code for part 2.5

# cast "air_time" column to numeric
from pyspark.sql.types import IntegerType

flights = flights.withColumn("air_time", flights.air_time.cast(IntegerType()))

# find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()

# find the longest flight from SEA in terms of air time

flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()

# average duration of Delta flights
flights.filter(flights.carrier == "DL").filter(flights.origin == "SEA").groupBy().avg("air_time").show()

# total hours in the air
flights.withColumn("duration_hrs", flights.air_time / 60).groupBy().sum("duration_hrs").show()

+-------------+
|min(distance)|
+-------------+
|          106|
+-------------+
+-------------+
|max(air_time)|
+-------------+
|          409|
+-------------+
+------------------+
|     avg(air_time)|
+------------------+
|188.20689655172413|
+------------------+

+------------------+
| sum(duration_hrs)|
+------------------+
|25289.600000000126|
+------------------+


In [13]:
# Code for part 2.5 (continue)

# grouping and aggregating (i)

# group by tailnum
by_plane = flights.groupBy("tailnum")

# number of flights each plane made
by_plane.count().show()

# group by origin
by_origin = flights.groupBy("origin")

# average duratin of flightss from PDX and SEA
by_origin.avg("air_time").show()

+-------+-----+
|tailnum|count|
+-------+-----+
| N442AS|   38|
| N102UW|    2|
| N36472|    4|
| N38451|    4|
| N73283|    4|
| N513UA|    2|
| N954WN|    5|
| N388DA|    3|
| N567AA|    1|
| N516UA|    2|
| N927DN|    1|
| N8322X|    1|
| N466SW|    1|
|  N6700|    1|
| N607AS|   45|
| N622SW|    4|
| N584AS|   31|
| N914WN|    4|
| N654AW|    2|
| N336NW|    1|
+-------+-----+
+------+------------------+
|origin|     avg(air_time)|
+------+------------------+
|   SEA| 160.4361496051259|
|   PDX|137.11543248288737|
+------+------------------+


In [14]:
# Code for part 2.5 (continue)

# grouping and aggregating (ii)
import pyspark.sql.functions as F

# cast "dep_delay" column to numeric
from pyspark.sql.types import IntegerType

flights = flights.withColumn("dep_delay", flights.dep_delay.cast(IntegerType()))

# group by month and dest
by_month_dest = flights.groupBy("month", "dest")

# average departure delay by month and destination
by_month_dest.avg("dep_delay").show()

# standard deviation of departure delay
by_month_dest.agg(F.stddev("dep_delay")).show()

+-----+----+-------------------+
|month|dest|     avg(dep_delay)|
+-----+----+-------------------+
|    4| PHX| 1.6833333333333333|
|    1| RDM|             -1.625|
|    5| ONT| 3.5555555555555554|
|    7| OMA|               -6.5|
|    8| MDW|               7.45|
|    6| DEN|  5.418181818181818|
|    5| IAD|               -4.0|
|   12| COS|               -1.0|
|   11| ANC|  7.529411764705882|
|    5| AUS|              -0.75|
|    5| COS| 11.666666666666666|
|    2| PSP|                0.6|
|    4| ORD|0.14285714285714285|
|   10| DFW| 18.176470588235293|
|   10| DCA|               -1.5|
|    8| JNU|             18.125|
|   11| KOA|               -1.0|
|   10| OMA|-0.6666666666666666|
|    6| ONT|              9.625|
|    3| MSP|                3.2|
+-----+----+-------------------+
+-----+----+------------------+
|month|dest| stddev(dep_delay)|
+-----+----+------------------+
|    4| PHX|15.003380033491737|
|    1| RDM| 8.830749846821778|
|    5| ONT|18.895178691342874|
|    7| OMA|2.12

#### Part 2.6 Joining

* Join will combine two tables along a column that they share, which is called *key*
* When you join tables, you are adding all the columns from one table to another table
* Example: *'df.join(df2, on='colname', how='leftouter)'*

In [15]:
# Code for part 2.6

# joining 

# get two datasets in the workspace: flights and airports
flights = my_spark.read.csv('21_datasets/flights.csv', header=True)
airports = my_spark.read.csv('21_datasets/airports.csv', header=True)

# examine the data
flights.show()
airports.show()

# rename the "faa" column to "dest"
airports = airports.withColumnRenamed("faa","dest")

# join the dataframes
flights_with_airports = flights.join(airports, on='dest', how='leftouter')

# examine the new DataFrame
flights_with_airports.show()

# note now we get a bigger dataset when two datasets joined.

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

### CHAPTER 3. Getting Started With Machine Learning Pipelines

#### Part 3.1 Machine learning pipelines

* We will use *'pyspark.ml'* module
* The core of this module are *'Transformer'* and *'Estimator'* class
* *'Transformer'* class has a *'.transform()'* method that transforms DataFrame
* *'Estimator'* class has a *'.fit()'* method that returns a model object

In [16]:
# Code for part 3.1

# join the dataframes

# get two dataframes
flights = my_spark.read.csv('21_datasets/flights.csv', header=True)
planes = my_spark.read.csv('21_datasets/planes.csv', header=True)

# rename your column: 'year' to 'plane_year'
planes = planes.withColumnRenamed('year', 'plane_year')

# join the dataframes
model_data = flights.join(planes, on='tailnum', how='leftouter')


#### Part 3.2 Data types

* Spark only handles numeric data, so the data must be either integers or doubles
* To remedy this, use the *'.cast()'* method in combination with *'.withColumn()'* method
* Example:
* *'dataframe = dataframe.withColumn("col", dataframe.col.cast("new_type"))'*

In [17]:
# Code for part 3.2

# string to integer


# check column data types
print(model_data.printSchema()) # all columns have 'string' data type

# cast the columns to integers
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast("integer"))
model_data = model_data.withColumn("air_time", model_data.air_time.cast("integer"))
model_data = model_data.withColumn("month", model_data.month.cast("integer"))
model_data = model_data.withColumn("plane_year", model_data.plane_year.cast("integer"))

# create a new column: 'plane_age'
model_data = model_data.withColumn("plane_age", model_data.year - model_data.plane_year)

# making a boolean column: 'is_late'
model_data = model_data.withColumn("is_late", model_data.arr_delay > 0)

# convert it to an integer
model_data = model_data.withColumn("label", model_data.is_late.cast("integer"))

# remove missing values
model_data = model_data.filter("arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL")



root
 |-- tailnum: string (nullable = true)
 |-- year: string (nullable = true)
 |-- month: string (nullable = true)
 |-- day: string (nullable = true)
 |-- dep_time: string (nullable = true)
 |-- dep_delay: string (nullable = true)
 |-- arr_time: string (nullable = true)
 |-- arr_delay: string (nullable = true)
 |-- carrier: string (nullable = true)
 |-- flight: string (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- air_time: string (nullable = true)
 |-- distance: string (nullable = true)
 |-- hour: string (nullable = true)
 |-- minute: string (nullable = true)
 |-- plane_year: string (nullable = true)
 |-- type: string (nullable = true)
 |-- manufacturer: string (nullable = true)
 |-- model: string (nullable = true)
 |-- engines: string (nullable = true)
 |-- seats: string (nullable = true)
 |-- speed: string (nullable = true)
 |-- engine: string (nullable = true)

None


#### Part 3.3 Strings and factors

* Strings cannot be easily converted to numeric data
* PySpark has a *'pyspark.ml.features'* submodule to handle this
* It can create 'one-hot vectors' to represent strings information
* Encoding steps:
1. Create a *'StringIndexer'* class
    * *'Estimator'* takes a DataFrame with a column of strings and maps each unique string to a number
    * Then returns a *'Transformer'* that takes a DataFrame, attaches the mapping as metadata, and returns a new DataFrame with a numeric column corresponding to the string column
2. Encode this numeric column as a one-hot vector using *'OneHotEncoder'*
    * Works the same way as *'StringIndexer'*
    * Creating an *'Estimator'*
    * Then a *'Transformer'*
3. All you need to do is to create a *'StringIndexer'* and a *'OneHotEncoder'*, *'Pipeline'* will take care of the rest

In [27]:
# Code for part 3.3

# carrier and destination: strings to numeric
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

# create a StringIndexer
carr_indexer = StringIndexer(inputCol="carrier", outputCol="carrier_index")
dest_indexer = StringIndexer(inputCol="dest", outputCol="dest_index")

# create a OneHotEncoder
carr_encoder = OneHotEncoder(inputCol="carrier_index", outputCol="carrier_fact")
dest_encoder = OneHotEncoder(inputCol="dest_index", outputCol="dest_fact")

# assembly a vector
vec_assembler = VectorAssembler(inputCols=["month", "air_time", "carrier_fact", "dest_fact", "plane_age"], outputCol="features")


In [28]:
# Code for part 3.3 (continue)

# create the pipeline
from pyspark.ml import Pipeline

# make the pipeline
flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])

#### Part 3.4 Test vs. Train

* Most important step before modeling: split the data into train set and test set
* Test set should not be touched until you think you have a good model
* In PySpark, it is important to split the data **after** all the transformations 
* Use *'.randomSplit()'* method to split the data

In [30]:
# Code for part 3.4

# transform the model_data
piped_data = flights_pipe.fit(model_data).transform(model_data)

# split the data
training, test = piped_data.randomSplit([.6, .4])


### CHAPTER 4. Model Tuning and Selection

#### Part 4.1 What is Logistic Regression?

**Logistic Regression:**
* Similar to linear regression, but it is not predicting numeric variable, it predicts the probability
* You need to assign a cutoff point to these probabilities to classify 'yes' and 'no'
* You need to adjust hyperparameters in the model to tune the model for better performance 
* Use *'pyspark.ml.classification'* class

In [31]:
# Code for part 4.1

# create the modeler
from pyspark.ml.classification import LogisticRegression

# create a LR estimator
lr = LogisticRegression()


#### Part 4.2 Cross Validation

**Cross Validation:**
* Split the training data into a few different partitions
* One partition is set aside, the model fits to others
* Repeated for each of the partitions
* Use *'pyspark.ml.evaluation.BinaryClassificationEvaluator()'* to build evaluator
* Use *'pyspark.ml.tuning.ParamGridBuilder()'* to build grid
* Use *'crossValidator()'* to perform cross-validation (need to specify evaluator and gird)

**Hyperparameters:**
* Hyperparameters in the model grid:
1. elasticNetParam
2. regParam


In [32]:
# Code for part 4.2

# create the evaluator, grid and validator 
import pyspark.ml.evaluation as evals
import pyspark.ml.tuning as tune
import numpy as np

# create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName="areaUnderROC")

# create the parameter grid
grid = tune.ParamGridBuilder()

# add the hyperparameter
grid = grid.addGrid(lr.regParam, np.arange(0, .1, .01))
grid = grid.addGrid(lr.elasticNetParam, [0,1])

# build the grid
grid = grid.build()

# make the validator
# create the CrossValidator (combining earlier steps)
cv = tune.CrossValidator(estimator=lr,
                         estimatorParamMaps=grid,
                         evaluator=evaluator)

# fit the model
best_lr = cv.fit(training)
print(best_lr)


CrossValidatorModel_522be99ee996


#### Part 4.3 Evaluating binary classifiers

* A common classification metric: AUC (Area Under the Curve)
* The curve is ROC (Receiver Operating Curve)
* The closer the AUC is to one (1), the better the model is

In [33]:
# Part 4.3 Evaluate the model

# use the model to predict the test set
test_results = best_lr.transform(test)
print(evaluator.evaluate(test_results))

# close connection to spark
my_spark.stop()

# the result is not bad!

0.6846671249760546


This is the end of this course!