In [1]:
from datetime import datetime

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

In [2]:
spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate()

24/03/09 15:20:19 WARN Utils: Your hostname, avalon resolves to a loopback address: 127.0.1.1; using 192.168.18.2 instead (on interface eth0)
24/03/09 15:20:19 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/03/09 15:20:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
paths_green = ["data/pq/green/2020/*", "data/pq/green/2021/*"]
paths_yellow = ["data/pq/yellow/2020/*", "data/pq/yellow/2021/*"]
df_green = spark.read.parquet(*paths_green)
df_yellow = spark.read.parquet(*paths_yellow)

                                                                                

### OOP

In [4]:
class TaxiSpark:
    def __init__(
        self,
        df: pyspark.sql.dataframe.DataFrame,
        pickup_column: str,
        dropoff_column: str,
    ) -> None:
        self.df = df
        self.pickup_column = pickup_column
        self.dropoff_column = dropoff_column
        self.df_revenue = None

    def calculate_revenue(
        self,
        start_date: datetime,
    ) -> pyspark.sql.dataframe.DataFrame:
        self.df_revenue = (
            (self.df)
            .filter(
                F.col(self.pickup_column) >= start_date,
            )
            .groupBy(
                [
                    F.date_trunc("hour", self.pickup_column).alias("hour"),
                    F.col("PULocationID").alias("zone"),
                ]
            )
            .agg(
                F.sum("total_amount").alias("amount"),
                F.count(F.expr("*")).alias("number_records"),
            )
            .orderBy(
                "hour",
                "zone",
            )
        )
        return self.df_revenue

    def write_parquet_revenue(self, num_partition: int, path: str):
        if self.df_revenue is None:
            return
        self.df_revenue.repartition(num_partition).write.parquet(path, mode="overwrite")

In [5]:
green_taxi = TaxiSpark(df_green, "lpep_pickup_datetime", "lpep_dropoff_datetime")
yellow_taxi = TaxiSpark(df_yellow, "tpep_pickup_datetime", "tpep_dropoff_datetime")

In [6]:
green_taxi.calculate_revenue(datetime(2020, 1, 1))
yellow_taxi.calculate_revenue(datetime(2020, 1, 1))

DataFrame[hour: timestamp, zone: int, amount: double, number_records: bigint]

In [21]:
green_taxi.write_parquet_revenue(20, "data/report/revenue/green")
yellow_taxi.write_parquet_revenue(20, "data/report/revenue/yellow")

                                                                                

In [7]:
green_taxi.df_revenue

DataFrame[hour: timestamp, zone: int, amount: double, number_records: bigint]

In [8]:
yellow_taxi.df_revenue

DataFrame[hour: timestamp, zone: int, amount: double, number_records: bigint]

In [9]:
# df_green_revenue = green_taxi.df_revenue
# df_yellow_revenue = yellow_taxi.df_revenue
df_green_revenue = spark.read.parquet("data/report/revenue/green")
df_yellow_revenue = spark.read.parquet("data/report/revenue/yellow")

In [10]:
df_join = (
    (df_green_revenue)
    .withColumnRenamed("amount", "green_amount")
    .withColumnRenamed("number_records", "green_number_records")
    .join(
        (df_yellow_revenue)
        .withColumnRenamed("amount", "yellow_amount")
        .withColumnRenamed("number_records", "yellow_number_records"),
        on=["hour", "zone"],
        how="outer",
    )
)

In [36]:
df_join.write.parquet("data/report/revenue/total", mode="overwrite")

                                                                                

In [11]:
df_join = spark.read.parquet("data/report/revenue/total")

In [12]:
df_join.show()

                                                                                

+-------------------+----+------------------+--------------------+------------------+---------------------+
|               hour|zone|      green_amount|green_number_records|     yellow_amount|yellow_number_records|
+-------------------+----+------------------+--------------------+------------------+---------------------+
|2020-01-01 00:00:00|  10|              NULL|                NULL|42.410000801086426|                    2|
|2020-01-01 00:00:00|  14|              NULL|                NULL| 8.800000190734863|                    1|
|2020-01-01 00:00:00|  15|              NULL|                NULL| 34.09000015258789|                    1|
|2020-01-01 00:00:00|  17| 195.0299997329712|                   9|220.21000003814697|                    8|
|2020-01-01 00:00:00|  24| 87.60000038146973|                   3| 754.9499969482422|                   45|
|2020-01-01 00:00:00|  29| 61.29999923706055|                   1|              NULL|                 NULL|
|2020-01-01 00:00:00|  33|31

In [13]:
df_zones = spark.read.parquet("data/pq/zones")

