# Logistic Regression
Logistic regression is an algorithm used for problems where the target outcome variable is binary (i.e. 1 or 0). For example, our model will predict whether or not
a flight is delayed. The two possible outputs for the model will be 1 if the flight is delayed or 0 if the flight is not delayed. The features we can use to predict the outcome in a logistic regression model can be either categorical or numeric. 
The logistic regression function is the sigmoid aka "squashing" function, so called because it maps input values in the range (-inf, inf) to between 0 and 1:

![logit model](files/shared_uploads/gauriganjoo@berkeley.edu/Screen_Shot_2022_04_11_at_3_10_57_AM.png)

where x is our independent variable. 

Logistic regression assumes a linear relationship between the input and target variables.

The loss function used for logistic regression is a version of the cross-entropy loss function, sometimes called logistic loss or log loss:

![Log loss function](files/shared_uploads/gauriganjoo@berkeley.edu/Screen_Shot_2022_04_12_at_1_16_47_PM.png)
 
 
where N is the number of rows in our dataset and y_i is the true value at a point. 

This notebook creates a logistic regression model using resilient distributed datasets (RDDs).

### 1. Notebook setup

In [0]:
# Put at the top of any notebooks for storing in blob

from pyspark.sql.functions import *

blob_container = "team06" # The name of your container created in https://portal.azure.com
storage_account = "apatel" # The name of your Storage account created in https://portal.azure.com
secret_scope = "team06" # The name of the scope created in your local computer using the Databricks CLI
secret_key = "team06" # The name of the secret key created in your local computer using the Databricks CLI 
blob_url = f"wasbs://{blob_container}@{storage_account}.blob.core.windows.net"
mount_path = "/mnt/mids-w261"

In [0]:
# general imports
import sys
import csv
import numpy as np
import pandas as pd
import ast
import math
# magic commands
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [0]:
# ML modules
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix # for confusion_matrix

### 2. Dataset

The data we used for this algorithm was already joined and imputed. We will be ignoring data leakage issues and focus on the structure of the logistic regression.

In [0]:
# Read in full data with no nulls
df_6m_join = spark.read.parquet(f'{blob_url}/full_join_mattsFeats_anandFeats_cleaned_depPR_arrPR_v3')

In [0]:
display(df_6m_join)

