# Working with RDDs

In [38]:
import pyspark
from collections import namedtuple
from datetime import datetime
import pandas as pd
from pyspark.sql import SparkSession, types

In [2]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("NYTaxi") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/10/28 22:28:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df_green = spark.read.parquet('data/pq/green/*/*')

                                                                                

We want to implement the following, but this time with RDDs instead of the convenient API that DataFrames provide.

``` sql
SELECT 
    date_trunc('hour', lpep_pickup_datetime) AS hour, 
    PULocationID AS zone,

    SUM(total_amount) AS amount,
    COUNT(1) AS number_records
FROM
    green
WHERE
    lpep_pickup_datetime >= '2020-01-01 00:00:00'
GROUP BY
    1, 2
```

In [4]:
df_green.rdd

MapPartitionsRDD[7] at javaToPython at NativeMethodAccessorImpl.java:0

In [5]:
df_green.rdd.take(5)

                                                                                

[Row(VendorID=2, lpep_pickup_datetime=datetime.datetime(2020, 1, 23, 13, 10, 15), lpep_dropoff_datetime=datetime.datetime(2020, 1, 23, 13, 38, 16), store_and_fwd_flag='N', RatecodeID=1, PULocationID=74, DOLocationID=130, passenger_count=1, trip_distance=12.77, fare_amount=36.0, extra=0.0, mta_tax=0.5, tip_amount=2.05, tolls_amount=6.12, ehail_fee=None, improvement_surcharge=0.3, total_amount=44.97, payment_type=1, trip_type=1, congestion_surcharge=0.0),
 Row(VendorID=None, lpep_pickup_datetime=datetime.datetime(2020, 1, 20, 15, 9), lpep_dropoff_datetime=datetime.datetime(2020, 1, 20, 15, 46), store_and_fwd_flag=None, RatecodeID=None, PULocationID=67, DOLocationID=39, passenger_count=None, trip_distance=8.0, fare_amount=29.9, extra=2.75, mta_tax=0.5, tip_amount=0.0, tolls_amount=0.0, ehail_fee=None, improvement_surcharge=0.3, total_amount=33.45, payment_type=None, trip_type=None, congestion_surcharge=None),
 Row(VendorID=2, lpep_pickup_datetime=datetime.datetime(2020, 1, 15, 20, 23, 41)

In [6]:
# Select only the columns we need
rdd = df_green \
    .select("lpep_pickup_datetime", "PULocationID", "total_amount") \
    .rdd

In [7]:
rdd.take(5)

[Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 23, 13, 10, 15), PULocationID=74, total_amount=44.97),
 Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 20, 15, 9), PULocationID=67, total_amount=33.45),
 Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 15, 20, 23, 41), PULocationID=260, total_amount=8.3),
 Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 5, 16, 32, 26), PULocationID=82, total_amount=8.3),
 Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 29, 19, 22, 42), PULocationID=166, total_amount=12.74)]

## Implementing the `WHERE` Clause with `filter()`

In [8]:
start_date = datetime(year=2020, month=1, day=1)

def filter_date_outliers(row):
    return row.lpep_pickup_datetime >= start_date

rdd \
    .filter(filter_date_outliers)
    # Or just
    # .filter(lambda row: row.lpep_pickup_datetime >= start_date)

PythonRDD[16] at RDD at PythonRDD.scala:53

## Preparing for Grouping with `map()` and Implementing the `GROUP BY` Clause with `reduce()`

In [9]:
sample_row = rdd.take(1)[0]
sample_row

Row(lpep_pickup_datetime=datetime.datetime(2020, 1, 23, 13, 10, 15), PULocationID=74, total_amount=44.97)

In [10]:
# We only want the hour (see `date_trunc` in the SQL query)
sample_row.lpep_pickup_datetime.replace(minute=0, second=0, microsecond=0)

datetime.datetime(2020, 1, 23, 13, 0)

In [11]:
def prepare_for_grouping(row):
    # (hour, zone) comprises our composite key 
    hour = row.lpep_pickup_datetime.replace(minute=0, second=0, microsecond=0)
    zone = row.PULocationID
    
    # (amount, count) makes up our value
    amount = row.total_amount
    count = 1
    
    key = (hour, zone)
    value = (amount, count)
    
    return (key, value)

