In [1]:
%matplotlib inline
import os
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import problem
from sklearn.preprocessing import FunctionTransformer
pd.set_option('display.max_columns', None)

In [2]:
data = pd.read_csv(
    os.path.join('data', 'train.csv.bz2')
)
data.loc[:, 'DateOfDeparture'] = pd.to_datetime(data.loc[:, 'DateOfDeparture'])

In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8902 entries, 0 to 8901
Data columns (total 6 columns):
 #   Column            Non-Null Count  Dtype         
---  ------            --------------  -----         
 0   DateOfDeparture   8902 non-null   datetime64[ns]
 1   Departure         8902 non-null   object        
 2   Arrival           8902 non-null   object        
 3   WeeksToDeparture  8902 non-null   float64       
 4   log_PAX           8902 non-null   float64       
 5   std_wtd           8902 non-null   float64       
dtypes: datetime64[ns](1), float64(3), object(2)
memory usage: 417.4+ KB


In [4]:
X, y = problem.get_train_data()

In [5]:
from sklearn.preprocessing import FunctionTransformer

def _encode_dates(X):
    # With pandas < 1.0, we wil get a SettingWithCopyWarning
    # In our case, we will avoid this warning by triggering a copy
    # More information can be found at:
    # https://github.com/scikit-learn/scikit-learn/issues/16191
    X_encoded = X.copy()

    # Make sure that DateOfDeparture is of datetime format
    X_encoded.loc[:, 'DateOfDeparture'] = pd.to_datetime(X_encoded['DateOfDeparture'])
    # Encode the DateOfDeparture
    X_encoded.loc[:, 'year'] = X_encoded['DateOfDeparture'].dt.year
    X_encoded.loc[:, 'month'] = X_encoded['DateOfDeparture'].dt.month
    X_encoded.loc[:, 'day'] = X_encoded['DateOfDeparture'].dt.day
    X_encoded.loc[:, 'weekday'] = X_encoded['DateOfDeparture'].dt.weekday
    X_encoded.loc[:, 'week'] = X_encoded['DateOfDeparture'].dt.week
    X_encoded.loc[:, 'n_days'] = X_encoded['DateOfDeparture'].apply(
        lambda date: (date - pd.to_datetime("1970-01-01")).days
    )
    # Once we did the encoding, we will not need DateOfDeparture
#     return X_encoded.drop(columns=["DateOfDeparture"])
    return X_encoded

date_encoder = FunctionTransformer(_encode_dates)
X = date_encoder.fit_transform(X)

  X_encoded.loc[:, 'week'] = X_encoded['DateOfDeparture'].dt.week


In [6]:
__file__ = os.path.join('submissions', 'starting_kit', 'estimator.py')
filepath = os.path.join(os.path.dirname(__file__), 'external_data.csv')
filepath

'submissions/starting_kit/external_data.csv'

In [7]:
class MergeTransformer():
    """Custom scaling transformer"""
    
    def read_csv_ramp(self, parse_dates=["Date"]):
        self.filepath = os.path.join(
            self.filepath, self.filename
        )
        
        data = pd.read_csv(os.path.join('data', 'train.csv.bz2'))
        if parse_dates is not None:
            ext_data = pd.read_csv(self.filepath, parse_dates=parse_dates)
        else:
            ext_data = pd.read_csv(self.filepath)
        return ext_data
    
    def merge_external_data(self):

        X = self.X.copy()  # to avoid raising SettingOnCopyWarning
        # Make sure that DateOfDeparture is of dtype datetime
        X.loc[:, "DateOfDeparture"] = pd.to_datetime(X['DateOfDeparture'])

        if not(self.filename is None):
            self.X_ext = self.read_csv_ramp(parse_dates=self.parse_dates)

        if self.cols_to_keep != 'all':
            self.X_ext = self.X_ext[self.cols_to_keep]

        if self.cols_to_rename != None:
            self.X_ext = self.X_ext.rename(columns=self.cols_to_rename)

        X_merged = pd.merge(
            X, self.X_ext, how=self.how, on=self.on, sort=False
        )
        return X_merged

    
    def __init__(self, X_ext=None, filename=None, filepath='submissions/starting_kit/', cols_to_keep='all', cols_to_rename=None, how='left', on=None, parse_dates=None):
