In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StringType, TimestampType, DoubleType, IntegerType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
import os
from custom_utils import *

In [None]:
spark = SparkSession.\
    builder.\
    appName("merge_stations_and_fuel_data-notebook").\
    getOrCreate()

In [None]:
fuel_files = recursive_file_retrieval(os.path.join(project_base_dir, f"tankerkoenig-data/prices/{year}/"), month_strings)

In [None]:
len(fuel_files)

In [None]:
fuel_schema = StructType() \
      .add("date",TimestampType(),True) \
      .add("station_uuid",StringType(),True) \
      .add("diesel",DoubleType(),True) \
      .add("e5",DoubleType(),True) \
      .add("e10",DoubleType(),True) \
      .add("dieselchange",IntegerType(),True) \
      .add("e5change",IntegerType(),True) \
      .add("e10change",IntegerType(),True)

In [None]:
all_fuel_data = spark.read.format("csv") \
      .option("header", True) \
      .schema(fuel_schema) \
      .load(fuel_files)

In [None]:
all_fuel_data.show(10)

In [None]:
stations_schema = StructType() \
      .add("station_uuid",StringType(),True) \
      .add("latitude",DoubleType(),True) \
      .add("longitude",DoubleType(),True) \
      .add("city",StringType(),True)

In [None]:
stations_data = spark.read.format("csv") \
      .schema(stations_schema) \
      .load(os.path.join(project_base_dir, "outputs/selected_stations_unique.csv"))

In [None]:
stations_data.show()

In [None]:
joined_data = all_fuel_data \
    .join(stations_data, ["station_uuid"])

In [None]:
joined_data.show(10)

In [None]:
filtered_data = joined_data \
    .filter((f.col(f"{fuel_type}change") > 0) & (f.col(fuel_type) > 0)) \
    .select("date", "station_uuid", fuel_type, "latitude", "longitude") \
    .withColumnRenamed("date", "dateTime")

In [None]:
print(filtered_data.count()), filtered_data.show(10)

In [None]:
with_date_data = filtered_data \
    .withColumn("date", f.to_date(f.col("dateTime")))

In [None]:
date_and_time_data = with_date_data \
    .withColumn("truncated_timestamp", f.date_trunc("minute", col("dateTime"))) \
    .withColumn("minutes", f.minute(col("truncated_timestamp"))) \
    .withColumn("new_minutes", f.round(f.col("minutes")/60)*60) \
    .withColumn("add_seconds", (f.col("new_minutes") - f.col("minutes")) * 60) \
    .withColumn("new_timestamp", f.from_unixtime(f.unix_timestamp("truncated_timestamp") + f.col("add_seconds"))) \
    .withColumn("hour", f.hour(col("new_timestamp"))) \
    .drop("truncated_timestamp", "minutes", "new_minutes", "add_seconds", "new_timestamp")

In [None]:
date_and_time_data.orderBy(f.rand()).limit(10).show(), print(date_and_time_data.count())

In [None]:
# if there are multiple price changes within a single time slot, only take one
w2 = Window.partitionBy(["station_uuid", "date", "hour"]).orderBy(f.col("hour"))
deduplicated_timeslot_data = date_and_time_data \
    .withColumn("row", f.row_number().over(w2)) \
    .filter(col("row") == 1) \
    .drop("row")

In [None]:
deduplicated_timeslot_data.show(10), print(deduplicated_timeslot_data.count())

In [None]:
data_with_weekdays = deduplicated_timeslot_data \
    .withColumn("weekday", f.date_format(col("date"), "F"))

In [None]:
data_with_weekdays.show(10)

In [None]:
w = Window.partitionBy("station_uuid").orderBy("dateTime").rowsBetween(-rolling_window_size, 0)
deviation_data = data_with_weekdays \
    .withColumn('rolling_price_mean', f.avg(fuel_type).over(w)) \
    .withColumn("deviation", f.col(fuel_type) - f.col("rolling_price_mean")) \
    .drop(fuel_type, "rolling_price_mean")

In [None]:
deviation_data.show(10)

In [None]:
deviation_data.write \
    .option("header",True) \
    .csv(os.path.join(project_base_dir, "outputs/preprocessed_price_data.csv"))