YEAR,QUARTER,MONTH,DAY_OF_MONTH,DAY_OF_WEEK,FL_DATE,OP_UNIQUE_CARRIER,OP_CARRIER_AIRLINE_ID,TAIL_NUM,OP_CARRIER_FL_NUM,ORIGIN_AIRPORT_ID,ORIGIN_AIRPORT_SEQ_ID,ORIGIN,ORIGIN_STATE_ABR,ORIGIN_STATE_FIPS,ORIGIN_STATE_NM,DEST_AIRPORT_ID,DEST_AIRPORT_SEQ_ID,DEST,DEST_STATE_ABR,DEST_STATE_FIPS,DEST_STATE_NM,CRS_DEP_TIME,DEP_TIME,DEP_DELAY,DEP_DELAY_NEW,DEP_DEL15,DEP_DELAY_GROUP,DEP_TIME_BLK,CRS_ARR_TIME,ARR_TIME,ARR_DELAY,ARR_DELAY_NEW,ARR_DEL15,CANCELLED,ACTUAL_ELAPSED_TIME,DISTANCE,DISTANCE_GROUP,CARRIER_DELAY,WEATHER_DELAY,NAS_DELAY,LATE_AIRCRAFT_DELAY,ORI_IATA,ORI_station_id,ORI_station_lat,ORI_station_lon,ORI_airport_lat,ORI_airport_lon,ORI_elevation,ORI_dist_airp_sta,DEST_IATA,DEST_station_id,DEST_station_lat,DEST_station_lon,DEST_airport_lat,DEST_airport_lon,DEST_elevation,DEST_dist_airp_sta,CRS_DEP_HRS,CRS_DEP_MINS,CRS_DEP_TIME_STR,CRS_DEP_DT_STR,CRS_DEP_DATETIME,iata_code,ORI_timezone,CRS_DEP_DATETIME_UTC,CRS_DEP_DATETIME_UTC_END,CRS_DEP_DATETIME_UTC_START,STATION,DATE,LATITUDE,LONGITUDE,ELEVATION,REPORT_TYPE,DEST_timezone,DEP_HRS,DEP_MINS,DEP_TIME_STR,DEP_DT_STR,DEP_DATETIME,ARR_DATETIME_ACTUAL_UTC,late_night,daytime,evening,region_name,new_england,mid_atlantic,south,midwest,southwest,west,pacific_islands,atlantic_islands,spring,summer,autumn,winter,dep_date,weekend_or_holiday,flightID,ID,previous_flight_delay_status,previous_flight_dep_time,time_between_departures_min,valid_dep_delay,prior_dep_delayed,previous_DEP_DELAY_NEW_value,previous_DEP_DELAY_NEW,previous_flight_arrdelay_status,previous_flight_arr_time,time_between_arrival_and_end_min,valid_arr_delay,prior_arr_delayed,previous_ARR_DELAY_NEW_value,previous_ARR_DELAY_NEW,prev_arrival_airport,plane_is_here,avg_carrier_delay_24hrs,flights_sch_Today_ORIGIN,flights_sch_Today_DEST,avg_ori_airport_delay_24hrs,year_quarter,quarter_enum,quarter_enum_prev,month_seq_index,wnd_2_wind_obs_type,wnd_3_wind_sp_rate,cig_0_height,vis_0_distance,slp_0_day_avg,tmp_0_air_temp,dew_0_point_temp,ma1_0_altimeter_setting_rate,ma1_2_station_pressure_rate,gd1_3_sky_cover_height,WND_3_wind_speed,CIG_0_sky_ceiling_height,VIS_0_visibility_dist,SLP_0_avg_station_press,TMP_0_air_temperature,DEW_0_dew_pt_temp,MA1_0_altimeter_set_rate,MA1_2_station_pres_rate,OC1_0_wind_gust_spd_rate_imp,GD1_3_cloud_height,AA1_1_liquid_precip,AA3_1_liquid_precip,gd1_0_sky_coverage,au2_4_extreme_wind_weather,mv1_0_sand_dust_near,mv1_0_thunder_rain_near,aw1_mw1_0_smoke_haze_dust,aw1_mw1_0_fog,aw1_mw1_0_rain_drizzle,aw1_mw1_0_freezing_rain_drizzle,aw1_mw1_0_snow,aw1_mw1_0_hail_or_ice,aw1_mw1_0_thunderstorm,aw1_mw1_0_tornado,SLP_0_avg_station_press_imp,WND_3_wind_speed_imp,CIG_0_sky_ceiling_height_imp,VIS_0_visibility_dist_imp,TMP_0_air_temperature_imp,DEW_0_dew_pt_temp_imp,MA1_0_altimeter_set_rate_imp,MA1_2_station_pres_rate_imp,GD1_3_cloud_height_imp,depDelayPageRank,arrDelayPageRank
2015,1,1,1,4,2015-01-01,B6,20409,N324JB,2023,12478,1247802,JFK,NY,36,New York,14843,1484304,SJU,PR,72,Puerto Rico,535,618,43.0,43.0,1.0,2,0001-0559,1020,1039,19.0,19.0,1.0,0.0,201.0,1598.0,7,19.0,0.0,0.0,0.0,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,,,,,,,841.0,,5,35,05:35,2015-01-01 05:35,2015-01-01T05:35:00.000+0000,JFK,America/New_York,2015-01-01T10:35:00.000+0000,2015-01-01T08:35:00.000+0000,2015-01-01T06:35:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/Puerto_Rico,6,18,06:18,2015-01-01 06:18,2015-01-01T06:18:00.000+0000,2015-01-01T14:39:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,180388780851,B6N324JBJFK2015-01-01535,,,,no,0,,0.0,,,,no,0,,0.0,LGA,0,0.12,274,83,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,3,28,6,2015-03-28,B6,20409,N316JB,2323,11278,1127802,DCA,VA,51,Virginia,13204,1320402,MCO,FL,12,Florida,2015,2027,12.0,12.0,0.0,0,2000-2059,2235,2246,11.0,11.0,0.0,0.0,139.0,759.0,4,,,,,DCA,72405013743,38.847,-77.035,38.8521,-77.037697,15.0,0.6133064760167503,MCO,72205012815.0,28.434,-81.325,28.42939949,-81.30899811,96.0,1.6462168285180496,20,15,20:15,2015-03-28 20:15,2015-03-28T20:15:00.000+0000,DCA,America/New_York,2015-03-29T00:15:00.000+0000,2015-03-28T22:15:00.000+0000,2015-03-28T20:15:00.000+0000,72405013743,2015-03-28T21:52:00.000+0000,38.8472,-77.03454,3.0,FM-15,America/New_York,20,27,20:27,2015-03-28 20:27,2015-03-28T20:27:00.000+0000,2015-03-29T02:46:00.000+0000,0,0,1,south,0,0,1,0,0,0,0,0,1,0,0,0,2015-03-28,1,94489312135,B6N316JBDCA2015-03-282015,0.0,2015-03-28T21:45:00.000+0000,150.0,yes,0,0.0,0.0,0.0,2015-03-28T23:39:00.000+0000,36.0,no,0,0.0,0.0,DCA,0,0.27,178,391,0.13,2015-1,0,-1,2,N,62,1829,16093,10190,22,-122,10193,10169,1829.0,62,1829.0,16093,10190,22,-122,10193.0,10169,103,1829.0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,10190.0,62.0,1829.0,16093.0,22.0,-122.0,10193.0,10169.0,1829.0,0.0052220858322935,0.0054210976594481
2015,1,2,4,3,2015-02-04,DL,19790,N690DL,323,10397,1039705,ATL,GA,13,Georgia,14843,1484304,SJU,PR,72,Puerto Rico,820,819,-1.0,0.0,0.0,-1,0800-0859,1246,1245,-1.0,0.0,0.0,0.0,206.0,1547.0,7,,,,,ATL,72219013874,33.63,-84.442,33.6367,-84.428101,1026.0,1.4868906335816192,,,,,,,841.0,,8,20,08:20,2015-02-04 08:20,2015-02-04T08:20:00.000+0000,ATL,America/New_York,2015-02-04T13:20:00.000+0000,2015-02-04T11:20:00.000+0000,2015-02-04T09:20:00.000+0000,72219013874,2015-02-04T10:52:00.000+0000,33.6301,-84.4418,307.8,FM-15,America/Puerto_Rico,8,19,08:19,2015-02-04 08:19,2015-02-04T08:19:00.000+0000,2015-02-04T16:45:00.000+0000,0,1,0,south,0,0,1,0,0,0,0,0,0,0,0,1,2015-02-04,0,180388798319,DLN690DLATL2015-02-04820,0.0,2015-02-03T12:10:00.000+0000,1510.0,no,0,3.0,0.0,0.0,2015-02-03T13:35:00.000+0000,1425.0,yes,0,0.0,0.0,ATL,1,0.17,997,63,0.13,2015-1,0,-1,1,C,0,7620,16093,10250,22,-39,10244,9870,6096.0,0,7620.0,16093,10250,22,-39,10244.0,9870,0,6096.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10250.0,0.0,7620.0,16093.0,22.0,-39.0,10244.0,9870.0,6096.0,0.03323907560623,0.0309607739359259
2015,1,1,31,6,2015-01-31,WN,19393,N8646B,4229,15304,1530402,TPA,FL,12,Florida,10821,1082103,BWI,MD,24,Maryland,1135,1132,-3.0,0.0,0.0,-1,1100-1159,1355,1345,-10.0,0.0,0.0,0.0,133.0,842.0,4,,,,,TPA,72211012842,27.962,-82.54,27.97550011,-82.53320313,26.0,1.642864417823524,BWI,72406093721.0,39.173,-76.684,39.1754,-76.668297,146.0,1.3796804494464006,11,35,11:35,2015-01-31 11:35,2015-01-31T11:35:00.000+0000,TPA,America/New_York,2015-01-31T16:35:00.000+0000,2015-01-31T14:35:00.000+0000,2015-01-31T12:35:00.000+0000,72211012842,2015-01-31T13:53:00.000+0000,27.96194,-82.5403,5.8,FM-15,America/New_York,11,32,11:32,2015-01-31 11:32,2015-01-31T11:32:00.000+0000,2015-01-31T18:45:00.000+0000,0,1,0,south,0,0,1,0,0,0,0,0,0,0,0,1,2015-01-31,1,60129554587,WNN8646BTPA2015-01-311135,0.0,2015-01-31T12:35:00.000+0000,240.0,yes,0,6.0,6.0,0.0,2015-01-31T15:24:00.000+0000,71.0,no,0,0.0,0.0,TPA,0,0.25,173,187,0.14,2015-1,0,-1,1,N,36,8230,16093,10255,111,28,10254,10250,8230.0,36,8230.0,16093,10255,111,28,10254.0,10250,0,8230.0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,10255.0,36.0,8230.0,16093.0,111.0,28.0,10254.0,10250.0,8230.0,0.0087050844763926,0.0099642075178666
2015,1,1,1,4,2015-01-01,AA,19805,N3LLAA,2299,12478,1247802,JFK,NY,36,New York,13303,1330303,MIA,FL,12,Florida,545,640,55.0,55.0,1.0,3,0001-0559,850,959,69.0,69.0,1.0,0.0,199.0,1089.0,5,55.0,0.0,14.0,0.0,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,MIA,72202012839.0,25.788,-80.317,25.79319954,-80.29060364,8.0,2.7052797087850293,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,JFK,America/New_York,2015-01-01T10:45:00.000+0000,2015-01-01T08:45:00.000+0000,2015-01-01T06:45:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/New_York,6,40,06:40,2015-01-01 06:40,2015-01-01T06:40:00.000+0000,2015-01-01T14:59:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,137439444295,AAN3LLAAJFK2015-01-01545,,,,no,0,,0.0,,,,no,0,,0.0,LAX,0,0.17,274,208,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,3,28,6,2015-03-28,WN,19393,N443WN,1863,11278,1127802,DCA,VA,51,Virginia,10397,1039705,ATL,GA,13,Georgia,2035,2030,-5.0,0.0,0.0,-1,2000-2059,2225,2206,-19.0,0.0,0.0,0.0,96.0,547.0,3,,,,,DCA,72405013743,38.847,-77.035,38.8521,-77.037697,15.0,0.6133064760167503,ATL,72219013874.0,33.63,-84.442,33.6367,-84.428101,1026.0,1.4868906335816192,20,35,20:35,2015-03-28 20:35,2015-03-28T20:35:00.000+0000,DCA,America/New_York,2015-03-29T00:35:00.000+0000,2015-03-28T22:35:00.000+0000,2015-03-28T20:35:00.000+0000,72405013743,2015-03-28T21:52:00.000+0000,38.8472,-77.03454,3.0,FM-15,America/New_York,20,30,20:30,2015-03-28 20:30,2015-03-28T20:30:00.000+0000,2015-03-29T02:06:00.000+0000,0,0,1,south,0,0,1,0,0,0,0,0,1,0,0,0,2015-03-28,1,104952,WNN443WNDCA2015-03-282035,0.0,2015-03-28T22:20:00.000+0000,135.0,yes,0,0.0,0.0,0.0,2015-03-28T23:55:00.000+0000,40.0,no,0,0.0,0.0,DCA,0,0.26,178,912,0.13,2015-1,0,-1,2,N,62,1829,16093,10190,22,-122,10193,10169,1829.0,62,1829.0,16093,10190,22,-122,10193.0,10169,103,1829.0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,10190.0,62.0,1829.0,16093.0,22.0,-122.0,10193.0,10169.0,1829.0,0.0052220858322935,0.0054210976594481
2015,1,2,4,3,2015-02-04,DL,19790,N3735D,2203,10397,1039705,ATL,GA,13,Georgia,12889,1288903,LAS,NV,32,Nevada,820,820,0.0,0.0,0.0,0,0800-0859,1000,949,-11.0,0.0,0.0,0.0,269.0,1747.0,7,,,,,ATL,72219013874,33.63,-84.442,33.6367,-84.428101,1026.0,1.4868906335816192,LAS,72386023169.0,36.072,-115.163,36.08010101,-115.1520004,2181.0,1.3374107673224451,8,20,08:20,2015-02-04 08:20,2015-02-04T08:20:00.000+0000,ATL,America/New_York,2015-02-04T13:20:00.000+0000,2015-02-04T11:20:00.000+0000,2015-02-04T09:20:00.000+0000,72219013874,2015-02-04T10:52:00.000+0000,33.6301,-84.4418,307.8,FM-15,America/Los_Angeles,8,20,08:20,2015-02-04 08:20,2015-02-04T08:20:00.000+0000,2015-02-04T17:49:00.000+0000,0,1,0,south,0,0,1,0,0,0,0,0,0,0,0,1,2015-02-04,0,77309424304,DLN3735DATL2015-02-04820,1.0,2015-02-04T00:25:00.000+0000,775.0,yes,1,24.0,24.0,1.0,2015-02-04T03:25:00.000+0000,595.0,yes,1,16.0,16.0,ATL,1,0.17,997,373,0.13,2015-1,0,-1,1,C,0,7620,16093,10250,22,-39,10244,9870,6096.0,0,7620.0,16093,10250,22,-39,10244.0,9870,0,6096.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10250.0,0.0,7620.0,16093.0,22.0,-39.0,10244.0,9870.0,6096.0,0.03323907560623,0.0309607739359259
2015,1,1,31,6,2015-01-31,US,20355,N832AW,2002,15304,1530402,TPA,FL,12,Florida,11057,1105703,CLT,NC,37,North Carolina,1140,1137,-3.0,0.0,0.0,-1,1100-1159,1322,1311,-11.0,0.0,0.0,0.0,94.0,507.0,3,,,,,TPA,72211012842,27.962,-82.54,27.97550011,-82.53320313,26.0,1.642864417823524,CLT,72314013881.0,35.224,-80.955,35.2140007,-80.94309998,748.0,1.550757753493052,11,40,11:40,2015-01-31 11:40,2015-01-31T11:40:00.000+0000,TPA,America/New_York,2015-01-31T16:40:00.000+0000,2015-01-31T14:40:00.000+0000,2015-01-31T12:40:00.000+0000,72211012842,2015-01-31T13:53:00.000+0000,27.96194,-82.5403,5.8,FM-15,America/New_York,11,37,11:37,2015-01-31 11:37,2015-01-31T11:37:00.000+0000,2015-01-31T18:11:00.000+0000,0,1,0,south,0,0,1,0,0,0,0,0,0,0,0,1,2015-01-31,1,42949686370,USN832AWTPA2015-01-311140,0.0,2015-01-31T12:45:00.000+0000,235.0,yes,0,0.0,0.0,0.0,2015-01-31T15:19:00.000+0000,81.0,no,0,0.0,0.0,TPA,0,0.23,173,243,0.14,2015-1,0,-1,1,N,36,8230,16093,10255,111,28,10254,10250,8230.0,36,8230.0,16093,10255,111,28,10254.0,10250,0,8230.0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,10255.0,36.0,8230.0,16093.0,111.0,28.0,10254.0,10250.0,8230.0,0.0087050844763926,0.0099642075178666
2015,1,1,1,4,2015-01-01,B6,20409,N794JB,939,12478,1247802,JFK,NY,36,New York,10732,1073203,BQN,PR,72,Puerto Rico,545,545,0.0,0.0,0.0,0,0001-0559,1026,1007,-19.0,0.0,0.0,0.0,202.0,1576.0,7,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,,,,,,,841.0,,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,JFK,America/New_York,2015-01-01T10:45:00.000+0000,2015-01-01T08:45:00.000+0000,2015-01-01T06:45:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/Puerto_Rico,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,2015-01-01T14:07:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,163208893453,B6N794JBJFK2015-01-01545,,,,no,0,,0.0,,,,no,0,,0.0,HPN,0,0.12,274,6,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,3,28,6,2015-03-28,WN,19393,N8626B,2273,11278,1127802,DCA,VA,51,Virginia,11259,1125903,DAL,TX,48,Texas,2050,2050,0.0,0.0,0.0,0,2000-2059,2325,2238,-47.0,0.0,0.0,0.0,168.0,1184.0,5,,,,,DCA,72405013743,38.847,-77.035,38.8521,-77.037697,15.0,0.6133064760167503,DAL,72258013960.0,32.852,-96.856,32.847099,-96.851799,487.0,0.671561354090325,20,50,20:50,2015-03-28 20:50,2015-03-28T20:50:00.000+0000,DCA,America/New_York,2015-03-29T00:50:00.000+0000,2015-03-28T22:50:00.000+0000,2015-03-28T20:50:00.000+0000,72405013743,2015-03-28T21:52:00.000+0000,38.8472,-77.03454,3.0,FM-15,America/Chicago,20,50,20:50,2015-03-28 20:50,2015-03-28T20:50:00.000+0000,2015-03-29T03:38:00.000+0000,0,0,1,south,0,0,1,0,0,0,0,0,1,0,0,0,2015-03-28,1,120259125073,WNN8626BDCA2015-03-282050,0.0,2015-03-28T21:50:00.000+0000,180.0,yes,0,5.0,5.0,0.0,2015-03-29T00:03:00.000+0000,47.0,no,0,0.0,0.0,DCA,0,0.26,178,116,0.13,2015-1,0,-1,2,N,62,1829,16093,10190,22,-122,10193,10169,1829.0,62,1829.0,16093,10190,22,-122,10193.0,10169,103,1829.0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,10190.0,62.0,1829.0,16093.0,22.0,-122.0,10193.0,10169.0,1829.0,0.0052220858322935,0.0054210976594481