#         super().__init__(func)
        self.X_ext = X_ext
        self.filename = filename
        self.filepath = filepath
        self.cols_to_keep = cols_to_keep
        self.cols_to_rename = cols_to_rename
        self.how = how
        self.on = on
        self.parse_dates = parse_dates
        
    def fit_transform(self, X):
        self.fit(X)
        return self.transform()

    def fit(self, X):
        self.X = X

    def transform(self):
        return self.merge_external_data()

In [8]:
merge_transform = MergeTransformer(
    X_ext=None, 
    filename='external_data.csv',
    filepath='submissions/starting_kit/',
    cols_to_rename={'Date': 'DateOfDeparture', 'AirPort': 'Arrival'}, 
    how='left',
    on=['DateOfDeparture', 'Arrival'],
    parse_dates=['Date'])

X = merge_transform.fit_transform(X)

In [9]:
coordinates_data = pd.read_csv('data/list-of-airports-in-united-states-of-america-hxl-tags-1.csv', index_col=0)
coordinates_data[coordinates_data.loc[:, 'iata_code'] == 'ORD']

coordinates_data.loc[:, 'iso_region'] = coordinates_data.loc[:, 'iso_region'].str.strip('US-')
# print(coordinates_data.head())

merge_transform = MergeTransformer(
    X_ext=coordinates_data, 
    filename=None,
    filepath=None,
    cols_to_keep=['latitude_deg', 'longitude_deg', 'iata_code', 'iso_region'], 
    cols_to_rename={'iata_code': 'Departure',
                    'latitude_deg': 'latitude_departure',
                    'longitude_deg': 'longitude_departure',
                    'iso_region': 'state'}, 
    how='left',
    on=['Departure'],
    parse_dates=None)

X = merge_transform.fit_transform(X)

merge_transform = MergeTransformer(
    X_ext=coordinates_data, 
    filename=None,
    filepath=None,
    cols_to_keep=['latitude_deg', 'longitude_deg', 'iata_code'], 
    cols_to_rename={'iata_code': 'Arrival', 'latitude_deg': 'latitude_arrival', 'longitude_deg': 'longitude_arrival'}, 
    how='left',
    on=['Arrival'],
    parse_dates=None)

X = merge_transform.fit_transform(X)

import geopy.distance

X['distance'] = X.apply(lambda x: geopy.distance.geodesic(
    (x.latitude_departure, x.longitude_departure), 
    (x.latitude_arrival, x.longitude_arrival)).km, axis=1)
X