In [12]:
rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .take(5)

[((datetime.datetime(2020, 1, 23, 13, 0), 74), (44.97, 1)),
 ((datetime.datetime(2020, 1, 20, 15, 0), 67), (33.45, 1)),
 ((datetime.datetime(2020, 1, 15, 20, 0), 260), (8.3, 1)),
 ((datetime.datetime(2020, 1, 5, 16, 0), 82), (8.3, 1)),
 ((datetime.datetime(2020, 1, 29, 19, 0), 166), (12.74, 1))]

In [13]:
def custom_calc(left_value, right_value):
    left_amount, left_count = left_value
    right_amount, right_count = right_value
    
    output_amount = left_amount + right_amount
    output_count = left_count + right_count
    
    return (output_amount, output_count)

In [14]:
rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .reduceByKey(custom_calc) \
    .take(5)

                                                                                

[((datetime.datetime(2020, 1, 15, 20, 0), 260), (163.90000000000003, 14)),
 ((datetime.datetime(2020, 1, 29, 19, 0), 166), (695.0099999999999, 45)),
 ((datetime.datetime(2020, 1, 16, 8, 0), 41), (736.1399999999996, 54)),
 ((datetime.datetime(2020, 1, 4, 20, 0), 129), (583.27, 38)),
 ((datetime.datetime(2020, 1, 2, 8, 0), 66), (197.69, 10))]

That took considerably longer, because it had to go through and aggregate the entire dataset! Now let's just format the output (the nested tuples aren't very nice).

In [15]:
def unwrap(row):
    return (row[0][0], row[0][1], row[1][0], row[1][1])

In [16]:
rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .reduceByKey(custom_calc) \
    .map(unwrap) \
    .take(5)

                                                                                

[(datetime.datetime(2020, 1, 15, 20, 0), 260, 163.90000000000003, 14),
 (datetime.datetime(2020, 1, 29, 19, 0), 166, 695.0099999999999, 45),
 (datetime.datetime(2020, 1, 16, 8, 0), 41, 736.1399999999996, 54),
 (datetime.datetime(2020, 1, 4, 20, 0), 129, 583.27, 38),
 (datetime.datetime(2020, 1, 2, 8, 0), 66, 197.69, 10)]

And finally, to turn it into a DataFrame:

In [17]:
rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .reduceByKey(custom_calc) \
    .map(unwrap) \
    .toDF() \
    .show()

                                                                                

+-------------------+---+------------------+---+
|                 _1| _2|                _3| _4|
+-------------------+---+------------------+---+
|2020-01-15 20:00:00|260|163.90000000000003| 14|
|2020-01-29 19:00:00|166| 695.0099999999999| 45|
|2020-01-16 08:00:00| 41| 736.1399999999996| 54|
|2020-01-04 20:00:00|129|            583.27| 38|
|2020-01-02 08:00:00| 66|            197.69| 10|
|2020-01-03 09:00:00| 61|            142.21|  9|
|2020-01-17 21:00:00|236|              33.6|  4|
|2020-01-12 12:00:00| 82|            290.41| 14|
|2020-01-28 16:00:00|197| 831.4399999999998| 18|
|2020-01-10 22:00:00| 95| 407.7100000000002| 37|
|2020-01-10 01:00:00|215|            109.69|  2|
|2020-01-07 18:00:00| 25| 554.2900000000001| 37|
|2020-01-18 07:00:00| 55|              48.3|  1|
|2020-01-28 09:00:00|166| 473.0200000000002| 36|
|2020-01-12 15:00:00| 82| 265.7900000000001| 29|
|2020-01-10 20:00:00| 66|            405.88| 21|
|2020-01-31 15:00:00| 43|345.58000000000004| 19|
|2020-01-31 21:00:00

Oh no! Not surprisingly, we have lost our column names and schema! Fret not, as we will restore them with named tuples 😎

In [18]:
ProcessedRow = namedtuple("ProcessedRow", ["hour", "zone", "revenue", "count"])

In [19]:
def unwrap_and_add_col_names(row):
    return ProcessedRow(
        hour=row[0][0],
        zone=row[0][1],
        revenue=row[1][0],
        count=row[1][1]
    )