In [0]:
display(df_6m_join.filter(df_6m_join.FL_DATE == '2015-01-01'))

YEAR,QUARTER,MONTH,DAY_OF_MONTH,DAY_OF_WEEK,FL_DATE,OP_UNIQUE_CARRIER,OP_CARRIER_AIRLINE_ID,TAIL_NUM,OP_CARRIER_FL_NUM,ORIGIN_AIRPORT_ID,ORIGIN_AIRPORT_SEQ_ID,ORIGIN,ORIGIN_STATE_ABR,ORIGIN_STATE_FIPS,ORIGIN_STATE_NM,DEST_AIRPORT_ID,DEST_AIRPORT_SEQ_ID,DEST,DEST_STATE_ABR,DEST_STATE_FIPS,DEST_STATE_NM,CRS_DEP_TIME,DEP_TIME,DEP_DELAY,DEP_DELAY_NEW,DEP_DEL15,DEP_DELAY_GROUP,DEP_TIME_BLK,CRS_ARR_TIME,ARR_TIME,ARR_DELAY,ARR_DELAY_NEW,ARR_DEL15,CANCELLED,ACTUAL_ELAPSED_TIME,DISTANCE,DISTANCE_GROUP,CARRIER_DELAY,WEATHER_DELAY,NAS_DELAY,LATE_AIRCRAFT_DELAY,ORI_IATA,ORI_station_id,ORI_station_lat,ORI_station_lon,ORI_airport_lat,ORI_airport_lon,ORI_elevation,ORI_dist_airp_sta,DEST_IATA,DEST_station_id,DEST_station_lat,DEST_station_lon,DEST_airport_lat,DEST_airport_lon,DEST_elevation,DEST_dist_airp_sta,CRS_DEP_HRS,CRS_DEP_MINS,CRS_DEP_TIME_STR,CRS_DEP_DT_STR,CRS_DEP_DATETIME,iata_code,ORI_timezone,CRS_DEP_DATETIME_UTC,CRS_DEP_DATETIME_UTC_END,CRS_DEP_DATETIME_UTC_START,STATION,DATE,LATITUDE,LONGITUDE,ELEVATION,REPORT_TYPE,DEST_timezone,DEP_HRS,DEP_MINS,DEP_TIME_STR,DEP_DT_STR,DEP_DATETIME,ARR_DATETIME_ACTUAL_UTC,late_night,daytime,evening,region_name,new_england,mid_atlantic,south,midwest,southwest,west,pacific_islands,atlantic_islands,spring,summer,autumn,winter,dep_date,weekend_or_holiday,flightID,ID,previous_flight_delay_status,previous_flight_dep_time,time_between_departures_min,valid_dep_delay,prior_dep_delayed,previous_DEP_DELAY_NEW_value,previous_DEP_DELAY_NEW,previous_flight_arrdelay_status,previous_flight_arr_time,time_between_arrival_and_end_min,valid_arr_delay,prior_arr_delayed,previous_ARR_DELAY_NEW_value,previous_ARR_DELAY_NEW,prev_arrival_airport,plane_is_here,avg_carrier_delay_24hrs,flights_sch_Today_ORIGIN,flights_sch_Today_DEST,avg_ori_airport_delay_24hrs,year_quarter,quarter_enum,quarter_enum_prev,month_seq_index,wnd_2_wind_obs_type,wnd_3_wind_sp_rate,cig_0_height,vis_0_distance,slp_0_day_avg,tmp_0_air_temp,dew_0_point_temp,ma1_0_altimeter_setting_rate,ma1_2_station_pressure_rate,gd1_3_sky_cover_height,WND_3_wind_speed,CIG_0_sky_ceiling_height,VIS_0_visibility_dist,SLP_0_avg_station_press,TMP_0_air_temperature,DEW_0_dew_pt_temp,MA1_0_altimeter_set_rate,MA1_2_station_pres_rate,OC1_0_wind_gust_spd_rate_imp,GD1_3_cloud_height,AA1_1_liquid_precip,AA3_1_liquid_precip,gd1_0_sky_coverage,au2_4_extreme_wind_weather,mv1_0_sand_dust_near,mv1_0_thunder_rain_near,aw1_mw1_0_smoke_haze_dust,aw1_mw1_0_fog,aw1_mw1_0_rain_drizzle,aw1_mw1_0_freezing_rain_drizzle,aw1_mw1_0_snow,aw1_mw1_0_hail_or_ice,aw1_mw1_0_thunderstorm,aw1_mw1_0_tornado,SLP_0_avg_station_press_imp,WND_3_wind_speed_imp,CIG_0_sky_ceiling_height_imp,VIS_0_visibility_dist_imp,TMP_0_air_temperature_imp,DEW_0_dew_pt_temp_imp,MA1_0_altimeter_set_rate_imp,MA1_2_station_pres_rate_imp,GD1_3_cloud_height_imp,depDelayPageRank,arrDelayPageRank
2015,1,1,1,4,2015-01-01,B6,20409,N324JB,2023,12478,1247802,JFK,NY,36,New York,14843,1484304,SJU,PR,72,Puerto Rico,535,618,43.0,43.0,1.0,2,0001-0559,1020,1039,19.0,19.0,1.0,0.0,201.0,1598.0,7,19.0,0.0,0.0,0.0,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,,,,,,,841.0,,5,35,05:35,2015-01-01 05:35,2015-01-01T05:35:00.000+0000,JFK,America/New_York,2015-01-01T10:35:00.000+0000,2015-01-01T08:35:00.000+0000,2015-01-01T06:35:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/Puerto_Rico,6,18,06:18,2015-01-01 06:18,2015-01-01T06:18:00.000+0000,2015-01-01T14:39:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,180388780851,B6N324JBJFK2015-01-01535,,,,no,0,,0.0,,,,no,0,,0.0,LGA,0,0.12,274,83,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,AA,19805,N3LLAA,2299,12478,1247802,JFK,NY,36,New York,13303,1330303,MIA,FL,12,Florida,545,640,55.0,55.0,1.0,3,0001-0559,850,959,69.0,69.0,1.0,0.0,199.0,1089.0,5,55.0,0.0,14.0,0.0,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,MIA,72202012839.0,25.788,-80.317,25.79319954,-80.29060364,8.0,2.7052797087850293,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,JFK,America/New_York,2015-01-01T10:45:00.000+0000,2015-01-01T08:45:00.000+0000,2015-01-01T06:45:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/New_York,6,40,06:40,2015-01-01 06:40,2015-01-01T06:40:00.000+0000,2015-01-01T14:59:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,137439444295,AAN3LLAAJFK2015-01-01545,,,,no,0,,0.0,,,,no,0,,0.0,LAX,0,0.17,274,208,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N794JB,939,12478,1247802,JFK,NY,36,New York,10732,1073203,BQN,PR,72,Puerto Rico,545,545,0.0,0.0,0.0,0,0001-0559,1026,1007,-19.0,0.0,0.0,0.0,202.0,1576.0,7,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,,,,,,,841.0,,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,JFK,America/New_York,2015-01-01T10:45:00.000+0000,2015-01-01T08:45:00.000+0000,2015-01-01T06:45:00.000+0000,74486094789,2015-01-01T07:51:00.000+0000,40.6386,-73.7622,3.4,FM-15,America/Puerto_Rico,5,45,05:45,2015-01-01 05:45,2015-01-01T05:45:00.000+0000,2015-01-01T14:07:00.000+0000,1,0,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,163208893453,B6N794JBJFK2015-01-01545,,,,no,0,,0.0,,,,no,0,,0.0,HPN,0,0.12,274,6,0.13,2015-1,0,-1,1,N,62,22000,16093,10212,-17,-117,10213,10205,4267.0,62,22000.0,16093,10212,-17,-117,10213.0,10205,0,4267.0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10212.0,62.0,22000.0,16093.0,-17.0,-117.0,10213.0,10205.0,4267.0,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N531JB,583,12478,1247802,JFK,NY,36,New York,13204,1320402,MCO,FL,12,Florida,600,557,-3.0,0.0,0.0,-1,0600-0659,851,832,-19.0,0.0,0.0,0.0,155.0,944.0,4,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,MCO,72205012815.0,28.434,-81.325,28.42939949,-81.30899811,96.0,1.6462168285180496,6,0,06:00,2015-01-01 06:00,2015-01-01T06:00:00.000+0000,JFK,America/New_York,2015-01-01T11:00:00.000+0000,2015-01-01T09:00:00.000+0000,2015-01-01T07:00:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/New_York,5,57,05:57,2015-01-01 05:57,2015-01-01T05:57:00.000+0000,2015-01-01T13:32:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,94489280568,B6N531JBJFK2015-01-01600,,,,no,0,,0.0,1.0,2015-01-01T11:33:00.000+0000,-33.0,no,0,47.0,0.0,JFK,0,0.12,274,303,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,DL,19790,N967DL,421,12478,1247802,JFK,NY,36,New York,10397,1039705,ATL,GA,13,Georgia,600,605,5.0,5.0,0.0,0,0600-0659,843,824,-19.0,0.0,0.0,0.0,139.0,760.0,4,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,ATL,72219013874.0,33.63,-84.442,33.6367,-84.428101,1026.0,1.4868906335816192,6,0,06:00,2015-01-01 06:00,2015-01-01T06:00:00.000+0000,JFK,America/New_York,2015-01-01T11:00:00.000+0000,2015-01-01T09:00:00.000+0000,2015-01-01T07:00:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/New_York,6,5,06:05,2015-01-01 06:05,2015-01-01T06:05:00.000+0000,2015-01-01T13:24:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,463,DLN967DLJFK2015-01-01600,,,,no,0,,0.0,,,,no,0,,0.0,BHM,0,0.0,274,696,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N570JB,353,12478,1247802,JFK,NY,36,New York,14027,1402702,PBI,FL,12,Florida,600,554,-6.0,0.0,0.0,-1,0600-0659,858,836,-22.0,0.0,0.0,0.0,162.0,1028.0,5,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,PBI,72203012844.0,26.685,-80.099,26.68320084,-80.09559631,19.0,0.3929102161055523,6,0,06:00,2015-01-01 06:00,2015-01-01T06:00:00.000+0000,JFK,America/New_York,2015-01-01T11:00:00.000+0000,2015-01-01T09:00:00.000+0000,2015-01-01T07:00:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/New_York,5,54,05:54,2015-01-01 05:54,2015-01-01T05:54:00.000+0000,2015-01-01T13:36:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,154618827766,B6N570JBJFK2015-01-01600,,,,no,0,,0.0,,,,no,0,,0.0,PSE,0,0.12,274,81,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N645JB,525,12478,1247802,JFK,NY,36,New York,15304,1530402,TPA,FL,12,Florida,600,554,-6.0,0.0,0.0,-1,0600-0659,901,900,-1.0,0.0,0.0,0.0,186.0,1005.0,5,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,TPA,72211012842.0,27.962,-82.54,27.97550011,-82.53320313,26.0,1.642864417823524,6,0,06:00,2015-01-01 06:00,2015-01-01T06:00:00.000+0000,JFK,America/New_York,2015-01-01T11:00:00.000+0000,2015-01-01T09:00:00.000+0000,2015-01-01T07:00:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/New_York,5,54,05:54,2015-01-01 05:54,2015-01-01T05:54:00.000+0000,2015-01-01T14:00:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,60129542185,B6N645JBJFK2015-01-01600,,,,no,0,,0.0,,,,no,0,,0.0,JFK,0,0.12,274,177,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,UA,19977,N509UA,415,12478,1247802,JFK,NY,36,New York,14771,1477101,SFO,CA,6,California,605,605,0.0,0.0,0.0,0,0600-0659,949,945,-4.0,0.0,0.0,0.0,400.0,2586.0,11,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,SFO,72494023234.0,37.62,-122.365,37.61899948,-122.375,13.0,0.8877548411465989,6,5,06:05,2015-01-01 06:05,2015-01-01T06:05:00.000+0000,JFK,America/New_York,2015-01-01T11:05:00.000+0000,2015-01-01T09:05:00.000+0000,2015-01-01T07:05:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/Los_Angeles,6,5,06:05,2015-01-01 06:05,2015-01-01T06:05:00.000+0000,2015-01-01T17:45:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,60130380498,UAN509UAJFK2015-01-01605,,,,no,0,,0.0,,,,no,0,,0.0,,0,0.0,274,411,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N559JB,601,12478,1247802,JFK,NY,36,New York,11697,1169703,FLL,FL,12,Florida,605,603,-2.0,0.0,0.0,-1,0600-0659,913,911,-2.0,0.0,0.0,0.0,188.0,1069.0,5,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,FLL,74783012849.0,26.079,-80.162,26.072599,-80.152702,9.0,1.1700439648776644,6,5,06:05,2015-01-01 06:05,2015-01-01T06:05:00.000+0000,JFK,America/New_York,2015-01-01T11:05:00.000+0000,2015-01-01T09:05:00.000+0000,2015-01-01T07:05:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/New_York,6,3,06:03,2015-01-01 06:03,2015-01-01T06:03:00.000+0000,2015-01-01T14:11:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,52,B6N559JBJFK2015-01-01605,,,,no,0,,0.0,,,,no,0,,0.0,MCO,0,0.12,274,222,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883
2015,1,1,1,4,2015-01-01,B6,20409,N763JB,1403,12478,1247802,JFK,NY,36,New York,14843,1484304,SJU,PR,72,Puerto Rico,614,612,-2.0,0.0,0.0,-1,0600-0659,1057,1034,-23.0,0.0,0.0,0.0,202.0,1598.0,7,,,,,JFK,74486094789,40.639,-73.762,40.63980103,-73.77890015,13.0,1.4287707081898418,,,,,,,841.0,,6,14,06:14,2015-01-01 06:14,2015-01-01T06:14:00.000+0000,JFK,America/New_York,2015-01-01T11:14:00.000+0000,2015-01-01T09:14:00.000+0000,2015-01-01T07:14:00.000+0000,74486094789,2015-01-01T09:00:00.000+0000,40.6386,-73.7622,3.4,FM-12,America/Puerto_Rico,6,12,06:12,2015-01-01 06:12,2015-01-01T06:12:00.000+0000,2015-01-01T14:34:00.000+0000,0,1,0,mid_atlantic,0,1,0,0,0,0,0,0,0,0,0,1,2015-01-01,1,180388794339,B6N763JBJFK2015-01-01614,,,,no,0,,0.0,,,,no,0,,0.0,BUR,0,0.12,274,83,0.13,2015-1,0,-1,1,N,77,99999,16000,10208,-17,-133,99999,10199,,77,,16000,10208,-17,-133,,10199,0,,0,0,0,0,0,0,0,0,0,0,0,0,0,0,10208.0,77.0,9351.858930602955,16000.0,-17.0,-133.0,10190.759954493742,10199.0,1968.7180124223603,0.0112327556603591,0.0112457813144883


