# Brief PySpark on Databricks

## 1. Merge les données

In [0]:
# Charger les bibliothèques Spark
from pyspark.sql import SparkSession
from pyspark.sql.functions import lit

In [0]:
spark = SparkSession.builder.appName("NettoyageDonnees").getOrCreate()

In [0]:
# Charger les tables
df1 = spark.table("default.yellow_tripdata_2024_01")
df2 = spark.table("default.yellow_tripdata_2024_02")
df3 = spark.table("default.yellow_tripdata_2024_03")
df4 = spark.table("default.yellow_tripdata_2024_04")
df5 = spark.table("default.yellow_tripdata_2024_05")
df6 = spark.table("default.yellow_tripdata_2024_06")
df7 = spark.table("default.yellow_tripdata_2024_07")
df8 = spark.table("default.yellow_tripdata_2024_08")
df9 = spark.table("default.yellow_tripdata_2024_09")
df10 = spark.table("default.yellow_tripdata_2024_10")
df11 = spark.table("default.yellow_tripdata_2024_11")
df12 = spark.table("default.yellow_tripdata_2024_12")
df13 = spark.table("default.yellow_tripdata_2025_01")
df14 = spark.table("default.yellow_tripdata_2025_02")
df15 = spark.table("default.yellow_tripdata_2025_03")
df16 = spark.table("default.yellow_tripdata_2025_04")
df17 = spark.table("default.yellow_tripdata_2025_05")
df18 = spark.table("default.yellow_tripdata_2025_06")
df19 = spark.table("default.yellow_tripdata_2025_07")
df20 = spark.table("default.yellow_tripdata_2025_08")
df21 = spark.table("default.yellow_tripdata_2025_09")

In [0]:
# Trouver toutes les colonnes présentes dans au moins une table
all_columns = set(df1.columns) | set(df2.columns) | set(df3.columns) | set(df4.columns) | set(df5.columns) | set(df6.columns) | set(df7.columns) | set(df8.columns) | set(df9.columns) | set(df10.columns) | set(df11.columns) | set(df12.columns) | set(df13.columns) | set(df14.columns) | set(df15.columns) | set(df16.columns) | set(df17.columns) | set(df18.columns) | set(df19.columns) | set(df20.columns) | set(df21.columns)

In [0]:
# Fonction pour ajouter les colonnes manquantes
def add_missing_cols(df, all_cols):
    existing_cols = set(df.columns)
    for c in all_cols - existing_cols:
        df = df.withColumn(c, lit(None))
    return df.select(sorted(all_cols)) 

In [0]:
# Ajouter les colonnes manquantes
df1 = add_missing_cols(df1, all_columns)
df2 = add_missing_cols(df2, all_columns)
df3 = add_missing_cols(df3, all_columns)
df4 = add_missing_cols(df4, all_columns)
df5 = add_missing_cols(df5, all_columns)
df6 = add_missing_cols(df6, all_columns)
df7 = add_missing_cols(df7, all_columns)
df8 = add_missing_cols(df8, all_columns)
df9 = add_missing_cols(df9, all_columns)
df10 = add_missing_cols(df10, all_columns)
df11 = add_missing_cols(df11, all_columns)
df12 = add_missing_cols(df12, all_columns)
df13 = add_missing_cols(df13, all_columns)
df14 = add_missing_cols(df14, all_columns)
df15 = add_missing_cols(df15, all_columns)
df16 = add_missing_cols(df16, all_columns)
df17 = add_missing_cols(df17, all_columns)
df18 = add_missing_cols(df18, all_columns)
df19 = add_missing_cols(df19, all_columns)
df20 = add_missing_cols(df20, all_columns)
df21 = add_missing_cols(df21, all_columns)

merged_df1 = df1.union(df2)
merged_df2 = merged_df1.union(df3)
merged_df3 = merged_df2.union(df4)
merged_df4 = merged_df3.union(df5)
merged_df5 = merged_df4.union(df6)
merged_df6 = merged_df5.union(df7)
merged_df7 = merged_df6.union(df8)
merged_df8 = merged_df7.union(df9)
merged_df9 = merged_df8.union(df10)
merged_df10 = merged_df9.union(df11)
merged_df11 = merged_df10.union(df12)
merged_df12 = merged_df11.union(df13)
merged_df13 = merged_df12.union(df14)
merged_df14 = merged_df13.union(df15)
merged_df15 = merged_df14.union(df16)
merged_df16 = merged_df15.union(df17)
merged_df17 = merged_df16.union(df18)
merged_df18 = merged_df17.union(df19)
merged_df19 = merged_df18.union(df20)

df = merged_df19.union(df21)

df.write \
  .option("mergeSchema", "true") \
  .mode("overwrite") \
  .saveAsTable("default.Raw_table")

In [0]:
display(df)

## 2. Nettoyer les données

In [0]:
spark.sql("USE default")

