In [1]:
import os
os.environ["JAVA_HOME"] = "/opt/homebrew/opt/openjdk@17/libexec/openjdk.jdk/Contents/Home"  # Set JAVA_HOME environment variable

In [2]:
# Imports nécessaires
import os
import csv
import io
import glob
import builtins
import zipfile
import json
import gc
from pathlib import Path
from datetime import datetime
from collections import defaultdict
import pyarrow.parquet as pq
import warnings
warnings.filterwarnings('ignore')

# PySpark imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.sql import DataFrame

# Pour l'affichage
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

# Liste des pays de l'Union Européenne (27 membres)
EU_COUNTRIES = {
    'AT', 'BE', 'BG', 'CY', 'CZ', 'DE', 'DK', 'EE', 'ES', 'FI', 
    'FR', 'GR', 'HR', 'HU', 'IE', 'IT', 'LT', 'LU', 'LV', 'NL', 
    'PL', 'PT', 'RO', 'SE', 'SI', 'SK'
}

print("Imports réalisés avec succès")

Imports réalisés avec succès


In [3]:
# Configuration Spark optimisée pour MacBook Pro M3 Max (36GB, 14 cœurs)
# Rappels : réserver ~8-10GB pour macOS, utiliser tous les cœurs disponibles
spark = SparkSession.builder \
    .appName("ObRail_GTFS_Analysis") \
    .master("local[14]") \
    .config("spark.driver.memory", "30g") \
    .config("spark.driver.maxResultSize", "6g") \
    .config("spark.sql.shuffle.partitions", "56") \
    .config("spark.default.parallelism", "56") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.sql.files.maxPartitionBytes", "134217728") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.memory.fraction", "0.75") \
    .config("spark.memory.storageFraction", "0.3") \
    .config("spark.sql.autoBroadcastJoinThreshold", "20971520") \
    .config("spark.local.dir", "/tmp/spark-temp") \
    .config("spark.ui.showConsoleProgress", "false") \
    .config("spark.sql.session.timeZone", "UTC") \
    .getOrCreate()

# Configuration du niveau de log
spark.sparkContext.setLogLevel("WARN")

print(f"✓ Spark Session créée avec succès")
print(f"  - Version Spark : {spark.version}")
print(f"  - Master : {spark.sparkContext.master}")
print(f"  - Mémoire Driver : 26 GB")
print(f"  - Cœurs utilisés : 14")
print(f"  - Partitions shuffle : 56")
print(f"  - Application ID : {spark.sparkContext.applicationId}")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/02/24 11:41:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/02/24 11:41:21 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in standalone/kubernetes and LOCAL_DIRS in YARN).


✓ Spark Session créée avec succès
  - Version Spark : 4.1.1
  - Master : local[14]
  - Mémoire Driver : 26 GB
  - Cœurs utilisés : 14
  - Partitions shuffle : 56
  - Application ID : local-1771929682077


In [4]:
# Chargement des DataFrames Parquet
enriched_path = './processed/gtfs_enriched/'
botn_cleaned_path = './processed/botn_cleaned.parquet'

# Chargement du DataFrame gtfs_enriched (dossier Parquet)
df_enriched = spark.read.parquet(enriched_path)
print('Schéma de df_enriched :')
df_enriched.printSchema()

# Chargement du DataFrame botn_cleaned (fichier Parquet unique)
df_botn_cleaned = spark.read.parquet(botn_cleaned_path)
print('Schéma de df_botn_cleaned :')
df_botn_cleaned.printSchema()

Schéma de df_enriched :
root
 |-- source: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- trip_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- route_type: integer (nullable = true)
 |-- route_short_name: string (nullable = true)
 |-- route_long_name: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- trip_short_name: string (nullable = true)
 |-- agency_id: string (nullable = true)
 |-- agency_name: string (nullable = true)
 |-- agency_timezone: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- stop_lat: double (nullable = true)
 |-- stop_lon: double (nullable = true)
 |-- parent_station: string (nullable = true)
 |-- arrival_time: string (nullable = true)
 |-- departure_time: string (nullable = true)
 |-- stop_sequence: integer (nullable = true)
 |-- start_date: date (nullable = true)
 |-- end_date: date (nullable = true)
 |-- segment_dist_m: double (nu

In [5]:
# ═══════════════════════════════════════════════════════════
# 1. AJOUT is_night_train
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F

df_enriched = df_enriched.withColumn("is_night_train",
    F.when(F.col("route_type") == 105, True).otherwise(False)
)

df_botn_cleaned = df_botn_cleaned.withColumn("is_night_train", F.lit(True))

print(f"GTFS is_night_train=True  : {df_enriched.where(F.col('is_night_train')).count():,}")
print(f"GTFS is_night_train=False : {df_enriched.where(~F.col('is_night_train')).count():,}")
print(f"BOTN is_night_train=True  : {df_botn_cleaned.count():,}")


GTFS is_night_train=True  : 4,355
GTFS is_night_train=False : 23,245,345
BOTN is_night_train=True  : 3,265


In [6]:

# ═══════════════════════════════════════════════════════════
# 2. PROFIL DES TRIPS BOTN vs GTFS — Identifiants disponibles
# ═══════════════════════════════════════════════════════════

# Résumé au niveau trip pour chaque source
df_gtfs_trips = (
    df_enriched
    .groupBy("trip_id")
    .agg(
        F.first("agency_name").alias("agency_name"),
        F.first("agency_id").alias("agency_id"),
        F.first("route_short_name").alias("route_short_name"),
        F.first("trip_short_name").alias("trip_short_name"),
        F.first("trip_headsign").alias("trip_headsign"),
        F.first("source").alias("source"),
        F.first("country").alias("origin_country"),
        F.min("stop_sequence").alias("first_seq"),
        F.max("stop_sequence").alias("last_seq"),
        F.count("*").alias("n_stops"),
    )
)

df_botn_trips = (
    df_botn_cleaned
    .groupBy("trip_id")
    .agg(
        F.first("agency_name").alias("agency_name"),
        F.first("agency_id").alias("agency_id"),
        F.first("route_short_name").alias("route_short_name"),
        F.first("trip_short_name").alias("trip_short_name"),
        F.first("trip_headsign").alias("trip_headsign"),
        F.first("country").alias("origin_country"),
        F.count("*").alias("n_stops"),
    )
)

print(f"\nTrips GTFS  : {df_gtfs_trips.count():,}")
print(f"Trips BOTN  : {df_botn_trips.count()}")

# ═══════════════════════════════════════════════════════════
# 3. MATCHING MULTI-CRITÈRES
# ═══════════════════════════════════════════════════════════

# --- 3a. Match exact trip_id ---
match_trip_id = df_botn_trips.join(
    df_gtfs_trips.select(F.col("trip_id").alias("gtfs_trip_id")),
    F.col("trip_id") == F.col("gtfs_trip_id")
).count()
print(f"\n=== MATCHING ===")
print(f"Match exact trip_id                : {match_trip_id}")

# --- 3b. Match exact trip_short_name ---
match_tsn = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select("trip_short_name").distinct().alias("g"),
        F.col("b.trip_short_name") == F.col("g.trip_short_name")
    )
    .select("b.trip_id").distinct().count()
)
print(f"Match exact trip_short_name        : {match_tsn}")

# --- 3c. Match exact route_short_name ---
match_rsn = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select("route_short_name").distinct().alias("g"),
        F.col("b.route_short_name") == F.col("g.route_short_name")
    )
    .select("b.trip_id").distinct().count()
)
print(f"Match exact route_short_name       : {match_rsn}")

# --- 3d. Match agency_name (fuzzy : premier token) ---
df_botn_ag = df_botn_trips.withColumn("botn_ag_token", F.split("agency_name", " ")[0])
df_gtfs_ag = df_gtfs_trips.withColumn("gtfs_ag_token", F.split("agency_name", " ")[0]).select("gtfs_ag_token").distinct()
match_agency = (
    df_botn_ag.join(df_gtfs_ag, F.col("botn_ag_token") == F.col("gtfs_ag_token"))
    .select("trip_id").distinct().count()
)
print(f"Match agency premier token         : {match_agency}")

# --- 3e. Match trip_headsign (destination identique) ---
match_headsign = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select("trip_headsign").distinct().alias("g"),
        F.col("b.trip_headsign") == F.col("g.trip_headsign")
    )
    .select("b.trip_id").distinct().count()
)
print(f"Match exact trip_headsign          : {match_headsign}")

# --- 3f. Quels opérateurs BOTN sont dans GTFS ? ---
print(f"\n--- Opérateurs BOTN présents dans GTFS (agency_name token) ---")
(
    df_botn_trips.withColumn("ag_token", F.split("agency_name", " ")[0])
    .select("ag_token").distinct()
    .join(
        df_gtfs_trips.withColumn("ag_token", F.split("agency_name", " ")[0])
        .select("ag_token").distinct(),
        "ag_token", "left"
    )
    .withColumn("in_gtfs", F.when(F.col("ag_token").isNotNull(), "✓").otherwise("✗"))
    .orderBy("ag_token")
    .show(40, truncate=False)
)


Trips GTFS  : 1,682,011
Trips BOTN  : 408

=== MATCHING ===
Match exact trip_id                : 0
Match exact trip_short_name        : 144
Match exact route_short_name       : 12
Match agency premier token         : 128
Match exact trip_headsign          : 126

--- Opérateurs BOTN présents dans GTFS (agency_name token) ---
+-------------------+-------+
|ag_token           |in_gtfs|
+-------------------+-------+
|Astra              |✓      |
|C.F.R.             |✓      |
|Calea              |✓      |
|Caledonian         |✓      |
|European           |✓      |
|First              |✓      |
|Go-Ahead           |✓      |
|HŽ                 |✓      |
|Merresor           |✓      |
|MÁV-START          |✓      |
|Optima             |✓      |
|PKP                |✓      |
|RegioJet           |✓      |
|SJ                 |✓      |
|SNCF               |✓      |
|TCDD               |✓      |
|Train4you          |✓      |
|Trenitalia         |✓      |
|VR-Yhtymä          |✓      |
|Vygruppen   

