# Introduction to PySpark

## Libraries and functions

### Libraries

In [69]:
import pandas as pd
import numpy as np

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml.classification import LogisticRegression
import pyspark.ml.evaluation as evals
import pyspark.ml.tuning as tune

### Functions

## Read Data

In [None]:
## start session
spark = SparkSession.builder.getOrCreate()

In [21]:
##read Data
flights = spark.read.csv('datasets/flights_small.csv',header=True)
airports = spark.read.csv('datasets/planes.csv',header=True)
planes = spark.read.csv('datasets/airports.csv',header=True)

##send to the cluster
flights.createOrReplaceTempView('flight')
airports.createOrReplaceTempView('airport')
planes.createOrReplaceTempView('plane')

#list available tables
spark.catalog.listTables()

### Initial Tests

In [23]:


#test reading one table
flights = spark.sql('SELECT * FROM flight').toPandas()
flights.head()

Unnamed: 0,year,month,day,dep_time,dep_delay,arr_time,arr_delay,carrier,tailnum,flight,origin,dest,air_time,distance,hour,minute
0,2014,12,8,658,-7,935,-5,VX,N846VA,1780,SEA,LAX,132,954,6,58
1,2014,1,22,1040,5,1505,5,AS,N559AS,851,SEA,HNL,360,2677,10,40
2,2014,3,9,1443,-2,1652,2,VX,N847VA,755,SEA,SFO,111,679,14,43
3,2014,4,9,1705,45,1839,34,WN,N360SW,344,PDX,SJC,83,569,17,5
4,2014,3,9,754,-1,1015,1,AS,N612AS,522,SEA,BUR,127,937,7,54


In [30]:
##use a spark dataframe
flights = spark.sql('SELECT * FROM flight')

##convert to integer
flights = flights.withColumn('air_time',flights.air_time.cast('integer'))
flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
+----+-----+---+--------+---------+-----

In [34]:
##Filter using SQL
flights.filter("carrier='DL' AND origin='SEA'").groupby().avg('air_time').show()

##Filter 
flights.filter(flights.carrier=='DL').filter(flights.origin=='SEA').groupBy().avg('air_time').show()

+------------------+
|     avg(air_time)|
+------------------+
|188.20689655172413|
+------------------+

+------------------+
|     avg(air_time)|
+------------------+
|188.20689655172413|
+------------------+



In [37]:
## filter using other functions
flights.groupby('month','dest').agg(F.stddev('dep_delay').alias('std')).show(3)

+-----+----+------------------+
|month|dest|               std|
+-----+----+------------------+
|   11| TUS|3.0550504633038935|
|   11| ANC|18.604716401245316|
|    1| BUR| 15.22627576540667|
+-----+----+------------------+
only showing top 3 rows



## ML Pipeline

In [73]:
## Prepare Data
model_data = flights
model_data = model_data.withColumn("arr_delay", model_data.arr_delay.cast('integer'))
model_data = model_data.withColumn("air_time", model_data.air_time.cast('integer'))
model_data = model_data.withColumn("month", model_data.month.cast('integer'))

## Define the label
# Create is_late
model_data = model_data.withColumn("label", (model_data.arr_delay>0).cast('integer'))


# Remove missing values
model_data = model_data.filter("arr_delay IS NOT NULL AND \
                                dep_delay IS NOT NULL AND \
                                air_time IS NOT NULL")

In [74]:
### Create a Pipeline

## Convert Categorical Variables

# Create a StringIndexer
carr_indexer = StringIndexer(inputCol='carrier',outputCol='carrier_index')

# Create a OneHotEncoder
carr_encoder = OneHotEncoder(inputCol='carrier_index',outputCol='carrier_fact')

# Create a StringIndexer
dest_indexer = StringIndexer(inputCol="dest",outputCol='dest_index')

# Create a OneHotEncoder
dest_encoder = OneHotEncoder(inputCol='dest_index',outputCol='dest_fact')

## Create final dataset
# Make a VectorAssembler
vec_assembler = VectorAssembler(inputCols=["month", "air_time", "carrier_fact", "dest_fact"], outputCol='features')

## Make the pipeline
flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, carr_encoder, vec_assembler])

## Prepare datasets

In [75]:
# Fit and transform the data
piped_data = flights_pipe.fit(model_data).transform(model_data)

# Split the data into training and test sets
training, test = piped_data.randomSplit([.6,.4])

# Create a LogisticRegression Estimator
lr = LogisticRegression()

# Create a BinaryClassificationEvaluator
evaluator = evals.BinaryClassificationEvaluator(metricName='areaUnderROC')

# Create the parameter grid
grid = tune.ParamGridBuilder()

# Add the hyperparameter
grid = grid.addGrid(lr.regParam,[.01,.1,0,1])
grid = grid.addGrid(lr.elasticNetParam, [0,1])

# Build the grid
grid = grid.build()

# Create the CrossValidator
cv = tune.CrossValidator(estimator=lr,estimatorParamMaps=grid,evaluator=evaluator)

## Train the model

In [76]:
# Fit cross validation models
models = cv.fit(training)

# Extract the best model
best_lr = models.bestModel

# Call lr.fit() if want to skipt grid search
#best_lr = lr.fit(training)

# Print best_lr
print(best_lr)

LogisticRegressionModel: uid = LogisticRegression_67a83a04c56a, numClasses = 2, numFeatures = 80


## Validate

In [81]:
# Use the model to predict the test set
test_results = best_lr.transform(test)

# Evaluate the predictions
print(evaluator.evaluate(test_results))

0.6920433237776938