# Lister les tables disponibles dans la base
spark.sql("SHOW TABLES").display()

# Récupération de la table raw
df1 = spark.table("default.Raw_table")

from pyspark.sql import functions as F

# Trouver les lignes dupliquées
duplicates = (
    df1.groupBy(df1.columns)
      .count()
      .filter(F.col("count") > 1)
)

duplicates.display(truncate=False)
print("Nombre de doublons :", duplicates.count())

df_clean = df1.dropDuplicates()
df_clean.write.mode("overwrite").saveAsTable("default.clean_table")

df_clean.display(5)
print('Nombre de lignes : ', df_clean.count())

## 3. Créer les tables des requêtes 

In [0]:
spark.sql("USE default")

spark.sql("SHOW TABLES").display()

df1 = spark.table("default.clean_table")

### Identifier les 10 zones de départ les plus fréquentées chaque mois

In [0]:
spark.sql("""
    CREATE TABLE IF NOT EXISTS default.Top_10_pickup_zones AS
    WITH trips_per_month AS (
        SELECT 
            YEAR(tpep_pickup_datetime) AS year,
            MONTH(tpep_pickup_datetime) AS month,
            PULocationID,
            COUNT(*) AS trip_count
        FROM default.clean_table
        GROUP BY year, month, PULocationID
    ),
    ranking AS (
        SELECT 
            year,
            month,
            PULocationID,
            trip_count,
            RANK() OVER (PARTITION BY year, month ORDER BY trip_count DESC) AS rank
        FROM trips_per_month
    )
    SELECT *
    FROM ranking
    WHERE rank <= 10
    ORDER BY year, month, rank
""")
display(spark.table("default.Top_10_pickup_zones"))

### Calculer la durée moyenne des trajets par mois

In [0]:
spark.sql("""
    CREATE TABLE IF NOT EXISTS default.avg_trip_duration_per_month AS
    SELECT 
        CONCAT(YEAR(tpep_pickup_datetime), '-', MONTH(tpep_pickup_datetime)) AS datemonth,
        ROUND(AVG(UNIX_TIMESTAMP(tpep_dropoff_datetime) - UNIX_TIMESTAMP(tpep_pickup_datetime)), 2) AS duration
    FROM default.clean_table
    GROUP BY datemonth
    ORDER BY datemonth
""")

df = spark.table("default.avg_trip_duration_per_month")
display(df)

### Déterminer la distance moyenne par type de paiement

In [0]:
spark.sql("""
    CREATE TABLE IF NOT EXISTS default.avg_distance_by_payment_type AS
        SELECT payment_type, avg(trip_distance) as avg_distance
        FROM default.clean_table
        GROUP BY payment_type
        ORDER BY avg_distance""")
display(spark.table("default.avg_distance_by_payment_type"))

### Estimer le montant moyen des courses en fonction du nombre de passagers

In [0]:
spark.sql("""
        DROP TABLE IF EXISTS default.avg_amount_by_passenger_count""")
spark.sql(""" 
        CREATE TABLE IF NOT EXISTS default.avg_amount_by_passenger_count AS
        SELECT round(avg(total_amount), 2) as avg_total_amount, passenger_count
        FROM default.clean_table
        GROUP BY passenger_count
        ORDER BY passenger_count;""")
display(spark.table("default.avg_amount_by_passenger_count"))

### Mesurer la somme totale des pourboires versés chaque mois

In [0]:
spark.sql("""
        DROP TABLE IF EXISTS default.total_tip_per_month""")
spark.sql(""" 
        CREATE TABLE IF NOT EXISTS default.total_tip_per_month AS
        SELECT round(SUM(tip_amount), 2) as total_tip_amount, concat(year(tpep_pickup_datetime), '-', month(tpep_pickup_datetime)) as month_year
        FROM default.clean_table
        GROUP BY concat(year(tpep_pickup_datetime), '-', month(tpep_pickup_datetime))
        ORDER BY concat(year(tpep_pickup_datetime), '-', month(tpep_pickup_datetime));""")
display(spark.table("default.total_tip_per_month"))

## 4.Export

In [0]:

import os

jdbc_username = os.environ.get("LOGIN")
jdbc_password = os.environ.get("PASS")

print("Utilisateur connecté :", jdbc_username)

jdbc_url = "jdbc:sqlserver://cacfsql.database.windows.net:1433;database=brief"

connection_properties = {
    "user": jdbc_username,
    "password": jdbc_password,
    "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}

In [0]:
tables = ["total_tip_per_month", "avg_distance_by_payment_type", "avg_trip_duration_per_month", "avg_amount_by_passenger_count", "top_10_pickup_zones"]

for table in tables:
    df = spark.table("default." + table)
    df.write.mode("overwrite").jdbc(
        url=jdbc_url,
        table="lr_" + table,
        properties={
            "user": jdbc_username,
            "password": jdbc_password,
            "driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
        }
    )