In [7]:
# ═══════════════════════════════════════════════════════════
# ANALYSE DÉTAILLÉE DES MATCHS
# ═══════════════════════════════════════════════════════════

# 1. Les 144 matchs trip_short_name — quels opérateurs ?
df_match_tsn = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select(
            F.col("trip_short_name").alias("g_tsn"),
            F.col("agency_name").alias("g_agency"),
            F.col("trip_id").alias("g_trip_id"),
        ),
        F.col("b.trip_short_name") == F.col("g_tsn")
    )
)

print("--- Matchs trip_short_name par opérateur BOTN ---")
(
    df_match_tsn
    .select("b.trip_id", "b.agency_id", "b.trip_short_name", "g_agency")
    .distinct()
    .groupBy("b.agency_id")
    .agg(
        F.count("*").alias("n_matched"),
    )
    .orderBy(F.desc("n_matched"))
    .show(30, truncate=False)
)

# 2. Quels BOTN n'ont PAS matché ? Patterns à explorer
df_no_match_tsn = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select(F.col("trip_short_name").alias("g_tsn")).distinct(),
        F.col("b.trip_short_name") == F.col("g_tsn"),
        "left_anti"
    )
)
print(f"\n--- BOTN sans match trip_short_name : {df_no_match_tsn.count()} trips ---")
print("Par opérateur :")
(
    df_no_match_tsn
    .groupBy("agency_id")
    .agg(
        F.count("*").alias("n"),
        F.collect_list("trip_short_name").alias("exemples_tsn")
    )
    .orderBy(F.desc("n"))
    .show(30, truncate=False)
)

# 3. Comparaison manuelle : qu'est-ce qui diffère ?
# Ex: un trip BOTN "NJ 40295" → dans GTFS c'est peut-être "40295" ou "NJ40295" ?
print("\n--- Échantillon trip_short_name BOTN non matchés ---")
df_no_match_tsn.select("trip_short_name", "agency_id").show(20, truncate=False)

# 4. Chercher des matchs partiels : le numéro du train
# Extraire la partie numérique du trip_short_name
df_botn_num = (
    df_botn_trips
    .withColumn("train_number", F.regexp_extract("trip_short_name", "(\\d{3,6})", 1))
    .where(F.col("train_number") != "")
)

df_gtfs_num = (
    df_gtfs_trips
    .withColumn("train_number", F.regexp_extract("trip_short_name", "(\\d{3,6})", 1))
    .where(F.col("train_number") != "")
    .select("train_number", "trip_id", "agency_name").distinct()
)

# Match sur numéro de train + même opérateur (premier token)
match_num_agency = (
    df_botn_num
    .withColumn("botn_ag", F.split("agency_name", " ")[0])
    .join(
        df_gtfs_num.withColumn("gtfs_ag", F.split("agency_name", " ")[0]),
        (F.col("botn_ag") == F.col("gtfs_ag")) & 
        (df_botn_num["train_number"] == df_gtfs_num["train_number"])
    )
    .select(df_botn_num["trip_id"]).distinct().count()
)
print(f"\nMatch numéro train + opérateur : {match_num_agency}")

# 5. Match géographique : même origin_city + destination_city + opérateur
w_first = Window.partitionBy("trip_id").orderBy("stop_sequence")
w_last = Window.partitionBy("trip_id").orderBy(F.desc("stop_sequence"))

for label, df, prefix in [("GTFS", df_enriched, "g_"), ("BOTN", df_botn_cleaned, "b_")]:
    globals()[f"df_{prefix}od"] = (
        df
        .withColumn("rn_first", F.row_number().over(w_first))
        .withColumn("rn_last", F.row_number().over(w_last))
        .groupBy("trip_id")
        .agg(
            F.first(F.when(F.col("rn_first") == 1, F.col("city"))).alias(f"{prefix}origin_city"),
            F.first(F.when(F.col("rn_last") == 1, F.col("city"))).alias(f"{prefix}dest_city"),
            F.first("agency_name").alias(f"{prefix}agency"),
        )
    )

match_od_agency = (
    globals()["df_b_od"].alias("b")
    .join(
        globals()["df_g_od"].alias("g"),
        (F.col("b_origin_city") == F.col("g_origin_city")) &
        (F.col("b_dest_city") == F.col("g_dest_city"))
    )
    .select("b.trip_id").distinct().count()
)
print(f"Match origin_city + dest_city      : {match_od_agency}")

--- Matchs trip_short_name par opérateur BOTN ---
+-----------+---------+
|agency_id  |n_matched|
+-----------+---------+
|UZ         |251      |
|CFR        |73       |
|BDŽ        |52       |
|HŽPP       |43       |
|ES         |37       |
|OTE        |29       |
|ÖBB        |28       |
|UEX        |27       |
|ST         |25       |
|SJN        |20       |
|SJ         |16       |
|RJ         |16       |
|TCDD       |14       |
|CFM/CFR    |13       |
|VY         |8        |
|ŽS/ŽPCG    |7        |
|ZSSK       |6        |
|HŽPP/ÖBB   |4        |
|MÁV/ÖBB    |4        |
|CFR/MÁV/ÖBB|2        |
|PKP/ČD     |2        |
+-----------+---------+


--- BOTN sans match trip_short_name : 264 trips ---
Par opérateur :
+------------+---+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [8]:
# ═══════════════════════════════════════════════════════════
# STRATÉGIE DE MATCHING AVANCÉE
# ═══════════════════════════════════════════════════════════

# --- 1. Normaliser trip_short_name : extraire le numéro de train ---
# BOTN: "IC Notte 795" → "795", "NJ 40469" → "40469", "013 L" → "013"
# GTFS: formats variés

@F.udf(StringType())
def extract_train_numbers(name):
    """Extrait tous les blocs numériques ≥3 chiffres"""
    if name is None: return None
    import re
    nums = re.findall(r'\d{3,6}', name)
    return "|".join(nums) if nums else None

df_botn_nums = (
    df_botn_trips
    .withColumn("train_nums", extract_train_numbers(F.col("trip_short_name")))
    .withColumn("botn_ag", F.split("agency_name", " ")[0])
    .where(F.col("train_nums").isNotNull())
)

df_gtfs_nums = (
    df_gtfs_trips
    .withColumn("train_nums", extract_train_numbers(F.col("trip_short_name")))
    .withColumn("gtfs_ag", F.split("agency_name", " ")[0])
    .where(F.col("train_nums").isNotNull())
)

# Match : un numéro BOTN contenu dans un numéro GTFS + même opérateur
# Exploser les numéros
df_botn_exploded = (
    df_botn_nums
    .withColumn("num", F.explode(F.split("train_nums", "\\|")))
)
df_gtfs_exploded = (
    df_gtfs_nums
    .withColumn("num", F.explode(F.split("train_nums", "\\|")))
    .select(F.col("num").alias("g_num"), "gtfs_ag", 
            F.col("trip_id").alias("g_trip_id"),
            F.col("trip_short_name").alias("g_tsn"))
)

match_num_ag = (
    df_botn_exploded.alias("b")
    .join(
        df_gtfs_exploded.alias("g"),
        (F.col("b.num") == F.col("g.g_num")) & 
        (F.col("b.botn_ag") == F.col("g.gtfs_ag"))
    )
    .select("b.trip_id", "b.trip_short_name", "b.botn_ag", "g.g_tsn", "g.gtfs_ag")
    .distinct()
)
print(f"Match numéro train + opérateur (élargi) : {match_num_ag.select('b.trip_id').distinct().count()}")

# Échantillon des nouveaux matchs (pas dans les 144 exacts)
print("\n--- Nouveaux matchs numériques ---")
(
    match_num_ag
    .join(
        df_botn_trips.alias("exact")
        .join(df_gtfs_trips.select("trip_short_name").distinct(), "trip_short_name")
        .select(F.col("trip_id").alias("already_matched")),
        F.col("b.trip_id") == F.col("already_matched"),
        "left_anti"
    )
    .orderBy("b.botn_ag", "b.trip_short_name")
    .show(30, truncate=False)
)

# --- 2. Match géographique : stops proches entre BOTN et GTFS ---
# Origine et destination par coordonnées au lieu du nom de ville

w_first = Window.partitionBy("trip_id").orderBy("stop_sequence")
w_last = Window.partitionBy("trip_id").orderBy(F.desc("stop_sequence"))

df_botn_od = (
    df_botn_cleaned
    .withColumn("rn_f", F.row_number().over(w_first))
    .withColumn("rn_l", F.row_number().over(w_last))
    .groupBy("trip_id")
    .agg(
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_lat"))).alias("b_olat"),
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_lon"))).alias("b_olon"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_lat"))).alias("b_dlat"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_lon"))).alias("b_dlon"),
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_name"))).alias("b_origin"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_name"))).alias("b_dest"),
        F.first("agency_name").alias("b_agency"),
    )
)

df_gtfs_od = (
    df_enriched
    .withColumn("rn_f", F.row_number().over(w_first))
    .withColumn("rn_l", F.row_number().over(w_last))
    .groupBy("trip_id")
    .agg(
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_lat"))).alias("g_olat"),
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_lon"))).alias("g_olon"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_lat"))).alias("g_dlat"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_lon"))).alias("g_dlon"),
        F.first(F.when(F.col("rn_f") == 1, F.col("stop_name"))).alias("g_origin"),
        F.first(F.when(F.col("rn_l") == 1, F.col("stop_name"))).alias("g_dest"),
    )
)