In [0]:
# Checks the size of the toydata
df_6m_join.filter((df_6m_join.FL_DATE == '2015-01-01')).count()

In [0]:
# Saves few features for model use
to_keep = ['prior_dep_delayed', 'previous_DEP_DELAY_NEW', 'plane_is_here', 'Evening', 'previous_ARR_DELAY_NEW', 'avg_carrier_delay_24hrs', 'prior_arr_delayed', 'avg_ori_airport_delay_24hrs', 'DEP_DEL15' ]
df_6m_join.select(to_keep).filter(df_6m_join.FL_DATE == '2015-01-01').cache()
toy_data = df_6m_join.select(to_keep).filter(df_6m_join.FL_DATE == '2015-01-01')

In [0]:
# Generate 80/20 (pseudo)random train/test split 
train, test = toy_data.randomSplit([.8,.2], seed = 1)

In [0]:
# Saves the true y values for computing metrics later
actual = test.select(test.DEP_DEL15).toPandas()['DEP_DEL15'].values

In [0]:
# Remove the output variable from train and test set
for_model = ['prior_dep_delayed', 'previous_DEP_DELAY_NEW', 'plane_is_here', 'Evening', 'previous_ARR_DELAY_NEW', 'avg_carrier_delay_24hrs', 'prior_arr_delayed', 'avg_ori_airport_delay_24hrs' ]
train = train.select(for_model)
test = test.select(for_model)

