Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Tutorial: Load TAXI data and enrich it with Weather data in Pandas DataFrame

Begin by creating a dataframe to hold the taxi data. To download 2 months of taxi data, iteratively fetch one month at a time, and before appending it to green_taxi_df randomly sample 0.1% records from the specific month to avoid bloating the dataframe.

In [4]:
import pandas as pd
from datetime import datetime
from dateutil.relativedelta import relativedelta
from azureml.opendatasets import NycTlcGreen
from functools import reduce  # For Python 3.x
from pyspark.sql import DataFrame


start = datetime.strptime("1/1/2016", "%m/%d/%Y")
end = datetime.strptime("1/31/2016", "%m/%d/%Y")

dfs = []
for sample_month in range(2):
    temp_df_green = NycTlcGreen(
        start + relativedelta(months=sample_month),
        end + relativedelta(months=sample_month)).to_spark_dataframe()
    dfs.append(temp_df_green.sample(False, 0.001, 3))

green_taxi_df = reduce(DataFrame.unionAll, dfs)

Save a copy of the raw_columns name list for clean up at the last step.

In [6]:
raw_columns = list(green_taxi_df.columns)

NYC Latitude & Longitude: (40.71455, -74.00712) found by Bing search.

Add to taxi dataframe

Make all Latitude and Longitude be the location of New York City.

In [9]:
from pyspark.sql.functions import lit

nyc_lat, nyc_long = (40.71455, -74.00712)
green_taxi_df = green_taxi_df.withColumn('lat', lit(nyc_lat)).withColumn('long', lit(nyc_long))
display(green_taxi_df.limit(5))

vendorID,lpepPickupDatetime,lpepDropoffDatetime,passengerCount,tripDistance,puLocationId,doLocationId,pickupLongitude,pickupLatitude,dropoffLongitude,dropoffLatitude,rateCodeID,storeAndFwdFlag,paymentType,fareAmount,extra,mtaTax,improvementSurcharge,tipAmount,tollsAmount,ehailFee,totalAmount,tripType,puYear,puMonth,lat,long
1,2016-01-18T17:14:03.000+0000,2016-01-18T17:26:34.000+0000,1,1.6,,,-73.95868682861328,40.71489334106445,-73.95039367675781,40.69751739501953,1,N,2,9.5,0.0,0.5,0.3,0.0,0.0,,10.3,1,2016,1,40.71455,-74.00712
2,2016-01-18T18:00:42.000+0000,2016-01-18T18:08:53.000+0000,1,1.37,,,-73.95466613769531,40.789363861083984,-73.95391082763672,40.77495574951172,1,N,1,7.5,0.0,0.5,0.3,1.24,0.0,,9.54,1,2016,1,40.71455,-74.00712
2,2016-01-18T18:28:33.000+0000,2016-01-18T18:51:43.000+0000,1,5.52,,,-73.99703216552734,40.68907165527344,-73.9892349243164,40.745548248291016,1,N,1,20.0,0.0,0.5,0.3,5.0,0.0,,25.8,1,2016,1,40.71455,-74.00712
2,2016-01-18T18:50:23.000+0000,2016-01-18T18:56:13.000+0000,2,1.16,,,-73.90315246582031,40.745941162109375,-73.91759490966797,40.744773864746094,1,N,2,6.0,0.0,0.5,0.3,0.0,0.0,,6.8,1,2016,1,40.71455,-74.00712
1,2016-01-18T18:51:36.000+0000,2016-01-18T19:00:36.000+0000,1,1.6,,,-73.98666381835938,40.70247268676758,-73.97904968261719,40.68352127075195,1,N,1,8.0,0.0,0.5,0.3,1.2,0.0,,10.0,1,2016,1,40.71455,-74.00712


Initialize LocationTimeCustomerData using pandas dataframe green_taxi.

In [11]:
from azureml.opendatasets.accessories.location_data import LatLongColumn
from azureml.opendatasets.accessories.location_time_customer_data \
    import LocationTimeCustomerData
from azureml.opendatasets import NoaaIsdWeather


green_taxi = LocationTimeCustomerData(
    green_taxi_df,
    LatLongColumn('lat', 'long'),
    'lpepPickupDatetime')

In [12]:
spark.conf.set('spark.sql.crossJoin.enabled', 'true')

Initialize NoaaIsdWeather class, get enricher from it, and enrich the taxi data without aggregation

In [14]:
weather = NoaaIsdWeather(
    cols=["temperature", "precipTime", "precipDepth", "snowDepth"],
    start_date=datetime(2016, 1, 1, 0, 0),
    end_date=datetime(2016, 2, 28, 23, 59))