In [14]:
df_result = df_join.join(df_zones, df_join.zone == df_zones.LocationID)

In [29]:
df_result.drop("LocationId", "zone").show(2)

+-------------------+------------+--------------------+------------------+---------------------+--------+------------+
|               hour|green_amount|green_number_records|     yellow_amount|yellow_number_records| Borough|service_zone|
+-------------------+------------+--------------------+------------------+---------------------+--------+------------+
|2020-01-01 00:00:00|        NULL|                NULL|42.410000801086426|                    2|  Queens|   Boro Zone|
|2020-01-01 00:00:00|        NULL|                NULL| 8.800000190734863|                    1|Brooklyn|   Boro Zone|
+-------------------+------------+--------------------+------------------+---------------------+--------+------------+
only showing top 2 rows



In [30]:
df_result.drop("LocationId", "zone").write.parquet("tmp/revenue-zones")

                                                                                

### Using SQL Query

In [4]:
df_green.createOrReplaceTempView("green")
df_yellow.createOrReplaceTempView("yellow")

In [5]:
df_green_revenue = spark.sql(
    """
SELECT
    date_trunc("hour", lpep_pickup_datetime) AS hour,
    PULocationID AS zone,

    SUM(total_amount) AS amount,
    COUNT(1) AS number_records
FROM
    green
WHERE
    lpep_pickup_datetime >= "2020-01-01 00:00:00"
GROUP BY
    1, 2
ORDER BY
    1, 2
"""
)

In [6]:
df_green_revenue.show()

                                                                                

+-------------------+----+------------------+--------------+
|               hour|zone|            amount|number_records|
+-------------------+----+------------------+--------------+
|2020-01-01 00:00:00|   7| 769.7299957275391|            45|
|2020-01-01 00:00:00|  17| 195.0299997329712|             9|
|2020-01-01 00:00:00|  18| 7.800000190734863|             1|
|2020-01-01 00:00:00|  22|15.800000190734863|             1|
|2020-01-01 00:00:00|  24| 87.60000038146973|             3|
|2020-01-01 00:00:00|  25| 531.0000057220459|            26|
|2020-01-01 00:00:00|  29| 61.29999923706055|             1|
|2020-01-01 00:00:00|  32| 68.94999885559082|             2|
|2020-01-01 00:00:00|  33|317.26999831199646|            11|
|2020-01-01 00:00:00|  35| 129.9600019454956|             5|
|2020-01-01 00:00:00|  36| 295.3400011062622|            11|
|2020-01-01 00:00:00|  37|  175.669997215271|             6|
|2020-01-01 00:00:00|  38| 98.79000091552734|             2|
|2020-01-01 00:00:00|  4

### Using Pyspark API

In [7]:
df_green_revenue = (
    (df_green)
    .filter(
        F.col("lpep_pickup_datetime") >= datetime(2020, 1, 1),
    )
    .groupBy(
        [
            F.date_trunc("hour", "lpep_pickup_datetime").alias("hour"),
            F.col("PULocationID").alias("zone"),
        ]
    )
    .agg(
        F.sum("total_amount").alias("amount"),
        F.count(F.expr("*")).alias("number_records"),
    )
    .orderBy(
        "hour",
        "zone",
    )
)

In [8]:
df_green_revenue.show()



+-------------------+----+------------------+--------------+
|               hour|zone|            amount|number_records|
+-------------------+----+------------------+--------------+
|2020-01-01 00:00:00|   7| 769.7299957275391|            45|
|2020-01-01 00:00:00|  17| 195.0299997329712|             9|
|2020-01-01 00:00:00|  18| 7.800000190734863|             1|
|2020-01-01 00:00:00|  22|15.800000190734863|             1|
|2020-01-01 00:00:00|  24| 87.60000038146973|             3|
|2020-01-01 00:00:00|  25| 531.0000057220459|            26|
|2020-01-01 00:00:00|  29| 61.29999923706055|             1|
|2020-01-01 00:00:00|  32| 68.94999885559082|             2|
|2020-01-01 00:00:00|  33|317.26999831199646|            11|
|2020-01-01 00:00:00|  35| 129.9600019454956|             5|
|2020-01-01 00:00:00|  36| 295.3400011062622|            11|
|2020-01-01 00:00:00|  37|  175.669997215271|             6|
|2020-01-01 00:00:00|  38| 98.79000091552734|             2|
|2020-01-01 00:00:00|  4

                                                                                

In [9]:
df_green_revenue.repartition(20).write.parquet(
    "data/report/revenue/green", mode="overwrite"
)

                                                                                