Unnamed: 0,DateOfDeparture,Departure,Arrival,WeeksToDeparture,std_wtd,year,month,day,weekday,week,n_days,Max TemperatureC,Mean TemperatureC,Min TemperatureC,Dew PointC,MeanDew PointC,Min DewpointC,Max Humidity,Mean Humidity,Min Humidity,Max Sea Level PressurehPa,Mean Sea Level PressurehPa,Min Sea Level PressurehPa,Max VisibilityKm,Mean VisibilityKm,Min VisibilitykM,Max Wind SpeedKm/h,Mean Wind SpeedKm/h,Max Gust SpeedKm/h,Precipitationmm,CloudCover,Events,WindDirDegrees,latitude_departure,longitude_departure,state,latitude_arrival,longitude_arrival,distance
0,2012-06-19,ORD,DFW,12.875000,9.812647,2012,6,19,1,25,15510,34,29,24,22,21,19,82,63,44,1012,1010,1009,16,16,16,48,29,60.0,0.00,5,,161,41.97859955,-87.90480042,IL,32.89680099487305,-97.03800201416016,1290.346797
1,2012-09-10,LAS,DEN,14.285714,9.466734,2012,9,10,0,37,15593,33,25,16,-2,-6,-8,21,14,7,1011,1008,1005,16,16,16,35,15,42.0,0.00,3,,207,36.08010101,-115.1520004,NV,39.861698150635,-104.672996521,1011.046677
2,2012-10-05,DEN,LAX,10.863636,9.035883,2012,10,5,4,40,15618,22,19,16,17,16,14,93,77,61,1018,1016,1014,16,13,8,24,8,29.0,0.00,5,Fog,266,39.861698150635,-104.672996521,CO,33.94250107,-118.4079971,1387.023784
3,2011-10-09,ATL,ORD,11.480000,7.990202,2011,10,9,6,40,15256,27,19,11,12,10,9,83,58,33,1028,1026,1024,16,16,16,23,6,29.0,0.00,1,,93,33.63669967651367,-84.4281005859375,GA,41.97859955,-87.90480042,974.957144
4,2012-02-21,DEN,SFO,11.450000,9.517159,2012,2,21,1,8,15391,16,12,8,10,8,7,93,79,64,1027,1025,1024,16,12,3,24,8,29.0,0.00,7,,300,39.861698150635,-104.672996521,CO,37.61899948120117,-122.375,1556.391964
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9616,2012-09-25,DFW,ORD,12.772727,10.641034,2012,9,25,1,39,15608,25,17,9,17,11,4,93,68,43,1012,1011,1009,16,16,16,27,15,34.0,0.00,3,,216,32.89680099487305,-97.03800201416016,TX,41.97859955,-87.90480042,1290.346797
9617,2012-01-19,SFO,LAS,11.047619,7.908705,2012,1,19,3,3,15358,13,8,3,-7,-9,-12,40,31,22,1021,1016,1013,16,16,16,16,8,19.0,0.00,6,,197,37.61899948120117,-122.375,CA,36.08010101,-115.1520004,666.249783
9618,2013-02-03,ORD,PHL,6.076923,4.030334,2013,2,3,6,5,15739,1,-1,-3,-3,-6,-9,92,72,51,1018,1013,1010,16,13,2,40,9,58.0,0.25,7,Snow,296,41.97859955,-87.90480042,IL,39.87189865112305,-75.24109649658203,1090.917547
9619,2013-02-03,ORD,PHL,6.076923,4.030334,2013,2,3,6,5,15739,1,-1,-3,-3,-6,-9,92,72,51,1018,1013,1010,16,13,2,40,9,58.0,0.25,7,Snow,296,41.97859955,-87.90480042,IL,0,0,9837.635043


In [10]:
merge_transform = MergeTransformer(
    X_ext=None, 
    filename='oil_price.csv',
    filepath='data/',
    cols_to_keep=['date', 'value'], 
    cols_to_rename={'date': 'DateOfDeparture', 'value': 'OilPrice'},
    how='left',
    on=['DateOfDeparture'],
    parse_dates=['date'])

X = merge_transform.fit_transform(X)

In [11]:
import holidays

us_holidays = holidays.US(years=2011, state='CA')
for key, value in us_holidays.items():
    print(f"key = {key}, value = {value}")

key = 2011-01-01, value = New Year's Day
key = 2010-12-31, value = New Year's Day (Observed)
key = 2011-01-17, value = Martin Luther King Jr. Day
key = 2011-02-21, value = Washington's Birthday
key = 2011-03-31, value = César Chávez Day
key = 2011-05-30, value = Memorial Day
key = 2011-07-04, value = Independence Day
key = 2011-09-05, value = Labor Day
key = 2011-10-10, value = Columbus Day
key = 2011-11-11, value = Veterans Day
key = 2011-11-24, value = Thanksgiving
key = 2011-12-25, value = Christmas Day
key = 2011-12-26, value = Christmas Day (Observed)


In [12]:
X.head()

