### SQL Homework
Use this notebook to answer the questions.
It can be in the same project as the previous homework.  
When you are ready, **upload** to your github repo, and send me the link (just zip a txt file with the repo's address, and upload it as the homework).  
If your repo is private, invite me: balazs.balogh@cubixedu.com.

#### Import the SparkSession, create it then load the taxi data (yellow_tripdata_2024-08.parquet)

In [41]:
from datetime import datetime, date

from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    StringType,
    DateType,
    TimestampType
)

spark = (
    SparkSession
    .builder
    .appName("SQL")        # type: ignore
    .master("local[*]")    # Ensures it runs in local mode using all CPU cores (6 cores and 12 threads in my case).
    .getOrCreate()
)

taxi_data_2024_08 = (
    spark
    .read
    .format("parquet")
    .load("C:\\Users\\Lerry\\Desktop\\cubix_data_engineer_homework\\scr\data\\yellow_tripdata_2024-08.parquet")
)

taxi_data_2024_08.createOrReplaceTempView("taxi_data_2024_08")

taxi_data_2024_08.show(10)


+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-08-01 00:21:00|  2024-08-01 00:36:13|              1|          7.4|         1|                 N|         138|          80|           1|       28.9| 7.75|    0.5|      7.6

#### 1. What is the total fare amount for all trips?  
Please round the answer to two decimal places.

In [43]:
spark.sql("""

   SELECT
        ROUND(SUM(fare_amount), 2) AS total_fare_amount
    FROM 
        taxi_data_2024_08
""").show()


+-----------------+
|total_fare_amount|
+-----------------+
|    5.875099806E7|
+-----------------+



#### 2. Show the maximum fare amount, minimum fare amount, and average fare amount for each payment type. Order by payment type.
Round where you need to two decimal places.

In [44]:
spark.sql("""

    SELECT
        payment_type,
        ROUND(MAX(fare_amount), 2) AS max_fare_amount,
        ROUND(MIN(fare_amount), 2) AS min_fare_amount,
        ROUND(AVG(fare_amount), 2) AS avg_fare_amount
    FROM
        taxi_data_2024_08
    GROUP BY
        payment_type

""").show()

+------------+---------------+---------------+---------------+
|payment_type|max_fare_amount|min_fare_amount|avg_fare_amount|
+------------+---------------+---------------+---------------+
|           1|          650.0|         -108.7|           20.6|
|           3|          999.0|         -999.0|           6.26|
|           2|         1386.2|        -1174.1|          19.09|
|           4|          900.0|         -900.0|           1.37|
|           0|         394.76|          -72.2|          19.62|
+------------+---------------+---------------+---------------+



#### 3. For trips with a fare amount greater than 20, what is the total tip amount for each day (based on the tpep_pickup_datetime)?
Round the tip to two decimal places, and order the results from highest total tip amount.  
Hint: Check DATE() function, to convert tpep_pickup_datetime to date, to get only the YYYY-MM-DD.

In [64]:
spark.sql("""

    SELECT
    DATE(tpep_pickup_datetime) AS days,
    ROUND(SUM(tip_amount), 2) AS total_tip_amount
    
    FROM
        taxi_data_2024_08
          
    WHERE
          fare_amount > 20
    GROUP BY
        days
    ORDER BY total_tip_amount DESC

""").show()

+----------+----------------+
|      days|total_tip_amount|
+----------+----------------+
|2024-08-07|       200170.81|
|2024-08-01|       197694.16|
|2024-08-29|       192837.64|
|2024-08-08|       192807.33|
|2024-08-19|       184872.55|
|2024-08-22|       182233.24|
|2024-08-30|       180368.43|
|2024-08-28|       180220.74|
|2024-08-05|       176271.55|
|2024-08-15|        176223.0|
|2024-08-06|       175451.67|
|2024-08-11|       174769.89|
|2024-08-20|       173310.48|
|2024-08-25|       171774.94|
|2024-08-04|       169589.92|
|2024-08-23|        168495.1|
|2024-08-02|        167908.5|
|2024-08-14|        166266.4|
|2024-08-21|       165647.98|
|2024-08-27|        164666.8|
+----------+----------------+
only showing top 20 rows