In [0]:
# Check set sizes
print(train.count())
print(test.count())

In [0]:
# Transform into RDDs
trainRDD = train.rdd.map(lambda x: (x[0:],x[0])).cache()
testRDD = test.rdd.map(lambda x: (x[0:],x[0])).cache()

### 3. Train Model

In [0]:
# Defined for logistic regression calculation
def sigmoid(x):
    return 1 / (1 + math.exp(-x))

The following cell defines the cost/loss function for logistic regression that we will try to minimize. If our predictions from the logit model are close to the 
actual values, then the cost will be small. If our predictions are off, then the cost will be higher. Also note that our loss function is convex, meaning that the function 
does have a single minimum.

In [0]:
# Logistic loss function
def LogLoss(dataRDD, W): 
    """
    Compute logistic loss error.
    Args:
        dataRDD - each record is a tuple of (features_array, y)
        W       - (array) model coefficients with bias at index 0
    """
    
    augmentedData = dataRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
    ################## YOUR CODE HERE ##################
    loss = augmentedData.map(lambda x: x[1]*np.log(sigmoid(W.dot(x[0]))) + (1-x[1])*np.log(1 - sigmoid(W.dot(x[0])))).mean()*-1
    ################## (END) YOUR CODE ##################
    return loss

The next cell defines the gradient function. Without going into the math behind it, a gradient tells us what direction the model needs to go to minimize our loss. This will help us update the model's parameters on the next iteration. This process is called gradient descent.  