Unnamed: 0,DateOfDeparture,Departure,Arrival,WeeksToDeparture,std_wtd,year,month,day,weekday,week,n_days,Max TemperatureC,Mean TemperatureC,Min TemperatureC,Dew PointC,MeanDew PointC,Min DewpointC,Max Humidity,Mean Humidity,Min Humidity,Max Sea Level PressurehPa,Mean Sea Level PressurehPa,Min Sea Level PressurehPa,Max VisibilityKm,Mean VisibilityKm,Min VisibilitykM,Max Wind SpeedKm/h,Mean Wind SpeedKm/h,Max Gust SpeedKm/h,Precipitationmm,CloudCover,Events,WindDirDegrees,latitude_departure,longitude_departure,state,latitude_arrival,longitude_arrival,distance,OilPrice
0,2012-06-19,ORD,DFW,12.875,9.812647,2012,6,19,1,25,15510,34,29,24,22,21,19,82,63,44,1012,1010,1009,16,16,16,48,29,60.0,0.0,5,,161,41.97859955,-87.90480042,IL,32.89680099487305,-97.03800201416016,1290.346797,84.222
1,2012-09-10,LAS,DEN,14.285714,9.466734,2012,9,10,0,37,15593,33,25,16,-2,-6,-8,21,14,7,1011,1008,1005,16,16,16,35,15,42.0,0.0,3,,207,36.08010101,-115.1520004,NV,39.861698150635,-104.672996521,1011.046677,92.39
2,2012-10-05,DEN,LAX,10.863636,9.035883,2012,10,5,4,40,15618,22,19,16,17,16,14,93,77,61,1018,1016,1014,16,13,8,24,8,29.0,0.0,5,Fog,266,39.861698150635,-104.672996521,CO,33.94250107,-118.4079971,1387.023784,97.08
3,2011-10-09,ATL,ORD,11.48,7.990202,2011,10,9,6,40,15256,27,19,11,12,10,9,83,58,33,1028,1026,1024,16,16,16,23,6,29.0,0.0,1,,93,33.63669967651367,-84.4281005859375,GA,41.97859955,-87.90480042,974.957144,
4,2012-02-21,DEN,SFO,11.45,9.517159,2012,2,21,1,8,15391,16,12,8,10,8,7,93,79,64,1027,1025,1024,16,12,3,24,8,29.0,0.0,7,,300,39.861698150635,-104.672996521,CO,37.61899948120117,-122.375,1556.391964,106.168


In [13]:
states = X.loc[:, 'state'].unique()
years = [2011, 2012, 2013]

X['bank_holidays'] = X.apply(lambda x: x.DateOfDeparture in holidays.US(years = x.year, state=x.state), axis=1)

In [14]:
school_holidays = pd.read_csv('data/holidays.csv', sep=';', parse_dates=['date'])

In [15]:
merge_transform = MergeTransformer(
    X_ext=school_holidays, 
    filename=None,
    filepath=None,
    cols_to_keep=['date', 'is_vacation'], 
    cols_to_rename={'date': 'DateOfDeparture', 'is_vacation': 'school_holidays'},
    how='left',
    on=['DateOfDeparture'],
    parse_dates=None)

X = merge_transform.fit_transform(X)

In [16]:
X.loc[:, 'holidays'] = X.loc[:, 'bank_holidays'] | X.loc[:, 'school_holidays']
X.drop(['bank_holidays', 'school_holidays'], inplace=True, axis=1)

In [17]:
X.head()

