In [None]:
from pyspark.sql import SparkSession
input_path = "dbfs:/user/mehak/processed/berlin_clean.csv"

# Start or reuse the Spark session
spark = SparkSession.builder.appName("TrafficDataProcessing").getOrCreate()

# Load the CSV file into a Spark DataFrame
spark_df = spark.read.csv(input_path, header=True, inferSchema=True)


# Show the first 5 rows to verify data is loaded
spark_df.show(5)

# Check the schema and data types:
spark_df.printSchema()

# Get summary statistics:

spark_df.describe().show()

print(f"Total rows: {spark_df.count()}")
for col in spark_df.columns:
    print(col, spark_df.filter(spark_df[col].isNull()).count())


spark_df = spark_df.na.drop()  # Drop rows with null values


#  Understand Distribution of zahl_tvz (Traffic Volume) - check how traffic counts are distributed overall.

spark_df.describe("zahl_tvz").show()

#  Find Regions with Highest Traffic -Group by berlin_bez and lor_prg to see which areas have more traffic:
from pyspark.sql.functions import sum, avg

# Total traffic by berlin_bez
spark_df.groupBy("berlin_bez").agg(sum("zahl_tvz").alias("total_traffic")).orderBy("total_traffic", ascending=False).show()

# Average traffic by lor_prg
spark_df.groupBy("lor_prg").agg(avg("zahl_tvz").alias("avg_traffic")).orderBy("avg_traffic", ascending=False).show()

#  Top 10 Streets with Highest Traffic - Column name seems to contain street names.

spark_df.groupBy("name").agg(sum("zahl_tvz").alias("total_traffic")).orderBy("total_traffic", ascending=False).show(10)

# Convert to Pandas
pdf = spark_df.groupBy("berlin_bez").agg(sum("zahl_tvz").alias("total_traffic")).toPandas()

# Plot
import matplotlib.pyplot as plt

pdf.plot(kind="bar", x="berlin_bez", y="total_traffic", title="Traffic by Berlin Bez")
plt.show()

# Compute total and average traffic by Berlin district (berlin_bez)
from pyspark.sql.functions import sum, avg

spark_df.groupBy("berlin_bez").agg(
    sum("zahl_tvz").alias("total_traffic"),
    avg("zahl_tvz").alias("avg_traffic")
).orderBy("total_traffic", ascending=False).show()

# Find top 10 streets with the highest traffic

spark_df.groupBy("name").agg(
    sum("zahl_tvz").alias("total_traffic")
).orderBy("total_traffic", ascending=False).show(10)

pdf = spark_df.groupBy("berlin_bez").agg(sum("zahl_tvz").alias("total_traffic")).toPandas()

import matplotlib.pyplot as plt
pdf.plot(kind="bar", x="berlin_bez", y="total_traffic", title="Traffic by Berlin District")
plt.show()

import os
import matplotlib.pyplot as plt

# Ensure the directory exists before saving
folder_path = "/dbfs/FileStore/reports/figures"
os.makedirs(folder_path, exist_ok=True)

# Your plotting code here (assuming pdf is a pandas DataFrame)
pdf.plot(kind="bar", x="berlin_bez", y="total_traffic", title="Traffic by Berlin District")

# Path where the figure will be saved
fig_path = "/dbfs/FileStore/reports/figures/traffic_volume_berlin_bez.png"


# Save the figure
plt.savefig(fig_path)

plt.close()

print(f"Figure saved to: {fig_path}")