In [1]:
from sklearn.preprocessing import FunctionTransformer
import pandas as pd
import holidays
import warnings
from merge_transformer import MergeTransformer
from create_db import create_db

from geopy.point import Point
import geopy.distance

warnings.simplefilter(action='ignore', category=FutureWarning)

In [2]:
def check_csv_correct(X):
        
    if X.isna().sum().all() != 0:
        return False
    return True

In [13]:
database = create_db()

In [49]:
def reduce_X(X):
    
    print(X.shape)
    
    X.loc[:, 'DateOfDeparture'] = pd.to_datetime(X['DateOfDeparture'])

    start_date = '09/01/2011'
    end_date = '03/05/2013'
    
    X = X[X.DateOfDeparture.between(start_date, end_date)]
    print(X.shape)
    
    return X

In [50]:
def merge_dfs(database, submission_dir):
    
    X = database['AirportStatistics']

    merge_transform = MergeTransformer(X_ext=database['Date'], how='left', on=['DateOfDeparture'])
    X = merge_transform.fit_transform(X)
    
    merge_transform = MergeTransformer(X_ext=database['Airport'], 
                                       cols_to_keep=['AirPort', 'latitude_dep', 'longitude_dep', 'state_dep'],
                                       cols_to_rename={'iata': 'AirPort', 'latitude_deg': 'latitude_dep',
                                                       'longitude_deg': 'longitude_dep', 
                                                       'state': 'state_dep'}, 
                                       how='left', on=['AirPort'])
    X = merge_transform.fit_transform(X)
    
    merge_transform = MergeTransformer(X_ext=database['StateFeatures'],
                                       cols_to_keep=['State', 'year', 'month', 'UnemploymentRate', 'holidays', 'GDP_per_cap', 'closest_holidays'],
                                       cols_to_rename={'Abbreviation': 'state_dep'}, how='left', 
                                       on=['year', 'month', 'day', 'state_dep'])
    X = merge_transform.fit_transform(X)
    
    merge_transform = MergeTransformer(X_ext=database['Routes'],
                                       cols_to_rename={'Departure': 'AirPort'},
                                       on=['AirPort'])
    X = merge_transform.fit_transform(X)
    
    merge_transform = MergeTransformer(X_ext=database['Passengers'],
                                       cols_to_rename={'Date': 'DateOfDeparture', 'Airport': 'Arrival'},
                                       on=['Arrival', 'DateOfDeparture'])
    X = merge_transform.fit_transform(X)
    
    merge_transform = MergeTransformer(X_ext=database['Airport'],
                                       cols_to_keep=['Arrival', 'latitude_arr', 'longitude_arr'],
                                       cols_to_rename={'iata': 'Arrival', 'latitude_deg': 'latitude_arr',
                                                       'longitude_deg': 'longitude_arr'},
                                       how='left', on=['Arrival'])
    X = merge_transform.fit_transform(X)

    all_check = check_csv_correct(X)
    
    X['distance'] = X.apply(lambda x: geopy.distance.distance(
        Point(latitude=x.latitude_dep, longitude=x.longitude_dep),
        Point(latitude=x.latitude_arr, longitude=x.longitude_arr)).km, axis=1)

    features_to_keep = ['DateOfDeparture', 'AirPort',
                            'n_days', 'day_nb', 'oil_stock_price', 'oil_stock_volume', 
                            'AAL_stock_price', 'AAL_stock_volume', 'SP_stock_price',
                            'SP_stock_volume', 'day_mean', 'week_mean', 'month_mean', 'day_nb_mean', 'route_mean',
                            'Arrival', 'distance', 'Total', 'Flights', 'Booths', 'Mean per flight']                
    X = X[features_to_keep]
    
    if all_check:
        X = reduce_X(X)
        X.to_csv(submission_dir + 'external_data.csv')
    else:
        raise Exception("Sorry will not create CSV as Database has Nans or some other shit. Try Again!")
        
    return all_check

In [51]:
# def create_external_data(submission_dir):
    
#     database = create_db()
#     merge_dfs(database, submission_dir)

In [52]:
merge_dfs(database, '../submissions/test/')

(109312, 21)
(70656, 21)


True