In [20]:
df_result = rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .reduceByKey(custom_calc) \
    .map(unwrap_and_add_col_names) \
    .toDF()

                                                                                

We didn't do `.show()` yet Spark seems to be processing the entire dataset. Why? This time `.toDF()` is the culprit. Since we have not specified a schema, Spark has to go through every single row to infer the schema.

In [21]:
df_result.schema

StructType([StructField('hour', TimestampType(), True), StructField('zone', LongType(), True), StructField('revenue', DoubleType(), True), StructField('count', LongType(), True)])

Let's provide the schema...

In [22]:
result_schema = types.StructType([
    types.StructField('hour', types.TimestampType(), True), 
    types.StructField('zone', types.IntegerType(), True), 
    types.StructField('revenue', types.DoubleType(), True), 
    types.StructField('count', types.IntegerType(), True)
])

In [23]:
df_result = rdd \
    .filter(filter_date_outliers) \
    .map(prepare_for_grouping) \
    .reduceByKey(custom_calc) \
    .map(unwrap_and_add_col_names) \
    .toDF(schema=result_schema)

Now unless we do `.show()` or another "eager" action, it will be a "lazy" transformation. Speaking of eager actions, let's save the resulting dataset to a Parquet file.

In [24]:
df_result.write.parquet("data/report/green_agg_hourly_by_zone")

                                                                                

If we look at the execution graph, we'll see that it was a two-stage one, as before. The "prep" step was stage 1 and then `.reduceByKey()` does reshuffling (the second stage).

In [25]:
!ls data/report/green_agg_hourly_by_zone

_SUCCESS
part-00000-919726c8-2ae2-44d9-ae94-3447dcccc59e-c000.snappy.parquet
part-00001-919726c8-2ae2-44d9-ae94-3447dcccc59e-c000.snappy.parquet
part-00002-919726c8-2ae2-44d9-ae94-3447dcccc59e-c000.snappy.parquet
part-00003-919726c8-2ae2-44d9-ae94-3447dcccc59e-c000.snappy.parquet


## The `mapPartitions()` Function

It is similar to `map()` but applies a function to a whole partition of data. Very convenient when we have a huge amount of data split into smaller, more manageable partitions. Effectively "chunking" the data. Most notably applicable in machine learning (e.g. inference from a trained model). Let's see it in action. Imagine we have an ML model that predicts the duration of a trip given a handful of features:

In [27]:
features = ["VendorID", "lpep_pickup_datetime", "PULocationID", "DOLocationID", "trip_distance"]
duration_rdd = df_green \
    .select(features) \
    .rdd

In [31]:
def apply_model_bactched(partition):
    return [1] # Just for demonstration

Note that the function above _must_ return an iterable object.

In [32]:
duration_rdd.mapPartitions(apply_model_bactched).collect()

                                                                                

[1, 1, 1, 1]

Why did we get this output? Because there were four partitions. The `collect()` function "flattens" the output, i.e. the original output would have been `[[1], [1], [1], [1]]`.  
Now let's see if we can take a peek inside the partitions with a sneaky "model" implementation...

In [36]:
def apply_model_bactched(partition):
    # We cannot do `cnt = len(partition)` because partitions are iterator obejects. Instead:
    
    cnt = 0
    for row in partition:
        cnt += 1
    
    return [cnt]

In [37]:
duration_rdd.mapPartitions(apply_model_bactched).collect()

                                                                                

[1141148, 438057, 432402, 292910]

We can see that our partitions are not very balanced at all. Ideally, we would want them to be balanced. To mitigate this, we could perform "repartitioning". But repartitioning is an expensive operation so we'd have to be smart about when and how we do it. While an important concept, repartitioning is outside the scope of this lesson so we won't discuss it any further right now.

As data scientists, we like 🐼s. So let's convert our Spark DataFrame to a Pandas one.

In [39]:
def apply_model_bactched(rows):
    df = pd.DataFrame(rows, columns=features)
    cnt = len(df)
    return [cnt]