Logistic regression can be prone to overfitting, meaning that they will not generalize well to other datasets after training. In order to combat overfitting, we are adding a regularization term to the model that penalizes using more variables (L1) or too large of coefficients on our variables (L2).

In [0]:
# Ridge/lasso gradient function
def GDUpdate_wReg(dataRDD, W, learningRate = 0.1, regType='ridge', regParam = 0.1):
    """
    Perform one gradient descent step/update with ridge or lasso regularization.
    Args:
        dataRDD - tuple of (features_array, y)
        W       - (array) model coefficients with bias at index 0
        learningRate - (float) defaults to 0.1
        regType - (str) 'ridge' or 'lasso', defaults to None
        regParam - (float) regularization term coefficient
    Returns:
        model   - (array) updated coefficients, bias still at index 0
    """
    # augmented data
    N=dataRDD.count()
    augmentedData = dataRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
    
    new_model = None
    #################### YOUR CODE HERE ###################
    grad = augmentedData.map(lambda x: ((sigmoid(W.dot(x[0])) - x[1])*x[0])).sum()
    if regType == 'ridge':
        grad += regParam * np.append([0.0], W[1:])
    elif regType == 'lasso':
        grad += regParam * np.append([0.0], np.sign(W)[1:])
    new_model = W - learningRate * grad/N
    ################## (END) YOUR CODE ####################
    return new_model

