# Trips and Users

### Description

Table: Trips

| Column Name | Type     |
|-------------|----------|
| id          | int      |
| client_id   | int      |
| driver_id   | int      |
| city_id     | int      |
| status      | enum     |
| request_at  | date     |     

id is the primary key (column with unique values) for this table.
The table holds all taxi trips. Each trip has a unique id, while client_id and driver_id are foreign keys to the users_id at the Users table.
Status is an ENUM (category) type of ('completed', 'cancelled_by_driver', 'cancelled_by_client').

Table: Users

| Column Name | Type     |
|-------------|----------|
| users_id    | int      |
| banned      | enum     |
| role        | enum     |

users_id is the primary key (column with unique values) for this table.
The table holds all users. Each user has a unique users_id, and role is an ENUM type of ('client', 'driver', 'partner').
banned is an ENUM (category) type of ('Yes', 'No').

The cancellation rate is computed by dividing the number of canceled (by client or driver) requests with unbanned users by the total number of requests with unbanned users on that day.

Write a solution to find the cancellation rate of requests with unbanned users (both client and driver must not be banned) each day between "2013-10-01" and "2013-10-03". Round Cancellation Rate to two decimal points.

Return the result table in any order.

### Imports

In [None]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from datetime import datetime

### Sample data

In [None]:
trips_schema = StructType([
    StructField("id", IntegerType(), nullable=False),
    StructField("client_id", IntegerType(), nullable=False),
    StructField("driver_id", IntegerType(), nullable=False),
    StructField("city_id", IntegerType(), nullable=False),
    StructField("status", StringType(), nullable=False),
    StructField("request_at", DateType(), nullable=False)
])

users_schema = StructType([
    StructField("users_id", IntegerType(), nullable=False),
    StructField("banned", StringType(), nullable=False),
    StructField("role", StringType(), nullable=False)
])

trips_data = [
    (1, 1, 10, 1, "completed", datetime.strptime("2013-10-01", "%Y-%m-%d")),
    (2, 2, 11, 1, "cancelled_by_driver", datetime.strptime("2013-10-01", "%Y-%m-%d")),
    (3, 3, 12, 6, "completed", datetime.strptime("2013-10-01", "%Y-%m-%d")),
    (4, 4, 13, 6, "cancelled_by_client", datetime.strptime("2013-10-01", "%Y-%m-%d")),
    (5, 1, 10, 1, "completed", datetime.strptime("2013-10-02", "%Y-%m-%d")),
    (6, 2, 11, 6, "completed", datetime.strptime("2013-10-02", "%Y-%m-%d")),
    (7, 3, 12, 6, "completed", datetime.strptime("2013-10-02", "%Y-%m-%d")),
    (8, 2, 12, 12, "completed", datetime.strptime("2013-10-03", "%Y-%m-%d")),
    (9, 3, 10, 12, "completed", datetime.strptime("2013-10-03", "%Y-%m-%d")),
    (10, 4, 13, 12, "cancelled_by_driver", datetime.strptime("2013-10-03", "%Y-%m-%d"))
]

users_data = [
    (1, "No", "client"),
    (2, "Yes", "client"),
    (3, "No", "client"),
    (4, "No", "client"),
    (10, "No", "driver"),
    (11, "No", "driver"),
    (12, "No", "driver"),
    (13, "No", "driver")
]

spark.createDataFrame(trips_data, schema=trips_schema).createOrReplaceTempView("Trips")
spark.createDataFrame(users_data, schema=users_schema).createOrReplaceTempView("Users")

display(spark.table("Trips"))
display(spark.table("Users"))


### Solution

In [None]:
window = Window.partitionBy("request_at")

users_not_banned = spark.read.table("Users").filter("banned = 'No'").select("users_id")
users_not_banned_list = [row.users_id for row in users_not_banned.collect()]
broadcasted_users_not_banned_list = spark.sparkContext.broadcast(users_not_banned_list)

users_banned = spark.read.table("Users").filter("banned = 'Yes'").select("users_id")
users_banned_list = [row.users_id for row in users_banned.collect()]
broadcasted_users_banned_list = spark.sparkContext.broadcast(users_banned_list)

trips_enriched = (
    spark.read.table("Trips")
        .withColumn("total", F.count("*").over(window))
        .withColumn("cancelled", F.count_if(F.col("status").like("cancelled%") & F.col("client_id").isin(broadcasted_users_not_banned_list.value)).over(window))
        .withColumn("banned", F.count_if(F.col("client_id").isin(broadcasted_users_banned_list.value) | F.col("driver_id").isin(broadcasted_users_banned_list.value)).over(window))
        .filter(F.col("request_at").between("2013-10-01", "2013-10-03"))
        .select("status", "request_at", "total", "cancelled", "banned")
)

result = (
    trips_enriched.filter("total - banned > 0")
        .select(F.col("request_at").alias("Day"),
                F.round(F.col("cancelled") / (F.col("total") - F.col("banned")).cast("float"), 2).alias("Cancellation Rate"))
        .distinct()
)

display(result)