# Match : origine à < 5km ET destination à < 5km
match_geo = (
    df_botn_od.alias("b").crossJoin(df_gtfs_od.alias("g"))
    .withColumn("dist_origin",
        2 * 6371 * F.asin(F.sqrt(
            F.sin(F.radians(F.col("g_olat") - F.col("b_olat")) / 2) ** 2 +
            F.cos(F.radians("b_olat")) * F.cos(F.radians("g_olat")) *
            F.sin(F.radians(F.col("g_olon") - F.col("b_olon")) / 2) ** 2
        ))
    )
    .withColumn("dist_dest",
        2 * 6371 * F.asin(F.sqrt(
            F.sin(F.radians(F.col("g_dlat") - F.col("b_dlat")) / 2) ** 2 +
            F.cos(F.radians("b_dlat")) * F.cos(F.radians("g_dlat")) *
            F.sin(F.radians(F.col("g_dlon") - F.col("b_dlon")) / 2) ** 2
        ))
    )
    .where((F.col("dist_origin") < 5) & (F.col("dist_dest") < 5))
)

n_geo = match_geo.select(F.col("b.trip_id")).distinct().count()
print(f"\nMatch géographique (origin + dest < 5km) : {n_geo} trips BOTN")

Match numéro train + opérateur (élargi) : 29

--- Nouveaux matchs numériques ---
+----------------------+---------------+----------+-------------------------+----------+
|trip_id               |trip_short_name|botn_ag   |g_tsn                    |gtfs_ag   |
+----------------------+---------------+----------+-------------------------+----------+
|PKP/ČD EC 460         |EC 460         |PKP       |460 Baltic Express       |PKP       |
|PKP/ČD EC 461         |EC 461         |PKP       |461 Baltic Express       |PKP       |
|PKP/ČD EN 40416       |EN 40416       |PKP       |40416 Carpatia           |PKP       |
|PKP/ČD EN 40417       |EN 40417       |PKP       |40417 Carpatia           |PKP       |
|PKP/ČD EN 443 / IC 406|EN 443 / IC 406|PKP       |443 Chopin               |PKP       |
|PKP IC 16170          |IC 16170       |PKP       |16170 Karkonosze         |PKP       |
|PKP IC 38172          |IC 38172       |PKP       |38172 Podhalanin         |PKP       |
|PKP/ČD IC 407 / EN 442|IC 40

In [9]:
# ═══════════════════════════════════════════════════════════
# MATCHING COMBINÉ — STRATÉGIE PROGRESSIVE
# ═══════════════════════════════════════════════════════════

# === NIVEAU 1 : match exact trip_short_name (déjà 144) ===
match_L1 = (
    df_botn_trips.alias("b")
    .join(
        df_gtfs_trips.select(
            F.col("trip_short_name").alias("g_tsn"),
            F.col("trip_id").alias("g_trip_id")
        ),
        F.col("b.trip_short_name") == F.col("g_tsn")
    )
    .select(
        F.col("b.trip_id").alias("botn_trip_id"),
        F.col("g_trip_id").alias("gtfs_trip_id"),
        F.lit("L1_exact_tsn").alias("match_method")
    )
    .distinct()
)

# === NIVEAU 2 : numéro train + opérateur (PKP, SNCF, ZSSK) ===
match_L2 = (
    df_botn_exploded.alias("b")
    .join(
        df_gtfs_exploded.alias("g"),
        (F.col("b.num") == F.col("g.g_num")) & 
        (F.col("b.botn_ag") == F.col("g.gtfs_ag"))
    )
    .select(
        F.col("b.trip_id").alias("botn_trip_id"),
        F.col("g.g_trip_id").alias("gtfs_trip_id"),
        F.lit("L2_num_agency").alias("match_method")
    )
    .distinct()
    # Exclure ceux déjà matchés L1
    .join(match_L1.select("botn_trip_id").distinct(), "botn_trip_id", "left_anti")
)

n_L1 = match_L1.select("botn_trip_id").distinct().count()
n_L2 = match_L2.select("botn_trip_id").distinct().count()
print(f"L1 exact trip_short_name  : {n_L1} trips BOTN")
print(f"L2 numéro + opérateur    : {n_L2} trips BOTN (nouveaux)")

# === NIVEAU 3 : géo (origin/dest < 5km) avec bucketing ===
# Réduire GTFS aux candidats pertinents : même opérateur OU route_type ∈ [100-199, 2]
df_gtfs_candidates = (
    df_gtfs_od
    .join(
        df_gtfs_trips.select("trip_id", "agency_name", F.col("source")),
        "trip_id"
    )
    .withColumn("gtfs_ag", F.split("agency_name", " ")[0])
)

# Bucketing : arrondir lat/lon à 0.5°
df_botn_od_buck = (
    df_botn_od
    .withColumn("o_bucket", F.concat(
        F.round(F.col("b_olat") * 2, 0).cast("int"), F.lit("_"),
        F.round(F.col("b_olon") * 2, 0).cast("int")
    ))
    .withColumn("d_bucket", F.concat(
        F.round(F.col("b_dlat") * 2, 0).cast("int"), F.lit("_"),
        F.round(F.col("b_dlon") * 2, 0).cast("int")
    ))
    # Seulement les non-matchés
    .join(
        match_L1.select("botn_trip_id").distinct()
        .union(match_L2.select("botn_trip_id").distinct())
        .withColumnRenamed("botn_trip_id", "trip_id"),
        "trip_id", "left_anti"
    )
)

df_gtfs_od_buck = (
    df_gtfs_candidates
    .withColumn("o_bucket", F.concat(
        F.round(F.col("g_olat") * 2, 0).cast("int"), F.lit("_"),
        F.round(F.col("g_olon") * 2, 0).cast("int")
    ))
    .withColumn("d_bucket", F.concat(
        F.round(F.col("g_dlat") * 2, 0).cast("int"), F.lit("_"),
        F.round(F.col("g_dlon") * 2, 0).cast("int")
    ))
)

# Join sur buckets (origin bucket + dest bucket)
match_L3 = (
    df_botn_od_buck.alias("b")
    .join(
        df_gtfs_od_buck.alias("g"),
        (F.col("b.o_bucket") == F.col("g.o_bucket")) &
        (F.col("b.d_bucket") == F.col("g.d_bucket"))
    )
    # Haversine origin
    .withColumn("dist_o",
        2 * 6371 * F.asin(F.sqrt(
            F.sin(F.radians(F.col("g_olat") - F.col("b_olat")) / 2) ** 2 +
            F.cos(F.radians("b_olat")) * F.cos(F.radians("g_olat")) *
            F.sin(F.radians(F.col("g_olon") - F.col("b_olon")) / 2) ** 2
        ))
    )
    # Haversine dest
    .withColumn("dist_d",
        2 * 6371 * F.asin(F.sqrt(
            F.sin(F.radians(F.col("g_dlat") - F.col("b_dlat")) / 2) ** 2 +
            F.cos(F.radians("b_dlat")) * F.cos(F.radians("g_dlat")) *
            F.sin(F.radians(F.col("g_dlon") - F.col("b_dlon")) / 2) ** 2
        ))
    )
    .where((F.col("dist_o") < 5) & (F.col("dist_d") < 5))
    .select(
        F.col("b.trip_id").alias("botn_trip_id"),
        F.col("g.trip_id").alias("gtfs_trip_id"),
        "b_origin", "b_dest", "g_origin", "g_dest",
        F.round("dist_o", 1).alias("dist_origin_km"),
        F.round("dist_d", 1).alias("dist_dest_km"),
        F.lit("L3_geo_od").alias("match_method")
    )
    .distinct()
)

n_L3 = match_L3.select("botn_trip_id").distinct().count()
print(f"L3 géographique O/D      : {n_L3} trips BOTN (nouveaux)")

# Échantillon
print("\n--- Matchs géographiques L3 ---")
(
    match_L3
    .select("botn_trip_id", "b_origin", "b_dest", "g_origin", "g_dest", "dist_origin_km", "dist_dest_km")
    .distinct()
    .orderBy("botn_trip_id")
    .show(30, truncate=False)
)

# === TOTAL ===
total_matched = (
    match_L1.select("botn_trip_id").distinct()
    .union(match_L2.select("botn_trip_id").distinct())
    .union(match_L3.select("botn_trip_id").distinct())
    .distinct().count()
)
print(f"\n=== TOTAL MATCHÉS : {total_matched} / 408 ({total_matched/408:.1%}) ===")

L1 exact trip_short_name  : 144 trips BOTN
L2 numéro + opérateur    : 18 trips BOTN (nouveaux)
L3 géographique O/D      : 0 trips BOTN (nouveaux)

--- Matchs géographiques L3 ---
+------------+--------+------+--------+------+--------------+------------+
|botn_trip_id|b_origin|b_dest|g_origin|g_dest|dist_origin_km|dist_dest_km|
+------------+--------+------+--------+------+--------------+------------+
+------------+--------+------+--------+------+--------------+------------+


=== TOTAL MATCHÉS : 162 / 408 (39.7%) ===


In [10]:
# ═══════════════════════════════════════════════════════════
# 1. TABLE DE MATCHING CONSOLIDÉE
# ═══════════════════════════════════════════════════════════

df_matches = match_L1.union(match_L2.select("botn_trip_id", "gtfs_trip_id", "match_method"))

# Garder le meilleur match par trip BOTN (L1 prioritaire)
w_best = Window.partitionBy("botn_trip_id").orderBy(
    F.when(F.col("match_method") == "L1_exact_tsn", 0).otherwise(1)
)
df_matches_best = (
    df_matches
    .withColumn("rn", F.row_number().over(w_best))
    .where(F.col("rn") == 1)
    .drop("rn")
)

print(f"Matchs uniques BOTN→GTFS : {df_matches_best.count()}")
print(f"GTFS trips matchés       : {df_matches_best.select('gtfs_trip_id').distinct().count()}")

# ═══════════════════════════════════════════════════════════
# 2. FLAGUER is_night_train SUR GTFS
# ═══════════════════════════════════════════════════════════