The next cell defines a gradient descent function that will take the gradient of a model and use that to update the model weights. It repeats this multiple times in order to find the 'best' weights that minimize the loss function. Since the loss function is convex, we can gaurantee that the minimum found was the global minimum: the model is the optimal model given the problem and our features used to predict it.

In [0]:
# Gradient descent function
def GradientDescent_wReg(trainRDD, testRDD, wInit, nSteps = 100, learningRate = 0.1,
                         regType='ridge', regParam = 0.1, verbose = False):
    """
    Perform nSteps iterations of regularized gradient descent and 
    track loss on a test and train set. Return lists of
    test/train loss and the models themselves.
    """
    # initialize lists to track model performance
    train_history, test_history, model_history = [], [], []
    
    # perform n updates & compute test and train loss after each
    model = wInit
    for idx in range(nSteps):  
        # update the model
        model = GDUpdate_wReg(trainRDD, model, learningRate, regType, regParam)
        
        # keep track of test/train loss for plotting
        train_history.append(LogLoss(trainRDD, model))
        test_history.append(LogLoss(testRDD, model))
        model_history.append(model)
        
        # console output if desired
        if verbose:
            print("----------")
            print(f"STEP: {idx+1}")
            print(f"Model: {[round(w,3) for w in model]}")
    return train_history, test_history, model_history

In [0]:
## Creates logistic model

# Instantiates a baseline model to be improved through gradient descent
wInit = np.random.uniform(0,1,9)

# Computes logstic regression with ridge regularization
ridge_results = GradientDescent_wReg(trainRDD, testRDD, wInit, nSteps = 50, regType='ridge', regParam = 0.1)

