In [2]:
# Filter out irrelevant data

import json
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, max as spark_max, row_number, to_timestamp, date_add
from pyspark.sql.window import Window

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("lvb-spark") \
    .config('spark.master', 'local') \
    .config('spark.jars.packages', 'org.mongodb.spark:mongo-spark-connector_2.12:3.0.0') \
    .getOrCreate()

# Load departures data from MongoDB
df = spark.read.format("mongo") \
    .option("uri", "mongodb://mongo:27017/") \
    .option("database", "lvb") \
    .option("collection", "departures") \
    .load()

## METRICS
total_rows = df.count()

## TIMEZONE
# Shift time from UTC to UTC+1
df = df.withColumn("plannedWhen", date_add(col("plannedWhen"), 1))
df = df.withColumn("crawlDate", date_add(col("crawlDate"), 1))
df = df.withColumn("when", date_add(col("when"), 1))

## DATE FILTERING
# Define the date range
start_date = "2023-12-08"
end_date = "2024-01-14"

# Convert plannedWhen to timestamp and filter by date range
df = df.withColumn("plannedWhen", to_timestamp(col("plannedWhen")))
df = df.filter((col("plannedWhen") >= start_date) & (col("plannedWhen") <= end_date))

## DEDUPLICATION
window_spec = Window.partitionBy("tripId", "stopId", "plannedWhen").orderBy(col("crawlDate").desc())

# Add row numbers within each window
df = df.withColumn("row_num", row_number().over(window_spec))

# Keep only the first row (latest crawlDate) for each window
filtered_df = df.filter(col("row_num") == 1).drop("row_num")

# Count the number of rows before and after deduplication
rows_before = df.count()
rows_after = filtered_df.count()
removed_rows = rows_before - rows_after

print(f"Removed {removed_rows} duplicate rows. {rows_after} rows remain after deduplication.")

## STOPS
df = filtered_df

# Load relevant stops
with open('data/stops_fromRelevantLines.json', 'r') as f:
    relevant_stops = json.load(f)
relevant_stop_ids = list(relevant_stops.keys())

# Filter departures by relevant stops
filtered_df = df.filter(col("stopId").isin(relevant_stop_ids))

# Count filtered stops
total_stops = df.select("stopId").distinct().count()
remaining_stops = filtered_df.select("stopId").distinct().count()
filtered_stops = total_stops - remaining_stops

print(f"Filtered out {filtered_stops} stops. {remaining_stops} stops remain.")

## LINES
df = filtered_df

# Load relevant lines
with open('data/lines_with_stops.json', 'r') as f:
    relevant_lines = json.load(f)
relevant_line_ids = list(relevant_lines.keys())

total_lines = df.select("lineId").distinct().count()

# filter by relevant lines
filtered_df = df.filter(col("lineId").isin(relevant_line_ids))

# Count filtered lines
remaining_lines = filtered_df.select("lineId").distinct().count()
filtered_lines = total_lines - remaining_lines

print(f"Filtered out {filtered_lines} lines. {remaining_lines} lines remain.")

## METRICS
filtered_rows = filtered_df.count()
print(f"Total Filter Stats: Filtered out {total_rows - filtered_rows} rows. {filtered_rows} rows remain.")

# Output as Parquet
output_path = "data/filtered_01.parquet"
filtered_df.write.mode("overwrite").parquet(output_path)

print(f"Filtered data saved to {output_path}")

# Show sample of the filtered data
filtered_df.show(5)


                                                                                

Removed 1956332 duplicate rows. 4840435 rows remain after deduplication.


                                                                                

Filtered out 134 stops. 659 stops remain.


                                                                                

Filtered out 0 lines. 52 lines remain.


                                                                                

Total Filter Stats: Filtered out 3555799 rows. 4120672 rows remain.


                                                                                

Filtered data saved to data/filtered_01.parquet


[Stage 123:>                                                        (0 + 1) / 1]

+---+--------------------+----------+-----+--------------------+-----------+-------------------+------+--------------------+----------+
|__v|                 _id| crawlDate|delay|           direction|     lineId|        plannedWhen|stopId|              tripId|      when|
+---+--------------------+----------+-----+--------------------+-----------+-------------------+------+--------------------+----------+
|  0|{657e724cd9f42ed9...|2023-12-18|  120|          Stötteritz| 8-naslvt-4|2023-12-18 00:00:00|958145|1|1000039|0|81|17...|2023-12-18|
|  0|{657bf8c9da53b8a4...|2023-12-16|  120|BMW Zentralgebäud...|5-naslvb-82|2023-12-16 00:00:00|955799|1|1000092|0|81|15...|2023-12-16|
|  0|{657bfb98e996a6f5...|2023-12-16|  120|BMW Zentralgebäud...|5-naslvb-82|2023-12-16 00:00:00|955815|1|1000092|0|81|15...|2023-12-16|
|  0|{657bfaec1545b917...|2023-12-16|  120|BMW Zentralgebäud...|5-naslvb-82|2023-12-16 00:00:00|956123|1|1000092|0|81|15...|2023-12-16|
|  0|{657bfb99e996a6f5...|2023-12-16|  180|BMW Z

                                                                                

In [None]:

# Read the saved Parquet file
parquet_df = spark.read.parquet(output_path)

# Show sample of the filtered data from the Parquet file
print("Sample of data from the Parquet file:")
parquet_df.show(5)

print("Schema of the Parquet file:")
parquet_df.printSchema()

# Optional: Get some basic statistics
print("Summary statistics:")
parquet_df.describe().show()
# Count total number of rows
total_rows = parquet_df.count()
mongo_count = spark.read.format("mongo") \
    .option("uri", "mongodb://mongo:27017/") \
    .option("database", "lvb") \
    .option("collection", "departures") \
    .load().count()

# Calculate the percentage of reduced data rows
reduced_percentage = ((mongo_count - total_rows) / mongo_count) * 100 if mongo_count > 0 else 0
print(f"Total number of rows: {total_rows}")
print(f"Percentage of reduced data rows: {reduced_percentage:.2f}%")