# GTFS trips identifiés comme trains de nuit : matchés BOTN OU route_type 105
gtfs_night_trip_ids = df_matches_best.select(F.col("gtfs_trip_id").alias("trip_id"))

df_enriched = (
    df_enriched
    .withColumn("is_night_train",
        F.when(F.col("route_type") == 105, True)
         .when(F.col("trip_id").isin(
             [r[0] for r in gtfs_night_trip_ids.distinct().collect()]
         ), True)
         .otherwise(False)
    )
)

n_night_gtfs = df_enriched.where(F.col("is_night_train")).select("trip_id").distinct().count()
n_day_gtfs = df_enriched.where(~F.col("is_night_train")).select("trip_id").distinct().count()
print(f"\nGTFS trains de nuit  : {n_night_gtfs:,} trips")
print(f"GTFS trains de jour  : {n_day_gtfs:,} trips")

# ═══════════════════════════════════════════════════════════
# 3. PROPAGER ÉMISSIONS BOTN → GTFS MATCHÉS
# ═══════════════════════════════════════════════════════════

# Récupérer émissions au niveau trip BOTN
df_botn_emissions = (
    df_botn_cleaned
    .select("trip_id", "botn_emissions_co2e", "botn_co2_per_km")
    .distinct()
)

# Joindre via la table de matching
df_emissions_mapping = (
    df_matches_best
    .join(df_botn_emissions, df_matches_best["botn_trip_id"] == df_botn_emissions["trip_id"])
    .select(
        F.col("gtfs_trip_id").alias("trip_id"),
        "botn_emissions_co2e", "botn_co2_per_km"
    )
)

# Ajouter au GTFS
df_enriched = (
    df_enriched
    .join(df_emissions_mapping, "trip_id", "left")
)

n_with_emissions = df_enriched.where(F.col("botn_co2_per_km").isNotNull()).select("trip_id").distinct().count()
print(f"GTFS trips avec émissions BOTN : {n_with_emissions:,}")

# ═══════════════════════════════════════════════════════════
# 4. RÉCAP FINAL
# ═══════════════════════════════════════════════════════════

print(f"\n{'='*50}")
print(f"GTFS enrichi : {df_enriched.count():,} lignes")
print(f"  Trains de nuit  : {n_night_gtfs:,} trips ({n_night_gtfs/df_enriched.select('trip_id').distinct().count():.2%})")
print(f"  Trains de jour  : {n_day_gtfs:,} trips")
print(f"  Avec émissions  : {n_with_emissions:,} trips")
print(f"\nBOTN             : {df_botn_cleaned.count():,} lignes / {df_botn_cleaned.select('trip_id').distinct().count()} trips")
print(f"  Matchés → GTFS  : {df_matches_best.count()} trips ({df_matches_best.count()/408:.1%})")
print(f"  Non matchés     : {408 - df_matches_best.count()} trips")
print(f"{'='*50}")

# Colonnes GTFS finales
print(f"\nSchéma GTFS enrichi :")
df_enriched.printSchema()

Matchs uniques BOTN→GTFS : 162
GTFS trips matchés       : 160

GTFS trains de nuit  : 1,088 trips
GTFS trains de jour  : 1,680,961 trips
GTFS trips avec émissions BOTN : 160

GTFS enrichi : 23,249,736 lignes
  Trains de nuit  : 1,088 trips (0.06%)
  Trains de jour  : 1,680,961 trips
  Avec émissions  : 160 trips

BOTN             : 3,265 lignes / 408 trips
  Matchés → GTFS  : 162 trips (39.7%)
  Non matchés     : 246 trips

