In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType
import os
import shutil


# Create a SparkSession
spark = SparkSession.builder \
    .appName("EventCount") \
    .getOrCreate()

# set the path to the data
data_path = "taxi-data"

# Get active streaming queries
active_queries = spark.streams.active

# Stop each active streaming query
for query in active_queries:
    query.stop()

# Define the directory path
directory = "output"

# Create the directory if it does not exist
if os.path.exists(directory):
    shutil.rmtree(directory)
os.makedirs(directory)

# Define the schema
schema = StructType([
    StructField("type", StringType()),
    StructField("VendorID", IntegerType()),
    StructField("pickup_datetime", TimestampType()),
    StructField("dropoff_datetime", TimestampType()),
    StructField("passenger_count", IntegerType()),
    StructField("trip_distance", DoubleType()),
    StructField("pickup_longitude", DoubleType()),
    StructField("pickup_latitude", DoubleType()),
    StructField("RatecodeID", IntegerType()),
    StructField("store_and_fwd_flag", StringType()),
    StructField("dropoff_longitude", DoubleType()),
    StructField("dropoff_latitude", DoubleType()),
    StructField("payment_type", IntegerType()),
    StructField("fare_amount", DoubleType()),
    StructField("extra", DoubleType()),
    StructField("mta_tax", DoubleType()),
    StructField("tip_amount", DoubleType()),
    StructField("tolls_amount", DoubleType()),
    StructField("improvement_surcharge", DoubleType()),
    StructField("total_amount", DoubleType())
])

# Define the streaming DataFrame
streaming_df = (
    spark.readStream
    .format("csv")
    .option("header", "false")
    .schema(schema)
    .load(data_path)
)

# Initialize a dictionary to store running totals for each hour
hour_counts = {}


def process_streaming_data(df, epoch_id):
    global hour_counts

    # Calculate counts for each hour in the current batch
    batch_hour_counts = df \
        .withColumn("hour", f.hour(df["dropoff_datetime"].cast("timestamp"))) \
        .groupBy("hour") \
        .count() \
        .collect()

    # Update running totals with counts from the current batch
    for row in batch_hour_counts:
        hour = row["hour"]
        count = row["count"]
        if hour in hour_counts:
            hour_counts[hour] += count
        else:
            hour_counts[hour] = count

    # Write the updated counts to the corresponding output directories
    for hour, count in hour_counts.items():
        output_directory = f"output/output-{(hour + 1) * 3600000}"
        row_df = spark.createDataFrame([(hour, count)], ["hour", "count"])
        # Repartition the DataFrame before writing
        row_df = row_df.repartition(1)
        row_df.write.csv(output_directory, mode="overwrite", header="true")


# Checkpoint location
checkpoint_location = "checkpoint"

# Start the new streaming query
query = streaming_df \
    .writeStream \
    .outputMode("update") \
    .foreachBatch(process_streaming_data) \
    .option("checkpointLocation", checkpoint_location)  # Add checkpoint location


query.start().awaitTermination()