Unnamed: 0,DateOfDeparture,Departure,Arrival,WeeksToDeparture,std_wtd,year,month,day,weekday,week,n_days,Max TemperatureC,Mean TemperatureC,Min TemperatureC,Dew PointC,MeanDew PointC,Min DewpointC,Max Humidity,Mean Humidity,Min Humidity,Max Sea Level PressurehPa,Mean Sea Level PressurehPa,Min Sea Level PressurehPa,Max VisibilityKm,Mean VisibilityKm,Min VisibilitykM,Max Wind SpeedKm/h,Mean Wind SpeedKm/h,Max Gust SpeedKm/h,Precipitationmm,CloudCover,Events,WindDirDegrees,latitude_departure,longitude_departure,state,latitude_arrival,longitude_arrival,distance,OilPrice,holidays
0,2012-06-19,ORD,DFW,12.875,9.812647,2012,6,19,1,25,15510,34,29,24,22,21,19,82,63,44,1012,1010,1009,16,16,16,48,29,60.0,0.0,5,,161,41.97859955,-87.90480042,IL,32.89680099487305,-97.03800201416016,1290.346797,84.222,True
1,2012-09-10,LAS,DEN,14.285714,9.466734,2012,9,10,0,37,15593,33,25,16,-2,-6,-8,21,14,7,1011,1008,1005,16,16,16,35,15,42.0,0.0,3,,207,36.08010101,-115.1520004,NV,39.861698150635,-104.672996521,1011.046677,92.39,False
2,2012-10-05,DEN,LAX,10.863636,9.035883,2012,10,5,4,40,15618,22,19,16,17,16,14,93,77,61,1018,1016,1014,16,13,8,24,8,29.0,0.0,5,Fog,266,39.861698150635,-104.672996521,CO,33.94250107,-118.4079971,1387.023784,97.08,False
3,2011-10-09,ATL,ORD,11.48,7.990202,2011,10,9,6,40,15256,27,19,11,12,10,9,83,58,33,1028,1026,1024,16,16,16,23,6,29.0,0.0,1,,93,33.63669967651367,-84.4281005859375,GA,41.97859955,-87.90480042,974.957144,,False
4,2012-02-21,DEN,SFO,11.45,9.517159,2012,2,21,1,8,15391,16,12,8,10,8,7,93,79,64,1027,1025,1024,16,12,3,24,8,29.0,0.0,7,,300,39.861698150635,-104.672996521,CO,37.61899948120117,-122.375,1556.391964,106.168,True


In [18]:
airports_rank = pd.read_csv('data/airports_passengers.csv', sep=';', encoding = "utf-8")
airports_rank.head()

Unnamed: 0,Rank,Airports (large hubs),IATA,Major city served,State,2019,2018,2017,2016
0,1,Hartsfield-Jackson Atlanta International Airport,ATL,Atlanta,GA,,51866464,50251964,50501858
1,2,Los Angeles International Airport,LAX,Los Angeles,CA,,42626783,41232432,39636042
2,3,O'Hare International Airport,ORD,Chicago,IL,,39874879,38593028,37589899
3,4,Dallas/Fort Worth International Airport,DFW,Dallas,TX,,32800721,31816933,31283579
4,5,Denver International Airport,DEN,Denver,CO,,31363573,29809097,28267394


In [19]:
merge_transform = MergeTransformer(
    X_ext=airports_rank, 
    filename=None,
    filepath=None,
    cols_to_keep=['2016', 'IATA', 'State'], 
    cols_to_rename={'IATA': 'Departure', '2016': 'airport_departure_capacity'},
    how='left',
    on=['Departure'],
    parse_dates=None)

X = merge_transform.fit_transform(X)

merge_transform = MergeTransformer(
    X_ext=airports_rank, 
    filename=None,
    filepath=None,
    cols_to_keep=['2016', 'IATA'], 
    cols_to_rename={'IATA': 'Arrival', '2016': 'airport_arrival_capacity'},
    how='left',
    on=['Arrival'],
    parse_dates=None)

X = merge_transform.fit_transform(X)

In [20]:
X