Schéma GTFS enrichi :
root
 |-- trip_id: string (nullable = true)
 |-- source: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- route_type: integer (nullable = true)
 |-- route_short_name: string (nullable = true)
 |-- route_long_name: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- trip_short_name: string (nullable = true)
 |-- agency_id: string (nullable = true)
 |-- agency_name: string (nullable = true)
 |-- agency_timezone: string (nul

In [11]:
# ═══════════════════════════════════════════════════════════
# 5. AJOUTER LES TRIPS BOTN NON MATCHÉS (sans introduire de colonnes conflictuelles)
# ═══════════════════════════════════════════════════════════

# 1) Identifier les trips BOTN non matchés
botn_matched_ids = df_matches_best.select("botn_trip_id").distinct()
botn_unmatched = df_botn_cleaned.join(
    botn_matched_ids,
    df_botn_cleaned["trip_id"] == botn_matched_ids["botn_trip_id"],
    "left_anti"
)

# 2) Éviter d'introduire des colonnes provenant d'autres tables (ex: df_ref) ou des colonnes problématiques
conflicting_cols = []
if 'date' in botn_unmatched.columns:
    conflicting_cols.append('date')
# ajouter toute colonne de df_ref qui apparaîtrait dans botn_unmatched
try:
    ref_cols = df_ref.columns
except Exception:
    ref_cols = []
for c in ref_cols:
    if c in botn_unmatched.columns:
        conflicting_cols.append(c)

if conflicting_cols:
    botn_unmatched = botn_unmatched.drop(*conflicting_cols)

# 3) Préparer les colonnes : faire strictement correspondre le schéma de df_enriched
from pyspark.sql import functions as F

gtfs_cols = df_enriched.columns
botn_cols = botn_unmatched.columns

# Ajouter les colonnes manquantes avec des valeurs nulles
for col in gtfs_cols:
    if col not in botn_cols:
        botn_unmatched = botn_unmatched.withColumn(col, F.lit(None))

# Supprimer les colonnes supplémentaires non présentes dans df_enriched
extra_cols = [c for c in botn_unmatched.columns if c not in gtfs_cols]
if extra_cols:
    botn_unmatched = botn_unmatched.drop(*extra_cols)

# 4) S'assurer que `source` existe et marquer l'origine
if "source" not in df_enriched.columns:
    df_enriched = df_enriched.withColumn("source", F.lit("GTFS"))
botn_unmatched = botn_unmatched.withColumn("source", F.lit("BOTN_only"))

# 5) Réordonner les colonnes pour correspondre
enriched_type_map = {f.name: f.dataType for f in df_enriched.schema.fields}

for col_name, target_type in enriched_type_map.items():
    if col_name not in botn_unmatched.columns:
        continue
    
    botn_field = dict((f.name, f.dataType) for f in botn_unmatched.schema.fields)
    source_type = botn_field.get(col_name)
    
    # Cas critique : STRING → DATE avec format GTFS YYYYMMDD
    if isinstance(target_type, DateType) and not isinstance(source_type, DateType):
        botn_unmatched = botn_unmatched.withColumn(
            col_name,
            F.to_date(
                F.regexp_replace(F.col(col_name).cast("string"), r'^(\d{4})(\d{2})(\d{2})$', '$1-$2-$3')
            )
        )
    # Autres cas de type mismatch → cast générique
    elif type(target_type) != type(source_type):
        botn_unmatched = botn_unmatched.withColumn(
            col_name,
            F.col(col_name).cast(target_type)
        )
botn_unmatched = botn_unmatched.select(gtfs_cols)

# 6) Union (concaténation des lignes)
final_df = df_enriched.unionByName(botn_unmatched)

# 7) Comptages de contrôle (au niveau trip_id distinct)
n_gtfs_trips = df_enriched.select("trip_id").distinct().count()
n_botn_unmatched_trips = botn_unmatched.select("trip_id").distinct().count()
n_final_trips = final_df.select("trip_id").distinct().count()

print(f"\nTrips GTFS distincts (avant)       : {n_gtfs_trips:,}")
print(f"Trips BOTN non matchés ajoutés     : {n_botn_unmatched_trips:,}")
print(f"Trips distincts après union        : {n_final_trips:,}")

# (Optionnel) compter les lignes totales
print(f"Lignes (rows) df_enriched          : {df_enriched.count():,}")
print(f"Lignes ajoutées (BOTN non matchés) : {botn_unmatched.count():,}")
print(f"Lignes totales final_df            : {final_df.count():,}")
# 8) Schéma final de la table enrichie
print(f"\nSchéma de la table finale enrichie :")
final_df.printSchema()



Trips GTFS distincts (avant)       : 1,682,011
Trips BOTN non matchés ajoutés     : 246
Trips distincts après union        : 1,682,257
Lignes (rows) df_enriched          : 23,249,736
Lignes ajoutées (BOTN non matchés) : 1,690
Lignes totales final_df            : 23,251,426

Schéma de la table finale enrichie :
root
 |-- trip_id: string (nullable = true)
 |-- source: string (nullable = true)
 |-- stop_id: string (nullable = true)
 |-- route_id: string (nullable = true)
 |-- service_id: string (nullable = true)
 |-- route_type: integer (nullable = true)
 |-- route_short_name: string (nullable = true)
 |-- route_long_name: string (nullable = true)
 |-- trip_headsign: string (nullable = true)
 |-- trip_short_name: string (nullable = true)
 |-- agency_id: string (nullable = true)
 |-- agency_name: string (nullable = true)
 |-- agency_timezone: string (nullable = true)
 |-- stop_name: string (nullable = true)
 |-- stop_lat: double (nullable = true)
 |-- stop_lon: double (nullable = true)
 |

In [12]:
# ═══════════════════════════════════════════════════════════
# ENRICHIR ÉMISSIONS : GTFS route_type 105 sans match BOTN
# Même logique : opérateur → pays → moyenne globale
# ═══════════════════════════════════════════════════════════

# 1. Identifier les trips 105 sans émissions
df_105_no_emi = (
    final_df
    .where((F.col("route_type") == 105) & F.col("botn_co2_per_km").isNull())
    .select("trip_id", "agency_name", "country")
    .distinct()
)
print(f"Trips route_type 105 sans émissions : {df_105_no_emi.select('trip_id').distinct().count()}")

# 2. Moyennes BOTN par opérateur (premier token)
df_avg_agency = (
    df_botn_cleaned
    .where(F.col("botn_co2_per_km").isNotNull())
    .withColumn("ag_token", F.split("agency_name", " ")[0])
    .select("trip_id", "ag_token", "botn_co2_per_km")
    .distinct()
    .groupBy("ag_token")
    .agg(F.avg("botn_co2_per_km").alias("avg_co2_agency"))
)

# 3. Moyennes BOTN par pays d'origine
w_first = Window.partitionBy("trip_id").orderBy("stop_sequence")
df_botn_origin = (
    df_botn_cleaned
    .withColumn("rn", F.row_number().over(w_first))
    .where(F.col("rn") == 1)
    .select("trip_id", F.col("country").alias("origin_country"))
)

df_avg_country = (
    df_botn_cleaned
    .where(F.col("botn_co2_per_km").isNotNull())
    .select("trip_id", "botn_co2_per_km").distinct()
    .join(df_botn_origin, "trip_id")
    .groupBy("origin_country")
    .agg(F.avg("botn_co2_per_km").alias("avg_co2_country"))
)

avg_global = df_botn_cleaned.where(F.col("botn_co2_per_km").isNotNull()).select("trip_id", "botn_co2_per_km").distinct().select(F.avg("botn_co2_per_km")).collect()[0][0]
print(f"Moyenne globale BOTN : {avg_global:.2f} g/km")

# 4. Appliquer : opérateur GTFS → pays origine GTFS → global
df_gtfs_origin = (
    final_df
    .withColumn("rn", F.row_number().over(w_first))
    .where(F.col("rn") == 1)
    .select("trip_id", F.col("country").alias("origin_country"))
)

df_fill = (
    df_105_no_emi
    .select("trip_id", "agency_name").distinct()
    .withColumn("ag_token", F.split("agency_name", " ")[0])
    .join(df_gtfs_origin, "trip_id", "left")
    .join(df_avg_agency, "ag_token", "left")
    .join(df_avg_country, "origin_country", "left")
    .withColumn("fill_co2_per_km",
        F.coalesce(
            F.col("avg_co2_agency"),
            F.col("avg_co2_country"),
            F.lit(avg_global)
        )
    )
    .select("trip_id", "fill_co2_per_km")
)

# 5. Appliquer au df_enriched
df_enriched = (
    final_df
    .join(df_fill, "trip_id", "left")
    .withColumn("botn_co2_per_km",
        F.coalesce(F.col("botn_co2_per_km"), F.col("fill_co2_per_km"))
    )
    .withColumn("botn_emissions_co2e",
        F.coalesce(
            F.col("botn_emissions_co2e"),
            F.round(F.col("botn_co2_per_km") * F.col("segment_dist_m") / 1000000, 2)
        )
    )
    .drop("fill_co2_per_km")
)

# Vérif
n_105_total = df_enriched.where(F.col("route_type") == 105).select("trip_id").distinct().count()
n_105_with = df_enriched.where((F.col("route_type") == 105) & F.col("botn_co2_per_km").isNotNull()).select("trip_id").distinct().count()
n_105_without = df_enriched.where((F.col("route_type") == 105) & F.col("botn_co2_per_km").isNull()).select("trip_id").distinct().count()

print(f"\nTrips route_type 105 : {n_105_total:,}")
print(f"  Avec émissions     : {n_105_with:,}")
print(f"  Sans émissions     : {n_105_without:,}")

# Détail des sources d'enrichissement
print("\n--- Opérateurs enrichis ---")
df_enriched.select("trip_id", "agency_name", "is_night_train", "botn_co2_per_km", "botn_emissions_co2e").filter(F.col("botn_co2_per_km").isNotNull() & F.col("is_night_train")).show(20, truncate=False)
(
    df_enriched
    .where((F.col("route_type") == 105) & F.col("is_night_train"))
    .select("trip_id", "agency_name", "botn_co2_per_km")
    .distinct()
    .withColumn("ag_token", F.split("agency_name", " ")[0])
    .groupBy("ag_token")
    .agg(
        F.count("*").alias("n_trips"),
        F.round(F.avg("botn_co2_per_km"), 2).alias("avg_co2"),
    )
    .orderBy(F.desc("n_trips"))
    .show(20, truncate=False)
)

Trips route_type 105 sans émissions : 928
Moyenne globale BOTN : 17.57 g/km

Trips route_type 105 : 1,175
  Avec émissions     : 1,175
  Sans émissions     : 0

--- Opérateurs enrichis ---
+---------------------------+-------------------------------------+--------------+------------------+-------------------+
|trip_id                    |agency_name                          |is_night_train|botn_co2_per_km   |botn_emissions_co2e|
+---------------------------+-------------------------------------+--------------+------------------+-------------------+
|1.TA.12-D11-j24-1.1.R      |OEBB Personenverkehr AG Kundenservice|true          |14.38848921       |16.0               |
|1.TA.12-D11-j24-1.1.R      |OEBB Personenverkehr AG Kundenservice|true          |14.38848921       |16.0               |
|10.TA.12-D5-j24-1.3.R      |OEBB Personenverkehr AG Kundenservice|true          |15.768362396      |0.61               |
|10.TA.12-D5-j24-1.3.R      |OEBB Personenverkehr AG Kundenservice|true        

In [13]:
df_ref = spark.read.parquet("./processed/co2_reference.parquet")

In [14]:
# ═══════════════════════════════════════════════════════════
# CALCUL ÉMISSIONS V2 — Lignes sans botn_emissions_co2e
# Intégration intensité Ember (2024/2025) & Correctif Diesel
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 1. Nouvelles constantes sourcées (ifeu, UIC, KTH, Rambøll)
ENERGY_BY_ROUTE_TYPE = {
    101: 0.050,  # TGV / High-speed (Forte traînée aéro compensée par occupation)
    100: 0.040,  # Intercity / Long-distance (L'optimum énergétique)
    102: 0.040,  # Long Distance Trains
    103: 0.040,  # Inter Regional Rail
    105: 0.090,  # Sleeper Rail / Night Train (Hôtel roulant, forte conso auxiliaire)
    106: 0.080,  # Regional Rail (Arrêts fréquents, occupation moyenne faible)
    109: 0.080,  # Suburban Railway
    2:   0.050,  # Rail (GTFS générique)
}
DEFAULT_ENERGY = 0.050  # Fallback
DEFAULT_GRID_INTENSITY = 230.0  # Moyenne UE approx
DIESEL_EMISSION_GCO2_PKM = 65.0 # Facteur d'émission standard régional diesel

# Taux d'électrification des réseaux ferrés par pays (IRG-Rail / Eurostat)
# Utilisé comme probabilité de traction électrique pour le trafic régional
ELECTRIFICATION_RATES = {
    'CH': 1.00, 'LI': 1.00, 'BE': 0.88, 'SE': 0.75, 'NL': 0.74, 'AT': 0.73, 
    'IT': 0.72, 'VA': 0.72, 'PL': 0.64, 'ES': 0.64, 'AD': 0.64, 'DE': 0.61, 
    'FR': 0.60, 'MC': 0.60, 'GB': 0.38, 'DK': 0.30, 'IE': 0.03, 'BA': 0.70,
    'BG': 0.74, 'HR': 0.37, 'CZ': 0.34, 'FI': 0.55, 'GR': 0.20, 'HU': 0.41,
    'NO': 0.65, 'PT': 0.65, 'RO': 0.37, 'SK': 0.44, 'SI': 0.50
}
DEFAULT_ELECTRIFICATION = 0.56 # Moyenne européenne

# 2. Pays d'origine du trip (premier arrêt)
w_first = Window.partitionBy("trip_id").orderBy("stop_sequence")
df_trip_origin = (
    df_enriched
    .withColumn("rn", F.row_number().over(w_first))
    .where(F.col("rn") == 1)
    .select("trip_id", F.col("country").alias("origin_country"))
)

# Nettoyage avant la jointure
if 'origin_country' in df_enriched.columns:
    df_enriched = df_enriched.drop('origin_country')

# Renommer la table de référence pour éviter les ambiguïtés
df_ref_renamed = df_ref.withColumnRenamed("country_alpha2", "eea_country").withColumnRenamed("eea_gco2_per_kwh", "eea_gco2_per_kwh_ref")

# Sélectionner uniquement les colonnes nécessaires depuis la référence pour éviter doublons
# RENOMMER les colonnes de référence pour qu'elles soient uniques après la jointure
if 'eea_country' in df_ref_renamed.columns and 'eea_gco2_per_kwh_ref' in df_ref_renamed.columns:
    df_ref_sel = (
        df_ref_renamed
        .select(
            F.col('eea_country').alias('ref_eea_country'),
            F.col('eea_gco2_per_kwh_ref').alias('ref_eea_gco2_per_kwh_ref')
        )
        .distinct()
    )
else:
    df_ref_sel = df_ref_renamed.select('*').distinct()

# Jointure avec df_enriched — utiliser les noms explicitement renommés
# Joindre d'abord df_trip_origin pour obtenir 'origin_country'
df_enriched = df_enriched.join(df_trip_origin, "trip_id", "left")

# Puis joindre la référence en comparant origin_country == ref_eea_country
df_enriched = df_enriched.join(
    df_ref_sel,
    df_enriched["origin_country"] == df_ref_sel["ref_eea_country"],
    "left"
)

# 3. Préparation des mappings PySpark (Énergie et Électrification)
energy_map_expr = F.create_map([F.lit(x) for p in ENERGY_BY_ROUTE_TYPE.items() for x in p])
elec_map_expr = F.create_map([F.lit(x) for p in ELECTRIFICATION_RATES.items() for x in p])

# Utiliser la colonne renommée `ref_eea_gco2_per_kwh_ref` pour éviter l'ambiguïté
df_enriched = df_enriched.withColumn("energy_kwh_pkm",
    F.coalesce(energy_map_expr[F.col("route_type")], F.lit(DEFAULT_ENERGY))
).withColumn("elec_rate",
    F.coalesce(elec_map_expr[F.col("origin_country")], F.lit(DEFAULT_ELECTRIFICATION))
).withColumn("grid_intensity",
    F.coalesce(F.col("ref_eea_gco2_per_kwh_ref"), F.lit(DEFAULT_GRID_INTENSITY))
)

# 4. Calcul de eea_co2_per_pkm avec la logique Hybride (Électrique/Diesel)
# - Les trains grandes lignes et de nuit (100, 101, 102, 103, 105) sont présumés 100% électriques.
# - Les trains régionaux et génériques (106, 109, 2) subissent le ratio d'électrification national.
df_enriched = df_enriched.withColumn("eea_co2_per_pkm",
    F.round(
        F.when(
            F.col("route_type").isin(106, 109, 2),
            (F.col("elec_rate") * (F.col("energy_kwh_pkm") * F.col("grid_intensity"))) +
            ((F.lit(1.0) - F.col("elec_rate")) * F.lit(DIESEL_EMISSION_GCO2_PKM))
        ).otherwise(
            F.col("energy_kwh_pkm") * F.col("grid_intensity")
        ),
        2
    )
)

# 5. Calculer émissions totales (kg CO₂) = gCO₂/pkm × distance_km / 1000
df_enriched = df_enriched.withColumn("eea_emissions_co2",
    F.round(
        F.col("eea_co2_per_pkm") * F.col("segment_dist_m") / 1_000_000, 
        2
    )
)

date_cols = [f.name for f in df_enriched.schema.fields 
             if str(f.dataType) == "StringType()" and "date" in f.name.lower()]

for col_name in date_cols:
    df_enriched = df_enriched.withColumn(
        col_name,
        F.to_date(F.col(col_name), "yyyyMMdd")
    )

# Sauvegarde
df_enriched.write.mode("overwrite").parquet("./processed/gtfs_co2_enriched.parquet")
df_enriched = spark.read.parquet("./processed/gtfs_co2_enriched.parquet")

# 6. Statistiques de contrôle
n_total = df_enriched.select("trip_id").distinct().count()
n_with_botn = df_enriched.where(F.col("botn_co2_per_km").isNotNull()).select("trip_id").distinct().count()
n_with_eea = df_enriched.where(F.col("eea_co2_per_pkm").isNotNull()).select("trip_id").distinct().count()

print(f"Trips total           : {n_total:,}")
print(f"Trips avec BOTN CO₂   : {n_with_botn:,}")
print(f"Trips avec EEA CO₂    : {n_with_eea:,}")

# Aperçu par route_type
print(f"\n--- Émissions EEA V2 moyennes par route_type ---")
(
    df_enriched
    .select("trip_id", "route_type", "origin_country", "eea_co2_per_pkm", "energy_kwh_pkm", "grid_intensity")
    .distinct()
    .groupBy("route_type")
    .agg(
        F.count("*").alias("n_trips"),
        F.round(F.avg("energy_kwh_pkm"), 3).alias("avg_kwh_pkm"),
        F.round(F.avg("grid_intensity"), 1).alias("avg_grid_gco2"),
        F.round(F.avg("eea_co2_per_pkm"), 2).alias("avg_gco2_pkm"),
    )
    .orderBy("route_type")
    .show(20, truncate=False)
)

# Nettoyage des colonnes intermédiaires
for c in ["origin_country", "ref_eea_gco2_per_kwh_ref", "ref_eea_country", "energy_kwh_pkm", "elec_rate", "grid_intensity"]:
    if c in df_enriched.columns:
        df_enriched = df_enriched.drop(c)


26/02/24 11:45:40 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


Trips total           : 1,682,257
Trips avec BOTN CO₂   : 1,334
Trips avec EEA CO₂    : 1,682,257

--- Émissions EEA V2 moyennes par route_type ---
+----------+-------+-----------+-------------+------------+
|route_type|n_trips|avg_kwh_pkm|avg_grid_gco2|avg_gco2_pkm|
+----------+-------+-----------+-------------+------------+
|2         |890820 |0.05       |224.6        |33.01       |
|100       |44853  |0.04       |123.8        |4.95        |
|101       |43405  |0.05       |74.9         |3.75        |
|102       |65518  |0.04       |59.3         |2.37        |
|103       |82491  |0.04       |77.8         |3.11        |
|105       |1175   |0.09       |114.0        |10.26       |
|106       |234919 |0.08       |132.1        |22.44       |
|107       |1459   |0.05       |47.9         |2.39        |
|108       |1004   |0.05       |41.6         |2.08        |
|109       |340983 |0.08       |125.9        |18.92       |
+----------+-------+-----------+-------------+------------+



In [15]:
# ═══════════════════════════════════════════════════════════
# COMPARAISON BOTN vs EEA — Écarts sur trips avec les deux
# ═══════════════════════════════════════════════════════════

df_compare = (
    df_enriched
    .where(
        F.col("botn_emissions_co2e").isNotNull() & 
        F.col("eea_emissions_co2").isNotNull() &
        (F.col("eea_emissions_co2") > 0)
    )
    .select("trip_id", "agency_name", "route_type", "segment_dist_m" ,
            "botn_emissions_co2e", "botn_co2_per_km",
            "eea_emissions_co2", "eea_co2_per_pkm")
    .distinct()
    .withColumn("diff_pct", 
        F.round((F.col("botn_emissions_co2e") - F.col("eea_emissions_co2")) 
                / F.col("eea_emissions_co2") * 100, 1))
)

df_compare = df_compare.withColumn(
    "seat_capacity_implied", 
    F.round(F.col("botn_emissions_co2e") / F.col("eea_emissions_co2"), 0)
)

n_compare = df_compare.count()
print(f"Trips avec BOTN + EEA : {n_compare}")

# Stats globales
print(f"\n--- Distribution des écarts BOTN vs EEA (%) ---")
df_compare.select(
    F.round(F.min("diff_pct"), 1).alias("min_%"),
    F.round(F.percentile_approx("diff_pct", 0.25), 1).alias("Q1_%"),
    F.round(F.median("diff_pct"), 1).alias("median_%"),
    F.round(F.avg("diff_pct"), 1).alias("mean_%"),
    F.round(F.percentile_approx("diff_pct", 0.75), 1).alias("Q3_%"),
    F.round(F.max("diff_pct"), 1).alias("max_%"),
).show(truncate=False)

# Par opérateur (premier token)
print(f"--- Écart moyen par opérateur ---")
(
    df_compare
    .withColumn("operator", F.split("agency_name", " ")[0])
    .groupBy("operator")
    .agg(
        F.count("*").alias("n"),
        F.round(F.avg("botn_emissions_co2e"), 2).alias("avg_botn_kg"),
        F.round(F.avg("eea_emissions_co2"), 2).alias("avg_eea_kg"),
        F.round(F.avg("diff_pct"), 1).alias("avg_diff_%"),
        F.round(F.avg("seat_capacity_implied"), 2).alias("avg_capacity_implied")
    )
    .orderBy("avg_diff_%")
    .show(30, truncate=False)
)

# Échantillon des plus gros écarts
print(f"\n--- Top 10 écarts les plus importants ---")
(
    df_compare
    .select("trip_id", "route_type", "seat_capacity_implied", "agency_name",
            F.round("botn_emissions_co2e", 2).alias("botn_kg"),
            F.round("eea_emissions_co2", 2).alias("eea_kg"),
            "diff_pct")
    .orderBy(F.abs("diff_pct").desc())
    .show(10, truncate=False)
)

Trips avec BOTN + EEA : 7157

--- Distribution des écarts BOTN vs EEA (%) ---
+------+-----+--------+------+-----+--------+
|min_% |Q1_% |median_%|mean_%|Q3_% |max_%   |
+------+-----+--------+------+-----+--------+
|-100.0|-33.3|58.0    |4460.0|500.0|303900.0|
+------+-----+--------+------+-----+--------+

--- Écart moyen par opérateur ---
+------------+---+-----------+----------+----------+--------------------+
|operator    |n  |avg_botn_kg|avg_eea_kg|avg_diff_%|avg_capacity_implied|
+------------+---+-----------+----------+----------+--------------------+
|RENFE       |5  |0.16       |2.7       |-92.4     |0.0                 |
|800486      |264|0.11       |0.21      |-49.0     |0.88                |
|IDS         |17 |0.09       |0.19      |-48.9     |0.35                |
|Arverio     |3  |0.18       |0.35      |-48.1     |1.0                 |
|Lugano-Ponte|11 |0.02       |0.05      |-47.6     |0.91                |
|800456      |18 |0.26       |0.48      |-45.2     |0.5          

In [16]:
df_enriched.show(20, truncate=False)

+----------------+-----------+-------+---------------+--------------------------+----------+----------------+---------------+-------------+---------------+---------+---------------+---------------+------------------------------------------+----------+----------+--------------+------------+--------------+-------------+----------+--------+------------------+------------+--------------------+-------+--------------+-------------------+---------------+---------------+-----------------+
|trip_id         |source     |stop_id|route_id       |service_id                |route_type|route_short_name|route_long_name|trip_headsign|trip_short_name|agency_id|agency_name    |agency_timezone|stop_name                                 |stop_lat  |stop_lon  |parent_station|arrival_time|departure_time|stop_sequence|start_date|end_date|segment_dist_m    |days_of_week|city                |country|is_night_train|botn_emissions_co2e|botn_co2_per_km|eea_co2_per_pkm|eea_emissions_co2|
+----------------+----------

In [17]:
# ═══════════════════════════════════════════════════════════
# ESTIMATION VOLUMÉTRIE — Trips simples vs Combinatoire
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F

# Nombre d'arrêts par trip
df_stops_per_trip = (
    df_enriched
    .groupBy("trip_id")
    .agg(F.countDistinct("stop_sequence").alias("n_stops"))
)

n_trips = df_stops_per_trip.count()

# Solution 1 : 1 trajet par trip (A → D)
sol1 = n_trips

# Solution 2 : toutes les paires C(n,2) = n*(n-1)/2
df_combos = df_stops_per_trip.withColumn(
    "n_pairs", (F.col("n_stops") * (F.col("n_stops") - 1) / 2).cast("long")
)
sol2 = df_combos.agg(F.sum("n_pairs")).collect()[0][0]

print(f"Trips uniques          : {n_trips:,}")
print(f"Solution 1 (A→Z)       : {sol1:,} lignes")
print(f"Solution 2 (toutes paires) : {sol2:,} lignes")
print(f"Facteur d'explosion    : x{sol2/sol1:.1f}")

# Distribution du nombre d'arrêts
print(f"\n--- Distribution arrêts par trip ---")
df_stops_per_trip.select(
    F.min("n_stops").alias("min"),
    F.round(F.percentile_approx("n_stops", 0.25)).alias("Q1"),
    F.round(F.median("n_stops")).alias("médiane"),
    F.round(F.avg("n_stops"), 1).alias("moyenne"),
    F.round(F.percentile_approx("n_stops", 0.75)).alias("Q3"),
    F.max("n_stops").alias("max"),
).show(truncate=False)

# Top trips les plus longs (les plus explosifs)
print(f"--- Top 10 trips avec le plus d'arrêts ---")
df_combos.orderBy(F.desc("n_stops")).show(10, truncate=False)

Trips uniques          : 1,682,257
Solution 1 (A→Z)       : 1,682,257 lignes
Solution 2 (toutes paires) : 218,054,406 lignes
Facteur d'explosion    : x129.6

--- Distribution arrêts par trip ---
+---+---+-------+-------+---+---+
|min|Q1 |médiane|moyenne|Q3 |max|
+---+---+-------+-------+---+---+
|2  |7  |11.0   |13.4   |18 |202|
+---+---+-------+-------+---+---+

--- Top 10 trips avec le plus d'arrêts ---
+-------+-------+-------+
|trip_id|n_stops|n_pairs|
+-------+-------+-------+
|926815 |202    |20301  |
|200979 |186    |17205  |
|412826 |186    |17205  |
|1134025|151    |11325  |
|361811 |151    |11325  |
|310003 |151    |11325  |
|112291 |151    |11325  |
|876902 |151    |11325  |
|111002 |151    |11325  |
|533775 |151    |11325  |
+-------+-------+-------+
only showing top 10 rows


In [None]:
# ═══════════════════════════════════════════════════════════
# TABLE FINALE — Toutes paires O/D par trip (218M lignes)
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 1. Distance cumulée depuis le 1er arrêt du trip
#    segment_dist_m(stop_k) = distance de stop_k à stop_k+1
#    cumul(stop_1) = 0, cumul(stop_2) = seg(1), cumul(stop_3) = seg(1)+seg(2)...
w_cumul = (Window.partitionBy("trip_id")
           .orderBy("stop_sequence")
           .rowsBetween(Window.unboundedPreceding, -1))

df_work = (
    df_enriched
    .withColumn("cumul_dist_m",
        F.coalesce(
            F.sum(F.coalesce(F.col("segment_dist_m"), F.lit(0.0))).over(w_cumul),
            F.lit(0.0)
        ))
    .repartition(200, "trip_id")
    .persist()
)

print(f"Stops total : {df_work.count():,}")

# 2. Côté DÉPART
df_dep = df_work.select(
    F.col("trip_id"),
    F.col("stop_sequence").alias("dep_seq"),
    F.col("stop_name").alias("departure_station"),
    F.col("city").alias("departure_city"),
    F.col("country").alias("departure_country"),
    F.col("departure_time"),
    F.col("parent_station").alias("departure_parent_station"),
    F.col("cumul_dist_m").alias("dep_cumul"),
)

# 3. Côté ARRIVÉE
df_arr = df_work.select(
    F.col("trip_id").alias("_trip_id_arr"),
    F.col("stop_sequence").alias("arr_seq"),
    F.col("stop_name").alias("arrival_station"),
    F.col("city").alias("arrival_city"),
    F.col("country").alias("arrival_country"),
    F.col("arrival_time"),
    F.col("parent_station").alias("arrival_parent_station"),
    F.col("cumul_dist_m").alias("arr_cumul"),
)

# 4. Attributs trip-level (1 ligne par trip)
df_trip_attrs = (
    df_work
    .select(
        "trip_id", "source", "route_id", "service_id", "route_type",
        "route_short_name", "route_long_name", "trip_headsign", "trip_short_name",
        "agency_name", "agency_timezone",
        "start_date", "end_date", "days_of_week", "is_night_train", "eea_co2_per_pkm"
    )
    .distinct()
)

# 5. Génération de toutes les paires (dep_seq < arr_seq)
df_pairs = (
    df_dep.join(
        df_arr,
        (F.col("trip_id") == F.col("_trip_id_arr")) &
        (F.col("dep_seq") < F.col("arr_seq")),
        "inner"
    )
    .withColumn("distance_m", F.round(F.col("arr_cumul") - F.col("dep_cumul"), 2))
    .drop("_trip_id_arr", "dep_cumul", "arr_cumul", "dep_seq", "arr_seq")
)

# 6. Jointure attributs + calcul émissions
df_final = (
    df_pairs
    .join(df_trip_attrs, "trip_id", "inner")
    .withColumnRenamed("eea_co2_per_pkm", "co2_per_pkm")
    .withColumn("emissions_co2",
        F.round(F.col("co2_per_pkm") * F.col("distance_m") / 1_000_000, 2))
    .withColumnRenamed("trip_headsign", "destination")
    .withColumnRenamed("start_date", "service_start_date")
    .withColumnRenamed("end_date", "service_end_date")
)

# 7. Sélection finale
final_cols = [
    "source", "trip_id", "destination", "trip_short_name",
    "agency_name", "agency_timezone",
    "service_id", "route_id", "route_type", "route_short_name", "route_long_name",
    "departure_station", "departure_city", "departure_country", "departure_time",
    "departure_parent_station",
    "arrival_station", "arrival_city", "arrival_country", "arrival_time",
    "arrival_parent_station",
    "service_start_date", "service_end_date", "days_of_week",
    "is_night_train", "distance_m", "co2_per_pkm", "emissions_co2"
]

df_final = df_final.select(final_cols)

# 8. Écriture Parquet (partitionné par pays de départ pour perf)
(
    df_final
    .repartition(100)
    .write
    .mode("overwrite")
    .parquet("./processed/routes_all_pairs.parquet")
)

df_work.unpersist()



Stops total : 23,251,426

Table finale : 295,442,818 lignes
+------------+--------------------------------------------------------------------------------------------------------+----------------------------+---------------+-------------------------------------+---------------+----------------+----------------------------------------------+----------+----------------+-----------------------------------------+------------------+----------------+-----------------+--------------+------------------------+----------------+-------------+---------------+------------+----------------------+------------------+----------------+------------+--------------+----------+-----------+-------------+
|source      |trip_id                                                                                                 |destination                 |trip_short_name|agency_name                          |agency_timezone|service_id      |route_id                                      |route_type|route_short_name

In [21]:
# 9. Relecture + stats
df_final = spark.read.parquet("./processed/routes_all_pairs.parquet")
n = df_final.count()
print(f"\n{'='*50}")
print(f"Table finale : {n:,} lignes")
print(f"{'='*50}")

df_final.show(5, truncate=False)
df_final.printSchema()


Table finale : 295,442,818 lignes
+------------+--------------------------------------------------------------------------------------------------------+----------------------------+---------------+-------------------------------------+---------------+----------------+----------------------------------------------+----------+----------------+-----------------------------------------+------------------+----------------+-----------------+--------------+------------------------+----------------+-------------+---------------+------------+----------------------+------------------+----------------+------------+--------------+----------+-----------+-------------+
|source      |trip_id                                                                                                 |destination                 |trip_short_name|agency_name                          |agency_timezone|service_id      |route_id                                      |route_type|route_short_name|route_long_name         

In [23]:
# ═══════════════════════════════════════════════════════════
# CONTRÔLE DOUBLONS SÉMANTIQUES — Même trajet, même horaire
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F

# 1. Doublons exacts : même gare départ/arrivée + mêmes horaires
dup_key = ["departure_station", "departure_time", "arrival_station", "arrival_time"]

df_dup_counts = (
    df_final
    .groupBy(dup_key)
    .agg(
        F.count("*").alias("n_occurrences"),
        F.countDistinct("trip_id").alias("n_trips"),
        F.countDistinct("source").alias("n_sources"),
        F.collect_set("agency_name").alias("agencies"),
        F.collect_set("route_short_name").alias("routes"),
    )
    .where(F.col("n_occurrences") > 1)
    .orderBy(F.desc("n_occurrences"))
)

n_dup_groups = df_dup_counts.count()
n_dup_rows = df_dup_counts.agg(F.sum("n_occurrences")).collect()[0][0] or 0

print(f"Groupes doublons       : {n_dup_groups:,}")
print(f"Lignes impliquées      : {n_dup_rows:,}")
print(f"Lignes uniques         : {df_final.count() - n_dup_rows + n_dup_groups:,}")
print(f"Taux de duplication    : {n_dup_rows / df_final.count() * 100:.1f}%")

print(f"\n--- Top 20 groupes les plus dupliqués ---")
df_dup_counts.show(20, truncate=40)

# 2. Doublons élargis : même ville départ/arrivée + mêmes horaires
#    (cas où la même gare a des noms différents selon les sources)
dup_key_city = ["departure_city", "departure_time", "arrival_city", "arrival_time"]

df_dup_city = (
    df_final
    .groupBy(dup_key_city)
    .agg(
        F.count("*").alias("n_occurrences"),
        F.countDistinct("trip_id").alias("n_trips"),
        F.countDistinct("departure_station").alias("n_dep_stations"),
        F.countDistinct("arrival_station").alias("n_arr_stations"),
    )
    .where((F.col("n_occurrences") > 1) & (F.col("n_dep_stations") > 1))
    .orderBy(F.desc("n_occurrences"))
)

n_alias = df_dup_city.count()
print(f"\n--- Doublons par ville (gares alias) : {n_alias:,} groupes ---")
if n_alias > 0:
    df_dup_city.show(20, truncate=40)
else:
    print("Aucun doublon par alias de gare détecté.")

# 3. Répartition des doublons par source
print(f"\n--- Doublons par combinaison de sources ---")
(
    df_final
    .join(
        df_dup_counts.select(dup_key),
        dup_key,
        "inner"
    )
    .groupBy("source")
    .agg(F.count("*").alias("n_dup_rows"))
    .orderBy(F.desc("n_dup_rows"))
    .show(20, truncate=False)
)

Groupes doublons       : 34,341,309
Lignes impliquées      : 279,090,441
Lignes uniques         : 50,693,686
Taux de duplication    : 94.5%

--- Top 20 groupes les plus dupliqués ---
+-----------------+--------------+---------------+------------+-------------+-------+---------+----------------------------------------+---------------+
|departure_station|departure_time|arrival_station|arrival_time|n_occurrences|n_trips|n_sources|                                agencies|         routes|
+-----------------+--------------+---------------+------------+-------------+-------+---------+----------------------------------------+---------------+
|           Toulon|      12:43:00|        Antibes|    14:17:00|          573|    573|        3|[Transdev, SNCF / TER, ZOU ! Intermét...|     [SUD_IV15]|
|           Cannes|      11:05:00|        Antibes|    11:17:00|          573|    573|        3|[Transdev, SNCF / TER, ZOU ! Intermét...|     [SUD_IV15]|
|           Toulon|      09:43:00|        Antibes|  

In [24]:
# ═══════════════════════════════════════════════════════════
# DÉTECTION — Calendriers englobés (superset/subset)
# Ex: "1111100" englobe "1000000"
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F

# Clé sans days_of_week
route_key = ["departure_station", "departure_time", "arrival_station", "arrival_time"]

# Groupes avec plusieurs calendriers pour le même trajet
df_multi_cal = (
    df_dedup
    .groupBy(route_key)
    .agg(
        F.count("*").alias("n_variants"),
        F.collect_set("days_of_week").alias("calendars"),
    )
    .where(F.col("n_variants") > 1)
)

n_groups = df_multi_cal.count()
n_rows_impacted = df_multi_cal.agg(F.sum("n_variants")).collect()[0][0] or 0

print(f"Groupes multi-calendrier : {n_groups:,}")
print(f"Lignes impliquées        : {n_rows_impacted:,}")
print(f"Lignes potentiellement suppressibles : ~{n_rows_impacted - n_groups:,}")

# Exemples concrets
print(f"\n--- Top 20 cas ---")
(
    df_multi_cal
    .orderBy(F.desc("n_variants"))
    .withColumn("calendars_str", F.concat_ws(" | ", "calendars"))
    .select(*route_key, "n_variants", "calendars_str")
    .show(20, truncate=50)
)

# Fonction : est-ce que cal_a englobe cal_b ? (bitwise)
# "1111100" englobe "1000000" si pour chaque position, a[i] >= b[i]
# En binaire : a & b == b (intersection = le subset)
@F.udf("boolean")
def is_subset(parent, child):
    """True si child est un sous-ensemble strict de parent"""
    if parent == child or len(parent) != 7 or len(child) != 7:
        return False
    for p, c in zip(parent, child):
        if c == '1' and p == '0':
            return False
    return True

# Self-join pour trouver les paires parent/enfant
df_with_cal = df_dedup.select(*route_key, "days_of_week")

df_parents = df_with_cal.alias("parent")
df_children = df_with_cal.alias("child")

df_subsumption = (
    df_parents.join(
        df_children,
        [F.col(f"parent.{k}") == F.col(f"child.{k}") for k in route_key] +
        [F.col("parent.days_of_week") != F.col("child.days_of_week")],
        "inner"
    )
    .where(is_subset(F.col("parent.days_of_week"), F.col("child.days_of_week")))
    .select(
        *[F.col(f"parent.{k}") for k in route_key],
        F.col("parent.days_of_week").alias("parent_cal"),
        F.col("child.days_of_week").alias("child_cal"),
    )
)

n_subsumable = df_subsumption.count()
n_children = df_subsumption.select(*route_key, "child_cal").distinct().count()

print(f"\n--- Résultats subsomption ---")
print(f"Paires parent→enfant  : {n_subsumable:,}")
print(f"Calendriers englobés   : {n_children:,} (suppressibles)")

print(f"\n--- Exemples ---")
df_subsumption.show(20, truncate=50)

Groupes multi-calendrier : 22,535,647
Lignes impliquées        : 82,438,818
Lignes potentiellement suppressibles : ~59,903,171

--- Top 20 cas ---
+-----------------------+--------------+----------------+------------+----------+--------------------------------------------------+
|      departure_station|departure_time| arrival_station|arrival_time|n_variants|                                     calendars_str|
+-----------------------+--------------+----------------+------------+----------+--------------------------------------------------+
|               Lenzburg|      24:19:00|       Zürich Hb|    24:40:00|        54|0110110 | 0111001 | 0010001 | 0001100 | 1111000...|
|                  Aarau|      24:12:00|       Zürich Hb|    24:40:00|        54|0110110 | 0111001 | 0010001 | 0001100 | 1111000...|
|                  Aarau|      24:12:00|        Lenzburg|    24:18:00|        54|0110110 | 0111001 | 0010001 | 0001100 | 1111000...|
|                  Olten|      24:02:00|       Zürich H

Traceback (most recent call last):
  File "/Users/llacroix/dev/ETL/.venv/lib/python3.13/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 233, in manager
    code = worker(sock, authenticated)
  File "/Users/llacroix/dev/ETL/.venv/lib/python3.13/site-packages/pyspark/python/lib/pyspark.zip/pyspark/daemon.py", line 87, in worker
    outfile.flush()
    ~~~~~~~~~~~~~^^
BrokenPipeError: [Errno 32] Broken pipe


In [25]:
# ═══════════════════════════════════════════════════════════
# DÉDUPLICATION V2 — Fusion calendriers OR bitwise
# Sur df_final (AVANT première dédup)
# ═══════════════════════════════════════════════════════════
from pyspark.sql import functions as F
from pyspark.sql.window import Window

route_key = ["departure_station", "departure_time", "arrival_station", "arrival_time"]

# 1. OR bitwise sur days_of_week par groupe
#    "1100000" OR "0011100" → "1111100"
@F.udf("string")
def bitwise_or_calendars(calendars):
    """Fusionne une liste de calendriers 7 chars par OR bitwise"""
    if not calendars:
        return "0000000"
    result = [0] * 7
    for cal in calendars:
        if cal and len(cal) == 7:
            for i, c in enumerate(cal):
                if c == '1':
                    result[i] = 1
    return ''.join(str(x) for x in result)

# 2. Agréger les calendriers + garder la meilleure ligne
#    Priorité : distance non-nulle > trip_short_name renseigné > émissions non-nulles
w_best = (
    Window.partitionBy(route_key)
    .orderBy(
        F.when(F.col("distance_m").isNotNull() & (F.col("distance_m") > 0), 0).otherwise(1),
        F.when(F.col("trip_short_name").isNotNull(), 0).otherwise(1),
        F.when(F.col("emissions_co2").isNotNull(), 0).otherwise(1),
        F.col("trip_id"),
    )
)

# Calendrier fusionné par groupe
df_merged_cal = (
    df_final
    .groupBy(route_key)
    .agg(F.collect_set("days_of_week").alias("all_calendars"))
    .withColumn("merged_days_of_week", bitwise_or_calendars("all_calendars"))
    .drop("all_calendars")
)

# Meilleure ligne par groupe
df_best = (
    df_final
    .withColumn("_rn", F.row_number().over(w_best))
    .where(F.col("_rn") == 1)
    .drop("_rn", "days_of_week")
)

# Jointure : ligne best + calendrier fusionné
df_dedup2 = (
    df_best
    .join(df_merged_cal, route_key, "inner")
    .withColumnRenamed("merged_days_of_week", "days_of_week")
)

# 3. Stats
n_before = df_final.count()
n_after = df_dedup2.count()

print(f"Avant dédup  : {n_before:,}")
print(f"Après dédup  : {n_after:,}")
print(f"Supprimées   : {n_before - n_after:,} ({(n_before - n_after)/n_before*100:.1f}%)")

# Vérif doublons résiduels
n_check = (
    df_dedup2
    .groupBy(route_key)
    .count()
    .where(F.col("count") > 1)
    .count()
)
print(f"Doublons résiduels : {n_check}")

# 4. Sauvegarde
df_dedup2.write.mode("overwrite").parquet("./processed/routes_dedup_v2.parquet")
df_dedup2 = spark.read.parquet("./processed/routes_dedup_v2.parquet")

print(f"\n--- Distribution calendriers fusionnés ---")
(
    df_dedup2
    .groupBy("days_of_week")
    .agg(F.count("*").alias("n"))
    .orderBy(F.desc("n"))
    .show(10, truncate=False)
)

Avant dédup  : 295,442,818
Après dédup  : 50,691,718
Supprimées   : 244,751,100 (82.8%)
Doublons résiduels : 0

--- Distribution calendriers fusionnés ---
+------------+--------+
|days_of_week|n       |
+------------+--------+
|1111111     |19208323|
|0000000     |10344203|
|1111100     |8508771 |
|0000011     |2685592 |
|1111110     |2536903 |
|0000001     |1359097 |
|0000010     |1317463 |
|0000110     |426990  |
|1111101     |401461  |
|0000100     |303513  |
+------------+--------+
only showing top 10 rows
