In [None]:
# Filter out irrelevant data

import json
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, max as spark_max, row_number, to_timestamp, date_add
from pyspark.sql.window import Window

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("lvb-spark") \
    .config('spark.master', 'local') \
    .config('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.12:3.0.0') \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

# Load departures data from MongoDB
df = spark.read.format("mongo") \
    .option("uri", "mongodb://mongo:27017/") \
    .option("database", "lvb") \
    .option("collection", "departures") \
    .load()

## METRICS
total_rows = df.count()
from pyspark.sql.functions import expr

## TIMEZONE
# Convert plannedWhen, crawlDate, and when from string to timestamp
df = df.withColumn("plannedWhen", to_timestamp(col("plannedWhen")))
df = df.withColumn("crawlDate", to_timestamp(col("crawlDate")))
df = df.withColumn("when", to_timestamp(col("when")))

# Shift time from UTC to UTC+1
df = df.withColumn("plannedWhen", expr("plannedWhen + INTERVAL 1 HOUR"))
df = df.withColumn("crawlDate", expr("crawlDate + INTERVAL 1 HOUR"))
df = df.withColumn("when", expr("when + INTERVAL 1 HOUR"))

## DATE FILTERING
# Define the date range
start_date = "2023-12-08 00:00:00"
end_date = "2024-01-14 23:59:59"

# Filter by date range
df = df.filter((col("plannedWhen") >= start_date) & (col("plannedWhen") <= end_date))

## DEDUPLICATION
window_spec = Window.partitionBy("tripId", "stopId", "plannedWhen").orderBy(col("crawlDate").desc())

# Add row numbers within each window
df = df.withColumn("row_num", row_number().over(window_spec))

# Keep only the first row (latest crawlDate) for each window
filtered_df = df.filter(col("row_num") == 1).drop("row_num")

# Count the number of rows before and after deduplication
rows_before = df.count()
rows_after = filtered_df.count()
removed_rows = rows_before - rows_after

print(f"Removed {removed_rows} duplicate rows. {rows_after} rows remain after deduplication.")

## STOPS
df = filtered_df

# Load relevant stops
with open('data/stops_fromRelevantLines.json', 'r') as f:
    relevant_stops = json.load(f)
relevant_stop_ids = list(relevant_stops.keys())

# Filter departures by relevant stops
filtered_df = df.filter(col("stopId").isin(relevant_stop_ids))

# Count filtered stops
total_stops = df.select("stopId").distinct().count()
remaining_stops = filtered_df.select("stopId").distinct().count()
filtered_stops = total_stops - remaining_stops

print(f"Filtered out {filtered_stops} stops. {remaining_stops} stops remain.")

## LINES
df = filtered_df

# Load relevant lines
with open('data/lines_with_stops.json', 'r') as f:
    relevant_lines = json.load(f)
relevant_line_ids = list(relevant_lines.keys())

total_lines = df.select("lineId").distinct().count()

# filter by relevant lines
filtered_df = df.filter(col("lineId").isin(relevant_line_ids))

# Count filtered lines
remaining_lines = filtered_df.select("lineId").distinct().count()
filtered_lines = total_lines - remaining_lines

print(f"Filtered out {filtered_lines} lines. {remaining_lines} lines remain.")

## HANDLE NULLS
filtered_df = filtered_df.fillna({'delay': 0})


## METRICS
filtered_rows = filtered_df.count()
print(f"Total Filter Stats: Filtered out {total_rows - filtered_rows} rows. {filtered_rows} rows remain.")

# Output as Parquet
output_path = "data/filtered_01.parquet"
filtered_df.write.mode("overwrite").parquet(output_path)

print(f"Filtered data saved to {output_path}")

# Show sample of the filtered data
filtered_df.show(5)


In [None]:

# Read the saved Parquet file
parquet_df = spark.read.parquet(output_path)

# Show sample of the filtered data from the Parquet file
print("Sample of data from the Parquet file:")
parquet_df.show(5)

print("Schema of the Parquet file:")
parquet_df.printSchema()

# Optional: Get some basic statistics
print("Summary statistics:")
parquet_df.describe().show()
# Count total number of rows
total_rows = parquet_df.count()
mongo_count = spark.read.format("mongo") \
    .option("uri", "mongodb://mongo:27017/") \
    .option("database", "lvb") \
    .option("collection", "departures") \
    .load().count()

# Calculate the percentage of reduced data rows
reduced_percentage = ((mongo_count - total_rows) / mongo_count) * 100 if mongo_count > 0 else 0
print(f"Total number of rows: {total_rows}")
print(f"Percentage of reduced data rows: {reduced_percentage:.2f}%")