Unnamed: 0,DateOfDeparture,Departure,Arrival,WeeksToDeparture,std_wtd,year,month,day,weekday,week,n_days,Max TemperatureC,Mean TemperatureC,Min TemperatureC,Dew PointC,MeanDew PointC,Min DewpointC,Max Humidity,Mean Humidity,Min Humidity,Max Sea Level PressurehPa,Mean Sea Level PressurehPa,Min Sea Level PressurehPa,Max VisibilityKm,Mean VisibilityKm,Min VisibilitykM,Max Wind SpeedKm/h,Mean Wind SpeedKm/h,Max Gust SpeedKm/h,Precipitationmm,CloudCover,Events,WindDirDegrees,latitude_departure,longitude_departure,state,latitude_arrival,longitude_arrival,distance,OilPrice,holidays,airport_departure_capacity,State,airport_arrival_capacity
0,2012-06-19,ORD,DFW,12.875000,9.812647,2012,6,19,1,25,15510,34,29,24,22,21,19,82,63,44,1012,1010,1009,16,16,16,48,29,60.0,0.00,5,,161,41.97859955,-87.90480042,IL,32.89680099487305,-97.03800201416016,1290.346797,84.222,True,37589899,IL,31283579
1,2012-09-10,LAS,DEN,14.285714,9.466734,2012,9,10,0,37,15593,33,25,16,-2,-6,-8,21,14,7,1011,1008,1005,16,16,16,35,15,42.0,0.00,3,,207,36.08010101,-115.1520004,NV,39.861698150635,-104.672996521,1011.046677,92.390,False,22833267,NV,28267394
2,2012-10-05,DEN,LAX,10.863636,9.035883,2012,10,5,4,40,15618,22,19,16,17,16,14,93,77,61,1018,1016,1014,16,13,8,24,8,29.0,0.00,5,Fog,266,39.861698150635,-104.672996521,CO,33.94250107,-118.4079971,1387.023784,97.080,False,28267394,CO,39636042
3,2011-10-09,ATL,ORD,11.480000,7.990202,2011,10,9,6,40,15256,27,19,11,12,10,9,83,58,33,1028,1026,1024,16,16,16,23,6,29.0,0.00,1,,93,33.63669967651367,-84.4281005859375,GA,41.97859955,-87.90480042,974.957144,,False,50501858,GA,37589899
4,2012-02-21,DEN,SFO,11.450000,9.517159,2012,2,21,1,8,15391,16,12,8,10,8,7,93,79,64,1027,1025,1024,16,12,3,24,8,29.0,0.00,7,,300,39.861698150635,-104.672996521,CO,37.61899948120117,-122.375,1556.391964,106.168,True,28267394,CO,25707101
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9616,2012-09-25,DFW,ORD,12.772727,10.641034,2012,9,25,1,39,15608,25,17,9,17,11,4,93,68,43,1012,1011,1009,16,16,16,27,15,34.0,0.00,3,,216,32.89680099487305,-97.03800201416016,TX,41.97859955,-87.90480042,1290.346797,91.370,False,31283579,TX,37589899
9617,2012-01-19,SFO,LAS,11.047619,7.908705,2012,1,19,3,3,15358,13,8,3,-7,-9,-12,40,31,22,1021,1016,1013,16,16,16,16,8,19.0,0.00,6,,197,37.61899948120117,-122.375,CA,36.08010101,-115.1520004,666.249783,100.480,False,25707101,CA,22833267
9618,2013-02-03,ORD,PHL,6.076923,4.030334,2013,2,3,6,5,15739,1,-1,-3,-3,-6,-9,92,72,51,1018,1013,1010,16,13,2,40,9,58.0,0.25,7,Snow,296,41.97859955,-87.90480042,IL,39.87189865112305,-75.24109649658203,1090.917547,,False,37589899,IL,14564419
9619,2013-02-03,ORD,PHL,6.076923,4.030334,2013,2,3,6,5,15739,1,-1,-3,-3,-6,-9,92,72,51,1018,1013,1010,16,13,2,40,9,58.0,0.25,7,Snow,296,41.97859955,-87.90480042,IL,0,0,9837.635043,,False,37589899,IL,14564419