In [0]:
# Predict probabilities 
w = ridge_results[2][-1] # final model
augmentedTestData = testRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
results = augmentedTestData.map(lambda x: (sigmoid(w.dot(x[0])),x[1])).collect() 

In [0]:
# Sets prediction to 1 if probability is greater than or equal to 0.5, and 0 otherwise.
df = pd.DataFrame(results)
df['pred'] = df[0] >= .5
predicted = df['pred'].astype(int)

In [0]:
# Checking the actuals and predicted
print(np.array(predicted))
print(actual)
print(confusion_matrix(actual,np.array(predicted)))

In [0]:
# Save true positive, false positive, true negative, and false negative
tp, fp, tn, fn = confusion_matrix(actual,np.array(predicted)).ravel()
print(tn, fp, fn, tp)

In [0]:
# Calculate metrics
precision = tp/(tp + fp)
recall = tp/(tp + fn)
F1 = 2*(precision*recall)/(precision+recall)

In [0]:
print("Precision: " + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(F1))

### 4. Calculation Checks

In [0]:
#TESTING THE CONSISTANCY
ridge_results = GradientDescent_wReg(trainRDD, testRDD, wInit, nSteps = 50, regType='ridge', regParam = 0.1)
# predict probabilities for homegrown
w = ridge_results[2][-1] # final model
augmentedTestData = testRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
results = augmentedTestData.map(lambda x: (sigmoid(w.dot(x[0])),x[1])).collect() 
# set prediction to 1 if probability is greater than or equal to 0.5, and 0 otherwise.
df = pd.DataFrame(results)
df['pred'] = df[0] >= .5
predicted = df['pred'].astype(int)
confusion_matrix(actual,np.array(predicted))
tp, fp, tn, fn = confusion_matrix(actual,np.array(predicted)).ravel()
print(tn, fp, fn, tp)
precision = tp/(tp + fp)
recall = tp/(tp + fn)
F1 = 2*(precision*recall)/(precision+recall)
print("Precision: " + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(F1))

In [0]:
#TESTING DIFF NSTEPS
ridge_results = GradientDescent_wReg(trainRDD, testRDD, wInit, nSteps = 10, regType='ridge', regParam = 0.1)
# predict probabilities for homegrown
w = ridge_results[2][-1] # final model
augmentedTestData = testRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
results = augmentedTestData.map(lambda x: (sigmoid(w.dot(x[0])),x[1])).collect() 
# set prediction to 1 if probability is greater than or equal to 0.5, and 0 otherwise.
df = pd.DataFrame(results)
df['pred'] = df[0] >= .5
predicted = df['pred'].astype(int)
confusion_matrix(actual,np.array(predicted))
tp, fp, tn, fn = confusion_matrix(actual,np.array(predicted)).ravel()
print(tn, fp, fn, tp)
precision = tp/(tp + fp)
recall = tp/(tp + fn)
F1 = 2*(precision*recall)/(precision+recall)
print("Precision: " + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(F1))

In [0]:
#TESTING DIFF NSTEPS
ridge_results = GradientDescent_wReg(trainRDD, testRDD, wInit, nSteps = 100, regType='ridge', regParam = 0.1)
# predict probabilities for homegrown
w = ridge_results[2][-1] # final model
augmentedTestData = testRDD.map(lambda x: (np.append([1.0], x[0]), x[1]))
results = augmentedTestData.map(lambda x: (sigmoid(w.dot(x[0])),x[1])).collect() 
# set prediction to 1 if probability is greater than or equal to 0.5, and 0 otherwise.
df = pd.DataFrame(results)
df['pred'] = df[0] >= .5
predicted = df['pred'].astype(int)
confusion_matrix(actual,np.array(predicted))
tp, fp, tn, fn = confusion_matrix(actual,np.array(predicted)).ravel()
print(tn, fp, fn, tp)
precision = tp/(tp + fp)
recall = tp/(tp + fn)
F1 = 2*(precision*recall)/(precision+recall)
print("Precision: " + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(F1))

### 5. Sklearn Check

The following section will run another logistic regression with our toy_data using the sklearn LogisticRegression function so that we can check that our results are correct, based on the final precision, recall and F1 scores.

In [0]:
toy_data = toy_data.toPandas()
Y = toy_data.DEP_DEL15.copy()
X = toy_data.drop(['DEP_DEL15'], axis =1)


In [0]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
model = LogisticRegression(penalty='l1', max_iter=2000, random_state=0, solver='liblinear')
model.fit(x_train, y_train)
y_pred = pd.Series(model.predict(x_test))
y_test = y_test.reset_index(drop=True)

In [0]:
z = pd.concat([y_test, y_pred], axis=1)
z.columns = ['Actual', 'Prediction']
print("Precision:", metrics.precision_score(y_test, y_pred))
print("Recall:", metrics.recall_score(y_test, y_pred))
print("F1:", metrics.f1_score(y_test, y_pred))

In [0]:
tp, fp, tn, fn = confusion_matrix(y_test, y_pred).ravel()
print(tn, fp, fn, tp)
confusion_matrix(y_test, y_pred)

In [0]:
precision = tp/(tp + fp)
recall = tp/(tp + fn)
F1 = 2*(precision*recall)/(precision+recall)

In [0]:
print("Precision: " + str(precision))
print("Recall:" + str(recall))
print("F1:" + str(F1))

###6. Conclusions
The model we created that used 50 iterations and an L2 regularization term had a precision of 0.9497560975609756, a recall of 0.921875, and an F1 score of 0.9356078808265257l. The high precision indicates that the model did a good job of predicting the data it was given.
The high recall indicates that we had few false negatives. The high F1 score indicates that we had relatively balanced performance between false positives and negatives, without too many of either. 
However, this model does not take into account data leakage and uses future data to predict past delays. Our final logistic regression model makes use of rolling windows for the cross validation to prevent data leakage.

###References

[Loistic Regression](https://blog.exsilio.com/all/accuracy-precision-recall-f1-score-interpretation-of-performance-measures/)

[Logistic regression](https://developers.google.com/machine-learning/crash-course/logistic-regression/model-training)

[Loss Function](https://www.datarobot.com/blog/introduction-to-loss-functions/)