In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StringType, IntegerType, DoubleType
from pyspark.sql.functions import lower, lit, regexp_replace, col, contains
import os
from custom_utils import *

In [None]:
spark = SparkSession.\
    builder.\
    appName("preprocessing-notebook").\
    getOrCreate()

In [None]:
station_files = recursive_file_retrieval(os.path.join(project_base_dir, f"tankerkoenig-data/stations/{year}/"), month_strings)

In [None]:
schema = StructType() \
      .add("uuid",StringType(),True) \
      .add("name",StringType(),True) \
      .add("brand",StringType(),True) \
      .add("street",StringType(),True) \
      .add("house_number",IntegerType(),True) \
      .add("post_code",IntegerType(),True) \
      .add("city",StringType(),True) \
      .add("latitude",DoubleType(),True) \
      .add("longitude",DoubleType(),True) \
      .add("first_active",StringType(),True) \
      .add("openingtimes_json",StringType(),True)

In [None]:
all_stations = spark.read.format("csv") \
      .option("header", True) \
      .schema(schema) \
      .load(station_files)

In [None]:
all_stations.show(10)
print(all_stations.dtypes)
print(all_stations.count())

In [None]:
lower_brand = "lower_brand"
brand_stations = all_stations \
    .withColumn(lower_brand, lower(all_stations.brand)) \
    .filter(col(lower_brand) == brand) \
    .drop(lower_brand)

In [None]:
brand_stations.show(10)

In [None]:
standardised_city = "standardised_city"
standardised_stations = brand_stations \
    .withColumn('lower_city', lower(brand_stations.city)) \
    .withColumn(standardised_city, regexp_replace(col("lower_city"), lit("ü"), lit("ue"))) \
    .drop("lower_city")

In [None]:
standardised_stations.show(10)

In [None]:
city_rlike = f"({'|'.join(cities)})"
selected_city_stations = standardised_stations \
    .filter(col(standardised_city).rlike(city_rlike))

In [None]:
selected_city_stations.show(10)

In [None]:
# need to remove Frankfurt an der Oder
selected_city_stations_clean = selected_city_stations \
    .where(~(col(standardised_city).contains("frankfurt") & (col("latitude") > 51.5)))

In [None]:
selected_city_stations_clean.show(10),
print(selected_city_stations_clean.count())

In [None]:
selected_city_stations_unique = selected_city_stations_clean \
    .dropDuplicates(["uuid"]) \
    .drop("name", "brand", "street", "house_number", "post_code", "city", "first_active", "openingtimes_json")

In [None]:
n_stations = selected_city_stations_unique.count()
selected_city_stations_unique.show(n_stations)
print(n_stations)

In [None]:
selected_city_stations_unique.write.csv(os.path.join(project_base_dir, "outputs/selected_stations_unique.csv"))