weather_enricher = weather.get_enricher()
new_green_taxi, processed_weather = weather_enricher.enrich_customer_data_no_agg(
    customer_data_object=green_taxi,
    location_match_granularity=5,
    time_round_granularity='day')

Preview the pandas dataframe new_green_taxi.data

In [16]:
display(new_green_taxi.data.limit(3))

lat,long,vendorID,lpepPickupDatetime,lpepDropoffDatetime,passengerCount,tripDistance,puLocationId,doLocationId,pickupLongitude,pickupLatitude,dropoffLongitude,dropoffLatitude,rateCodeID,storeAndFwdFlag,paymentType,fareAmount,extra,mtaTax,improvementSurcharge,tipAmount,tollsAmount,ehailFee,totalAmount,tripType,puYear,puMonth,row_id,customer_rankgrouprbzmn,customer_join_time1v3cp
40.71455,-74.00712,1,2016-01-18T17:14:03.000+0000,2016-01-18T17:26:34.000+0000,1,1.6,,,-73.95868682861328,40.71489334106445,-73.95039367675781,40.69751739501953,1,N,2,9.5,0.0,0.5,0.3,0.0,0.0,,10.3,1,2016,1,77309411328,1,2016-01-18T00:00:00.000+0000
40.71455,-74.00712,2,2016-01-18T18:00:42.000+0000,2016-01-18T18:08:53.000+0000,1,1.37,,,-73.95466613769531,40.789363861083984,-73.95391082763672,40.77495574951172,1,N,1,7.5,0.0,0.5,0.3,1.24,0.0,,9.54,1,2016,1,77309411329,1,2016-01-18T00:00:00.000+0000
40.71455,-74.00712,2,2016-01-18T18:28:33.000+0000,2016-01-18T18:51:43.000+0000,1,5.52,,,-73.99703216552734,40.68907165527344,-73.9892349243164,40.745548248291016,1,N,1,20.0,0.0,0.5,0.3,5.0,0.0,,25.8,1,2016,1,77309411330,1,2016-01-18T00:00:00.000+0000


Define a dict `aggregations` to define how to aggregate each field at a hour level. For `snowDepth` and `temperature` we'll take the mean and for `precipTime` and `precipDepth` we'll take the hourly maximum. Use the groupby() function along with the aggregations to group data.

In [18]:
aggregations = {
    "snowDepth": "mean",
    "precipTime": "max",
    "temperature": "mean",
    "precipDepth": "max"}

The keys (`public_rankgroup`, `public_join_time`, `customer_rankgroup`, `customer_join_time`) used by groupby() and later merge() must be hacked here due to the current design.

In [20]:
public_rankgroup = processed_weather.id

public_join_time = [
    s for s in list(processed_weather.data.columns)
    if s.startswith('ds_join_time')][0]

customer_rankgroup = weather_enricher.location_selector.customer_rankgroup

customer_join_time = [
    s for s in list(new_green_taxi.data.columns)
    if type(s) is str and s.startswith('customer_join_time')][0]

weather_df_grouped = processed_weather.data.groupby(public_rankgroup, public_join_time).agg(aggregations)
display(weather_df_grouped.limit(3))

public_rankgroup1qf8h,ds_join_timevotkg,avg(snowDepth),avg(temperature),max(precipDepth),max(precipTime)
1,2016-01-13T00:00:00.000+0000,0.0,-2.266428571428573,3.0,24.0
1,2016-02-07T00:00:00.000+0000,1.8,3.8400000000000007,0.0,24.0
1,2016-01-21T00:00:00.000+0000,0.0,-0.0722627737226275,0.0,24.0


Join the final dataframe, and preview the joined result.

In [22]:
taxi_df = new_green_taxi.data
joined_dataset = taxi_df.join(
    weather_df_grouped,
    [taxi_df[customer_rankgroup] == weather_df_grouped[public_rankgroup],
     taxi_df[customer_join_time] == weather_df_grouped[public_join_time]],
    how='left')

final_df = joined_dataset.select(raw_columns + [
    "avg(temperature)", "max(precipTime)", "max(precipDepth)", "avg(snowDepth)"])
display(final_df.limit(5))

