In [None]:
import os
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, col

In [None]:
spark = SparkSession.builder.appName("ETL-local").getOrCreate()

In [None]:
dataframe = spark.read.csv("./data/users.csv", header=True, inferSchema=True)
dataframe.show(5)

In [None]:
dataframe_clean = dataframe.filter(
    dataframe.age.isNotNull()
).filter(
    dataframe.country != "Unknown"
)

In [None]:
dataframe_country = dataframe_clean.groupBy("country").agg(
    {
        "age": "avg",
        "salary": "avg"
    }
)
dataframe_country.show()

In [None]:
dataframe = dataframe.withColumn(
    "age_category",
    when(col("age") < 30, "jeune")
    .when(col("age") < 45, "adulte")
    .otherwise("senior")
)

In [None]:
dataframe.groupBy("country", "age_category").agg(
    {
        "salary": "avg"
    }
).show()

In [None]:
os.makedirs("./data/output", exist_ok=True)
dataframe_country.write.mode("overwrite").parquet("./data/output/country_stats.parquet")

In [None]:
read_parquet = pd.read_parquet("./data/output/country_stats.parquet/", engine="pyarrow")
read_parquet

In [None]:
spark.stop()