**Note:** An important caveat to the above implementation is that the casting to `pd.DataFrame` will materialize the entire partition in memory. So your exectors need to have enough memory. If they don't, you'll have to, you guessed it, repartition (or otherwise manually "chunk" your data somehow, say, using [`islice()` from `itertools`](https://docs.python.org/3/library/itertools.html#itertools.islice)).

In [40]:
duration_rdd.mapPartitions(apply_model_bactched).collect()

                                                                                

[1141148, 438057, 432402, 292910]

In [42]:
# Load our fancy machine learning model
# model = ...

def model_predict(df):
    # y_pred = model.predict(df)
    y_pred = df.trip_distance * 5 # Truly state of the art
    return y_pred

In [45]:
def apply_model_bactched(rows):
    df = pd.DataFrame(rows, columns=features)
    predictions = model_predict(df)
    df["predicted_duration"] = predictions
    
    for row in df.itertuples():
        yield row

In [46]:
duration_rdd.mapPartitions(apply_model_bactched).take(10)

                                                                                

[Pandas(Index=0, VendorID=2.0, lpep_pickup_datetime=Timestamp('2020-01-23 13:10:15'), PULocationID=74, DOLocationID=130, trip_distance=12.77, predicted_duration=63.849999999999994),
 Pandas(Index=1, VendorID=nan, lpep_pickup_datetime=Timestamp('2020-01-20 15:09:00'), PULocationID=67, DOLocationID=39, trip_distance=8.0, predicted_duration=40.0),
 Pandas(Index=2, VendorID=2.0, lpep_pickup_datetime=Timestamp('2020-01-15 20:23:41'), PULocationID=260, DOLocationID=157, trip_distance=1.27, predicted_duration=6.35),
 Pandas(Index=3, VendorID=2.0, lpep_pickup_datetime=Timestamp('2020-01-05 16:32:26'), PULocationID=82, DOLocationID=83, trip_distance=1.25, predicted_duration=6.25),
 Pandas(Index=4, VendorID=2.0, lpep_pickup_datetime=Timestamp('2020-01-29 19:22:42'), PULocationID=166, DOLocationID=42, trip_distance=1.84, predicted_duration=9.200000000000001),
 Pandas(Index=5, VendorID=2.0, lpep_pickup_datetime=Timestamp('2020-01-15 11:07:42'), PULocationID=179, DOLocationID=223, trip_distance=0.7

In [47]:
df_predictions = duration_rdd \
    .mapPartitions(apply_model_bactched) \
    .toDF() \
    .drop("Index")

                                                                                

In [48]:
df_predictions.show()

[Stage 27:>                                                         (0 + 1) / 1]

+--------+--------------------+------------+------------+-------------+------------------+
|VendorID|lpep_pickup_datetime|PULocationID|DOLocationID|trip_distance|predicted_duration|
+--------+--------------------+------------+------------+-------------+------------------+
|     2.0|                  {}|          74|         130|        12.77|63.849999999999994|
|     NaN|                  {}|          67|          39|          8.0|              40.0|
|     2.0|                  {}|         260|         157|         1.27|              6.35|
|     2.0|                  {}|          82|          83|         1.25|              6.25|
|     2.0|                  {}|         166|          42|         1.84| 9.200000000000001|
|     2.0|                  {}|         179|         223|         0.76|               3.8|
|     2.0|                  {}|          41|         237|         3.32|16.599999999999998|
|     2.0|                  {}|          75|         161|         2.21|             11.05|

                                                                                

Ideally we should have specified the schema so Spark doesn't infer it. But we've done that a number of times now so you've got the idea. (Also I'm not sure what's going on with `lpep_pickup_datetime` above...)

In [50]:
df_predictions.select("predicted_duration").show()

[Stage 29:>                                                         (0 + 1) / 1]

+------------------+
|predicted_duration|
+------------------+
|63.849999999999994|
|              40.0|
|              6.35|
|              6.25|
| 9.200000000000001|
|               3.8|
|16.599999999999998|
|             11.05|
|               4.5|
|              30.5|
|               8.7|
|5.8999999999999995|
|              11.0|
|              15.2|
|              4.25|
|25.299999999999997|
|7.8500000000000005|
|              34.0|
| 5.300000000000001|
|              6.15|
+------------------+
only showing top 20 rows



                                                                                