vendorID,lpepPickupDatetime,lpepDropoffDatetime,passengerCount,tripDistance,puLocationId,doLocationId,pickupLongitude,pickupLatitude,dropoffLongitude,dropoffLatitude,rateCodeID,storeAndFwdFlag,paymentType,fareAmount,extra,mtaTax,improvementSurcharge,tipAmount,tollsAmount,ehailFee,totalAmount,tripType,puYear,puMonth,avg(temperature),max(precipTime),max(precipDepth),avg(snowDepth)
2,2016-01-13T00:56:41.000+0000,2016-01-13T01:05:00.000+0000,1,1.85,,,-73.95475006103516,40.687801361083984,-73.95307922363281,40.70832824707031,1,N,1,8.5,0.5,0.5,0.3,1.2,0.0,,11.0,1,2016,1,-2.266428571428573,24.0,3.0,0.0
2,2016-01-13T01:46:31.000+0000,2016-01-13T01:53:55.000+0000,1,1.75,,,-73.8910903930664,40.74677658081055,-73.88246154785156,40.7307014465332,1,N,2,8.0,0.5,0.5,0.3,0.0,0.0,,9.3,1,2016,1,-2.266428571428573,24.0,3.0,0.0
1,2016-01-13T01:49:57.000+0000,2016-01-13T01:56:53.000+0000,2,1.4,,,-73.8910903930664,40.74702453613281,-73.86837768554688,40.75227355957031,1,N,2,7.0,0.5,0.5,0.3,0.0,0.0,,8.3,1,2016,1,-2.266428571428573,24.0,3.0,0.0
1,2016-01-13T01:41:11.000+0000,2016-01-13T02:02:37.000+0000,3,6.1,,,-73.95850372314453,40.719234466552734,-74.001708984375,40.73370361328125,1,N,2,21.0,0.5,0.5,0.3,0.0,0.0,,22.3,1,2016,1,-2.266428571428573,24.0,3.0,0.0
2,2016-01-13T04:17:12.000+0000,2016-01-13T04:25:11.000+0000,1,1.07,,,-73.89081573486328,40.74679565429688,-73.87198638916016,40.74689102172852,1,N,2,6.5,0.5,0.5,0.3,0.0,0.0,,7.8,1,2016,1,-2.266428571428573,24.0,3.0,0.0


Check the join success rate.

In [24]:
final_df.toPandas().info()

In [25]:
final_df.createOrReplaceTempView('joined_df')

In [26]:
%sql
select * from joined_df
where lpepPickupDatetime >= '2016-01-26' and lpepPickupDatetime < '2016-01-27'
order by lpepPickupDatetime limit 5

vendorID,lpepPickupDatetime,lpepDropoffDatetime,passengerCount,tripDistance,puLocationId,doLocationId,pickupLongitude,pickupLatitude,dropoffLongitude,dropoffLatitude,rateCodeID,storeAndFwdFlag,paymentType,fareAmount,extra,mtaTax,improvementSurcharge,tipAmount,tollsAmount,ehailFee,totalAmount,tripType,puYear,puMonth,avg(temperature),max(precipTime),max(precipDepth),avg(snowDepth)
2,2016-01-26T00:02:33.000+0000,2016-01-26T00:16:54.000+0000,2,3.27,,,-73.95603942871094,40.71393966674805,-73.90204620361328,40.70497131347656,1,N,2,12.5,0.5,0.5,0.3,0.0,0.0,,13.8,1,2016,1,4.209285714285715,24.0,0.0,40.06896551724138
2,2016-01-26T01:50:11.000+0000,2016-01-26T01:56:21.000+0000,1,2.06,,,-73.90959930419922,40.77006530761719,-73.88544464111328,40.75567626953125,1,N,2,8.0,0.5,0.5,0.3,0.0,0.0,,9.3,1,2016,1,4.209285714285715,24.0,0.0,40.06896551724138
2,2016-01-26T02:45:59.000+0000,2016-01-26T03:04:32.000+0000,1,5.59,,,-73.9578857421875,40.80094146728516,-73.93751525878906,40.84716033935547,1,N,1,19.5,0.5,0.5,0.3,2.0,0.0,,22.8,1,2016,1,4.209285714285715,24.0,0.0,40.06896551724138
2,2016-01-26T07:35:27.000+0000,2016-01-26T08:04:34.000+0000,1,2.54,,,-73.95887756347656,40.6507453918457,-73.97756958007812,40.684326171875,1,N,1,18.5,0.0,0.5,0.3,0.0,0.0,,19.3,1,2016,1,4.209285714285715,24.0,0.0,40.06896551724138
2,2016-01-26T08:40:46.000+0000,2016-01-26T09:24:24.000+0000,2,6.21,,,-73.93299102783203,40.67950820922852,-74.00007629394531,40.73252868652344,1,N,1,29.5,0.0,0.5,0.3,6.06,0.0,,36.36,1,2016,1,4.209285714285715,24.0,0.0,40.06896551724138