#### 4. For each trip, show the fare amount along with a column that indicates if the trip was "expensive" (greater than 30) or "cheap" (less than or equal to 30).
Hint: Use CASE WHEN for deciding on expensive, or cheap.

In [80]:
spark.sql("""
    
    SELECT
        fare_amount,
        CASE
            WHEN fare_amount > 30 THEN 'expensive'
            ELSE 'cheap'
        END AS fare_category
    FROM 
        taxi_data_2024_08
""").show()


+-----------+-------------+
|fare_amount|fare_category|
+-----------+-------------+
|       28.9|        cheap|
|       40.8|    expensive|
|       52.0|    expensive|
|       17.0|        cheap|
|        5.1|        cheap|
|        5.8|        cheap|
|       16.3|        cheap|
|       17.7|        cheap|
|       27.5|        cheap|
|       11.4|        cheap|
|        9.3|        cheap|
|        5.1|        cheap|
|       16.3|        cheap|
|       45.0|    expensive|
|        8.6|        cheap|
|        7.2|        cheap|
|       13.5|        cheap|
|       14.9|        cheap|
|       10.7|        cheap|
|        8.6|        cheap|
+-----------+-------------+
only showing top 20 rows



#### 5. Find the first trip (based on tpep_pickup_datetime) for each VendorID and display the fare amount.
Hint: You can use CTE with ROW_NUMBER().

In [84]:
spark.sql("""

    WITH first_trip AS (
        SELECT
            tpep_pickup_datetime,
            fare_amount,
            VendorID,
            ROW_NUMBER() OVER (PARTITION BY VendorID ORDER BY tpep_pickup_datetime) AS row_num
        FROM
            taxi_data_2024_08
    )

    SELECT
        VendorID,
        fare_amount
    FROM
        first_trip
    WHERE row_num = 1
              
""").show()

+--------+-----------+
|VendorID|fare_amount|
+--------+-----------+
|       1|       70.0|
|       2|       10.7|
+--------+-----------+



#### 7. Calculate the average trip distance for each VendorID, and assign a label of 'Above Average' or 'Below Average' for each trip based on the distance relative to the VendorID’s average trip distance.
Hint: CTE joined back to the main DataFrame.

In [90]:
spark.sql("""

    WITH vendor_avg_trip_distance AS (
        SELECT
            VendorID,
            AVG(trip_distance) AS avg_trip_distance
        FROM
            taxi_data_2024_08  
        GROUP BY
            VendorID
    )

    SELECT
        t.VendorID,
        t.trip_distance,
        CASE
            WHEN t.trip_distance > v.avg_trip_distance THEN 'Above Average'
            ELSE 'Below Average'
        END AS distance_label       
    FROM  taxi_data_2024_08 t    
    JOIN 
        vendor_avg_trip_distance v
    ON 
        t.VendorID = v.VendorID

""").show()


+--------+-------------+--------------+
|VendorID|trip_distance|distance_label|
+--------+-------------+--------------+
|       1|          7.4| Above Average|
|       2|         9.91| Above Average|
|       1|         13.4| Above Average|
|       1|          3.9| Below Average|
|       1|          0.4| Below Average|
|       1|          0.4| Below Average|
|       2|         3.21| Below Average|
|       2|          3.8| Below Average|
|       2|         5.54| Above Average|
|       2|         1.56| Below Average|
|       2|         1.32| Below Average|
|       2|         0.45| Below Average|
|       2|         2.65| Below Average|
|       1|         11.7| Above Average|
|       2|         1.54| Below Average|
|       1|          0.4| Below Average|
|       2|         2.19| Below Average|
|       2|         2.58| Below Average|
|       1|          1.4| Below Average|
|       1|          1.1| Below Average|
+--------+-------------+--------------+
only showing top 20 rows

