# Tardis EDA and Data Cleaning

This notebook performs exploratory data analysis and cleans the raw `dataset.csv` file, outputting a `cleaned_dataset.csv` file.

In [24]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
from typing import Set, List, Dict, Optional, Union, Any
from collections import defaultdict, deque
from fuzzywuzzy import fuzz


# Create the directories, including any necessary parent directories

In [25]:
# Define the paths for the directories to create
graphs_dir = "visualizations/graphs"
model_diagnostics_dir = "visualizations/model_diagnostics"

# Create the directories, including any necessary parent directories
try:
    os.makedirs(graphs_dir, exist_ok=True)
    print(f"Directory '{graphs_dir}' created successfully or already exists.")
except OSError as e:
    print(f"Error creating directory '{graphs_dir}': {e}")

try:
    os.makedirs(model_diagnostics_dir, exist_ok=True)
    print(
        f"Directory '{model_diagnostics_dir}' created successfully or already exists."
    )
except OSError as e:
    print(f"Error creating directory '{model_diagnostics_dir}': {e}")

Directory 'visualizations/graphs' created successfully or already exists.
Directory 'visualizations/model_diagnostics' created successfully or already exists.


## Station Name Standardization Functions

In [26]:
def standardize_station_name(name: str) -> Optional[str]:
    """
    Standardize station names by handling common patterns and special characters

    Args:
        name (str): Input station name

    Returns:
        Optional[str]: Standardized station name or None if input is None/NaT
    """
    if pd.isna(name) or name is None:
        return None

    # Convert to uppercase
    name = str(name).upper()

    # Special case replacements before number mapping
    special_cases = {
        # Common misspellings
        "7ILLE": "LILLE",
        "9ILLE": "LILLE",
        "0ARIS": "PARIS",
        "0TALIE": "ITALIE",
        "A2IGNON": "AVIGNON",
        "BORDE3UX": "BORDEAUX",
        "BRE3T": "BREST",
        "LE CREUSOT M1NTCEAU": "LE CREUSOT MONTCEAU",
        "LE CREUSOT MONTCEA4": "LE CREUSOT MONTCEAU",
        "MA7NE": "MARNE",
        "MARNE BA": "MARNE LA",
        "MARSEILLE ST CHA2LES": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CHA3LES": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CHAR6ES": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CHARL2S": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CHARLQS": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CHARLxS": "MARSEILLE ST CHARLES",
        "MARSEILLE ST CbARLES": "MARSEILLE ST CHARLES",
        "MARSEILLE STACHARLES": "MARSEILLE ST CHARLES",
        "MARSEILLE eT CHARLES": "MARSEILLE ST CHARLES",
        "MARSEILLEUST CHARLES": "MARSEILLE ST CHARLES",
        "MARSEILLEpST CHARLES": "MARSEILLE ST CHARLES",
        "MARSEIfLE ST CHARLES": "MARSEILLE ST CHARLES",
        "MARSuILLE ST CHARLES": "MARSEILLE ST CHARLES",
        "MARqEILLE ST CHARLES": "MARSEILLE ST CHARLES",
        "MARzE LA VALLEE": "MARNE LA VALLEE",
        "MDDRID": "MADRID",
        "METL": "METZ",
        "MO1TPELLIER": "MONTPELLIER",
        "NIM8S": "NIMES",
        "NIMqS": "NIMES",
        "P6RIS": "PARIS",
        "PAR7S": "PARIS",
        "PARBS": "PARIS",
        "PARCS": "PARIS",
        "PARES": "PARIS",
        "PARGS": "PARIS",
        "PARHS": "PARIS",
        "PARI8": "PARIS",
        "PARIP": "PARIS",
        "PARIS 3ORD": "PARIS NORD",
        "PARIS 3YON": "PARIS LYON",
        "PARIS CYON": "PARIS LYON",
        "PARIS DYON": "PARIS LYON",
        "PARIS EBT": "PARIS EST",
        "PARIS EgT": "PARIS EST",
        "PARIS HST": "PARIS EST",
        "PARIS IORD": "PARIS NORD",
        "PARIS L1ON": "PARIS LYON",
        "PARIS LBON": "PARIS LYON",
        "PARIS LMON": "PARIS LYON",
        "PARIS LY0N": "PARIS LYON",
        "PARIS LYCN": "PARIS LYON",
        "PARIS LYGN": "PARIS LYON",
        "PARIS LYOI": "PARIS LYON",
        "PARIS LYOM": "PARIS LYON",
        "PARIS LYOQ": "PARIS LYON",
        "PARIS LYOl": "PARIS LYON",
        "PARIS LYOm": "PARIS LYON",
        "PARIS LYOw": "PARIS LYON",
        "PARIS LYfN": "PARIS LYON",
        "PARIS LYtN": "PARIS LYON",
        "PARIS LYwN": "PARIS LYON",
        "PARIS LeON": "PARIS LYON",
        "PARIS LhON": "PARIS LYON",
        "PARIS LmON": "PARIS LYON",
        "PARIS LqON": "PARIS LYON",
        "PARIS LvON": "PARIS LYON",
        "PARIS MO3TPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MOATPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MOHTPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPAHNASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPAMNASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARN9SSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNASEE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNASNE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNASS5": "PARIS MONTPARNASSE",
        "PARIS MONTPARNASSm": "PARIS MONTPARNASSE",
        "PARIS MONTPARNAStE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNATSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNAZSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNAqSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARNnSSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARcASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARiASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPARpASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPURNASSE": "PARIS MONTPARNASSE",
        "PARIS MONTPfRNASSE": "PARIS MONTPARNASSE",
        "PARIS MONjPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MOhTPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MOmTPARNASSE": "PARIS MONTPARNASSE",
        "PARIS MOpTPARNASSE": "PARIS MONTPARNASSE",
        "PARIS NOQD": "PARIS NORD",
        "PARIS NORo": "PARIS NORD",
        "PARIS NORr": "PARIS NORD",
        "PARIS QYON": "PARIS LYON",
        "PARIS VAUGIRvRD": "PARIS VAUGIRARD",
        "PARIS iYON": "PARIS LYON",
        "PARIS nYON": "PARIS LYON",
        "PARIS pST": "PARIS EST",
        "PARIS pYON": "PARIS LYON",
        "PARIS rORD": "PARIS NORD",
        "PARIS5LYON": "PARIS LYON",
        "PARISNEST": "PARIS EST",
        "PARISOLYON": "PARIS LYON",
        "PARISWLYON": "PARIS LYON",
        "PARISYEST": "PARIS EST",
        "PARISgLYON": "PARIS LYON",
        "PARISmMONTPARNASSE": "PARIS MONTPARNASSE",
        "PARIf EST": "PARIS EST",
        "PARIw LYON": "PARIS LYON",
        "PARIx LYON": "PARIS LYON",
        "PARIz LYON": "PARIS LYON",
        "PARJS MONTPARNASSE": "PARIS MONTPARNASSE",
        "PARPS EST": "PARIS EST",
        "PARfS EST": "PARIS EST",
        "PARgS LYON": "PARIS LYON",
        "PARgS NORD": "PARIS NORD",
        "PARrS LYON": "PARIS LYON",
        "PAWIS LYON": "PARIS LYON",
        "PAfIS LYON": "PARIS LYON",
        "PApIS LYON": "PARIS LYON",
        "PAxIS LYON": "PARIS LYON",
        "PERPIGNAX": "PERPIGNAN",
        "PIMES": "NIMES",
        "PIRIS EST": "PARIS EST",
        "POITIERS": "POITIERS",
        "POUAI": "DOUAI",
        "PfRIS LYON": "PARIS LYON",
        "PkRIS MONTPARNASSE": "PARIS MONTPARNASSE",
        "PmRIS LYON": "PARIS LYON",
        "PqRIS LYON": "PARIS LYON",
        "QUIMPEK": "QUIMPER",
        "REIMM": "REIMS",
        "REIMH": "REIMS",
        "REIcS": "REIMS",
        "REN5ES": "RENNES",
        "RENAES": "RENNES",
        "RENNES": "RENNES",
        "RENNbS": "RENNES",
        "RENXES": "RENNES",
        "RRIMS": "REIMS",
        "RRNNES": "RENNES",
        "SAINT ETIENNE CHATEAUCREXX": "SAINT ETIENNE CHATEAUCREUX",
        "ST 3ALO": "ST MALO",
        "ST MALR": "ST MALO",
        "ST PIERRE 9ES CORPS": "ST PIERRE DES CORPS",
        "ST PIERRE DES 4ORPS": "ST PIERRE DES CORPS",
        "ST PIERRE DES CORPi": "ST PIERRE DES CORPS",
        "ST PIERRE DES WORPS": "ST PIERRE DES CORPS",
        "ST PIbRRE DES CORPS": "ST PIERRE DES CORPS",
        "STHMALO": "ST MALO",
        "STRAS3OURG": "STRASBOURG",
        "STRASBOERG": "STRASBOURG",
        "STRASBiURG": "STRASBOURG",
        "STUTTGART": "STUTTGART",
        "STUTTGAoT": "STUTTGART",
        "STUTTGsRT": "STUTTGART",
        "STUTdGART": "STUTTGART",
        "STUkTGART": "STUTTGART",
        "STXMALO": "ST MALO",
        "SyRASBOURG": "STRASBOURG",
        "TOUFON": "TOULON",
        "TOUJOUSE MATABIAU": "TOULOUSE MATABIAU",
        "TOULON": "TOULON",
        "TOULOUSE MATABIAU": "TOULOUSE MATABIAU",
        "TOULOUSEKMATABIAU": "TOULOUSE MATABIAU",
        "TOULyN": "TOULON",
        "TOURCO3NG": "TOURCOING",
        "TOURCOING": "TOURCOING",
        "TOURS": "TOURS",
        "TdULON": "TOULON",
        "UEIMS": "REIMS",
        "VALENCE ABIXAN TGV": "VALENCE ALIXAN TGV",
        "VALENCE ALIQAN TGV": "VALENCE ALIXAN TGV",
        "VALENCE ALIXAN T4V": "VALENCE ALIXAN TGV",
        "VALENCE ALIXAN TGM": "VALENCE ALIXAN TGV",
        "VALENCE ALIXAN TxV": "VALENCE ALIXAN TGV",
        "VALENCERALIXAN TGV": "VALENCE ALIXAN TGV",
        "VANNES": "VANNES",
        "VAdENCE ALIXAN TGV": "VALENCE ALIXAN TGV",
        "VOUAI": "DOUAI",
        "VhNNES": "VANNES",
        "XOULON": "TOULON",
        "ZUBICH": "ZURICH",
        "ZURICH": "ZURICH",
        "ZURRCH": "ZURICH",
        "ZURlCH": "ZURICH",
        "ZUUICH": "ZURICH",
        "bIMES": "NIMES",
        "eARIS LYON": "PARIS LYON",
        "iARIS LYON": "PARIS LYON",
        "jALENCE ALIXAN TGV": "VALENCE ALIXAN TGV",
        "jENNES": "RENNES",
        "nA ROCHELLE VILLE": "LA ROCHELLE VILLE",
        "rARIS EST": "PARIS EST",
        "yARIS MONTPARNASSE": "PARIS MONTPARNASSE",
        "zARIS LYON": "PARIS LYON",
        # Additional special cases
        "AIJON VILLE": "DIJON VILLE",
        "AIX EN PROVENCE TGW": "AIX EN PROVENCE TGV",
        "ANGEULEME": "ANGOULEME",
        "ANNECQ": "ANNECY",
        "ANNEY": "ANNECY",
        "AQRAS": "ARRAS",
        "ARRDS": "ARRAS",
        "AVAGNON TGV": "AVIGNON TGV",
        "BARCELENA": "BARCELONA",
        "BIMES": "NIMES",
        "BROST": "BREST",
        "BWEST": "BREST",
        "DDNKERQUE": "DUNKERQUE",
        "DOUI": "DOUAI",
        "FRANCAORT": "FRANCFORT",
        "GENEKE": "GENEVE",
        "GRENEBLE": "GRENOBLE",
        "IALIE": "ITALIE",
        "ITPLIE": "ITALIE",
        "JENNES": "RENNES",
        "LJLLE": "LILLE",
        "LOLLE": "LILLE",
        "MTTZ": "METZ",
        "NAMES": "NIMES",
        "NANCB": "NANCY",
        "NANCJ": "NANCY",
        "NANTS": "NANTES",
        "NICE DILLE": "NICE VILLE",
        "NIMQS": "NIMES",
        "NMNCY": "NANCY",
        "PARGS LYON": "PARIS LYON",
        "PARIS RORD": "PARIS NORD",
        "PARIS VAUGIRVRD": "PARIS VAUGIRARD",
        "PERIGNAN": "PERPIGNAN",
        "RENNBS": "RENNES",
        "STRASBIURG": "STRASBOURG",
        "STUTTGSRT": "STUTTGART",
        "TAUSANNE": "LAUSANNE",
        "TDULON": "TOULON",
        "TOULYN": "TOULON",
        "VHNNES": "VANNES",
        "ZURLCH": "ZURICH",
        # Additional special cases for remaining inconsistencies
        "AIRAS": "ARRAS",
        "ANEECY": "ANNECY",
        "ANGEAS SAINT LAUD": "ANGERS SAINT LAUD",
        "ANGOULLME": "ANGOULEME",
        "ANNECJ": "ANNECY",
        "ANNECW": "ANNECY",
        "ANNNCY": "ANNECY",
        "AQNECY": "ANNECY",
        "ARRAI": "ARRAS",
        "ARWAS": "ARRAS",
        "AVIGNONLTGV": "AVIGNON TGV",
        "AVNECY": "ANNECY",
        "AVRAS": "ARRAS",
        "BANTES": "NANTES",
        "BARCELOBA": "BARCELONA",
        "BRET": "BREST",
        "DIJON VILNE": "DIJON VILLE",
        "DJNKERQUE": "DUNKERQUE",
        "DOUAZ": "DOUAI",
        "DOUI": "DOUAI",
        "FRANCFOST": "FRANCFORT",
        "GENEGE": "GENEVE",
        "GENENE": "GENEVE",
        "GOITIERS": "POITIERS",
        "GRCNOBLE": "GRENOBLE",
        "IALIE": "ITALIE",
        "ITALIP": "ITALIE",
        "LAUSANWE": "LAUSANNE",
        "LAVLL": "LAVAL",
        "LE CREUSOT MONTCEA MONTCHANIN": "LE CREUSOT MONTCEAU MONTCHANIN",
        "LE MAAS": "LE MANS",
        "LILKE": "LILLE",
        "LILL": "LILLE",
        "LILME": "LILLE",
        "LKLLE": "LILLE",
        "MULHKUSE VILLE": "MULHOUSE VILLE",
        "MUTZ": "METZ",
        "NAJCY": "NANCY",
        "NANEY": "NANTES",
        "NANJES": "NANTES",
        "NANZY": "NANTES",
        "NAVAL": "LAVAL",
        "NICE VILQE": "NICE VILLE",
        "NIMEG": "NIMES",
        "PARGS NORD": "PARIS NORD",
        "PARIS IST": "PARIS EST",
        "PARS LYON": "PARIS LYON",
        "PERIGNAN": "PERPIGNAN",
        "QUDMPER": "QUIMPER",
        "REIMW": "REIMS",
        "REIS": "REIMS",
        "RENNAS": "RENNES",
        "RENNEN": "RENNES",
        "RENNVS": "RENNES",
        "RENPES": "RENNES",
        "REVNES": "RENNES",
        "RZNNES": "RENNES",
        "SA MALO": "ST MALO",
        "SD PIERRE DES CORPS": "ST PIERRE DES CORPS",
        "STNASBOURG": "STRASBOURG",
        "STUTTGAOT": "STUTTGART",
        "TANNES": "NANTES",
        "TOUDON": "TOULON",
        "TOURCOXNG": "TOURCOING",
        "TOURY": "TOURS",
        "TURS": "TOURS",
        "TVULON": "TOULON",
        "UNNECY": "ANNECY",
        "VANCY": "NANCY",
        "VANKES": "VANNES",
        "XANTES": "NANTES",
        "ZARICH": "ZURICH",
        "ZURICS": "ZURICH",
    }

    # Apply special case replacements
    for wrong, correct in special_cases.items():
        if name == wrong:
            return correct

    # Replace common patterns
    replacements = {
        r"PARIS\s+([A-Z]+)": r"PARIS \1",  # Standardize PARIS X format
        r"([A-Z]+)\s+ST\s+([A-Z]+)": r"\1 ST \2",  # Standardize X ST Y format
        r"([A-Z]+)\s+TGV": r"\1 TGV",  # Standardize X TGV format
        r"([A-Z]+)\s+VILLE": r"\1 VILLE",  # Standardize X VILLE format
        r"([A-Z]+)\s+PART\s+([A-Z]+)": r"\1 PART \2",  # Standardize X PART Y format
        r"([A-Z]+)\s+MATABIAU": r"\1 MATABIAU",  # Standardize X MATABIAU format
        r"([A-Z]+)\s+CHALLES\s+LES\s+EAUX": r"\1 CHALLES LES EAUX",  # Standardize X CHALLES LES EAUX format
        r"([A-Z]+)\s+SAINT\s+([A-Z]+)": r"\1 SAINT \2",  # Standardize X SAINT Y format
        r"([A-Z]+)\s+CHATEAUCREUX": r"\1 CHATEAUCREUX",  # Standardize X CHATEAUCREUX format
        r"([A-Z]+)\s+ALIXAN\s+TGV": r"\1 ALIXAN TGV",  # Standardize X ALIXAN TGV format
        r"([A-Z]+)\s+FRANCHE\s+COMTE\s+TGV": r"\1 FRANCHE COMTE TGV",  # Standardize X FRANCHE COMTE TGV format
        r"PARIS\s+([A-Z]+)\s+([A-Z]+)": r"PARIS \1 \2",  # Handle PARIS X Y format
        r"([A-Z]+)\s+([A-Z]+)\s+([A-Z]+)": r"\1 \2 \3",  # Standardize triple word format
        r"([A-Z]+)\s+([A-Z]+)": r"\1 \2",  # Standardize double word format
        r"([A-Z]+)LTGV": r"\1 TGV",  # Fix cases where TGV is attached without space
        r"([A-Z]+)VILQE": r"\1 VILLE",  # Fix cases where VILLE is misspelled
        r"([A-Z]+)PART": r"\1 PART",  # Fix cases where PART is attached without space
    }

    for pattern, replacement in replacements.items():
        name = re.sub(pattern, replacement, name)

    # Remove special characters and numbers
    name = re.sub(r"[^A-Z\s]", "", name)

    # Remove extra spaces
    name = re.sub(r"\s+", " ", name).strip()

    return name


def get_unique_stations(df: pd.DataFrame) -> Set[str]:
    """
    Get unique station names from both departure and arrival columns

    Args:
        df (pd.DataFrame): Input dataframe

    Returns:
        Set[str]: Set of unique station names
    """
    # Filter out NaN values and convert to set
    departure_stations = set(df["Departure station"].dropna().unique())
    arrival_stations = set(df["Arrival station"].dropna().unique())
    return departure_stations | arrival_stations


def find_similar_stations(stations: Set[str], threshold: int = 85) -> List[tuple]:
    """
    Find similar station names using fuzzy matching

    Args:
        stations (Set[str]): Set of station names
        threshold (int): Similarity threshold (0-100)

    Returns:
        List[tuple]: List of similar station pairs
    """
    similar_pairs = []
    stations_list = sorted(list(stations))  # Sort for consistent results

    for i, station1 in enumerate(stations_list):
        for station2 in stations_list[i + 1 :]:
            # Calculate similarity ratio
            ratio = fuzz.ratio(station1, station2)
            if ratio >= threshold:
                similar_pairs.append((station1, station2, ratio))

    return similar_pairs


def create_station_mapping(df: pd.DataFrame, threshold: int = 80) -> Dict[str, str]:
    """
    Create a robust mapping of misspelled station names to their most frequent (canonical) version using clustering.

    Args:
        df (pd.DataFrame): Input dataframe
        threshold (int): Similarity threshold (0-100)

    Returns:
        Dict[str, str]: Mapping of incorrect to correct station names
    """

    unique_stations = list(get_unique_stations(df))
    visited = set()
    clusters = []

    # Build clusters of mutually similar names
    for i, station in enumerate(unique_stations):
        if station in visited:
            continue
        cluster = set([station])
        queue = deque([station])
        visited.add(station)
        while queue:
            current = queue.popleft()
            for other in unique_stations:
                if other not in visited and fuzz.ratio(current, other) >= threshold:
                    cluster.add(other)
                    queue.append(other)
                    visited.add(other)
        clusters.append(cluster)

    # For each cluster, pick the most frequent spelling as canonical
    station_mapping = {}
    for cluster in clusters:
        if len(cluster) > 1:
            # Count occurrences of each station name
            station_counts = {
                s: (
                    len(df[df["Departure station"] == s])
                    + len(df[df["Arrival station"] == s])
                )
                for s in cluster
            }
            canonical = max(station_counts.items(), key=lambda x: x[1])[0]
            for s in cluster:
                if s != canonical:
                    station_mapping[s] = canonical
    return station_mapping

## Data Cleaning Functions

In [27]:
def load_data(file_path: str) -> Optional[pd.DataFrame]:
    """
    Load the dataset from CSV file

    Args:
        file_path (str): Path to the CSV file

    Returns:
        Optional[pd.DataFrame]: Loaded dataframe or None if error occurs
    """
    try:
        # Read CSV with semicolon separator
        df = pd.read_csv(file_path, sep=";")
        print(f"Successfully loaded data with shape: {df.shape}")
        return df
    except Exception as e:
        print(f"Error loading data: {e}")
        return None


def clean_station_names(df: pd.DataFrame, threshold: int = 85) -> pd.DataFrame:
    """
    Clean and standardize station names using fuzzy matching

    Args:
        df (pd.DataFrame): Input dataframe
        threshold (int): Similarity threshold (0-100)

    Returns:
        pd.DataFrame: Dataframe with cleaned station names
    """
    df_clean = df.copy()

    # First standardize all station names
    df_clean["Departure station"] = df_clean["Departure station"].apply(
        standardize_station_name
    )
    df_clean["Arrival station"] = df_clean["Arrival station"].apply(
        standardize_station_name
    )

    # Create station mapping
    station_mapping = create_station_mapping(df_clean, threshold)

    # Print the corrections being made
    if station_mapping:
        print("\nStation name corrections:")
        for incorrect, correct in sorted(station_mapping.items()):
            print(f"{incorrect} -> {correct}")

    # Apply mapping to both departure and arrival stations
    df_clean["Departure station"] = df_clean["Departure station"].map(
        lambda x: station_mapping.get(x, x) if pd.notna(x) else x
    )
    df_clean["Arrival station"] = df_clean["Arrival station"].map(
        lambda x: station_mapping.get(x, x) if pd.notna(x) else x
    )

    return df_clean


def print_extreme_delay_values(df: pd.DataFrame, delay_threshold: int = 600):
    """
    Print the most extreme values for each delay column.
    """
    delay_cols = [
        col
        for col in df.columns
        if "delay" in col.lower() and df[col].dtype in [float, int]
    ]
    for col in delay_cols:
        print(f"\nTop 5 highest values for {col}:")
        print(df[[col]].sort_values(by=col, ascending=False).head(5))
        print(f"Top 5 lowest values for {col}:")
        print(df[[col]].sort_values(by=col, ascending=True).head(5))
        print(f"Values above {delay_threshold} for {col}:")
        print(df[df[col] > delay_threshold][[col]].head())
        print(f"Negative values for {col}:")
        print(df[df[col] < 0][[col]].head())


def cap_and_clean_delay_outliers(
    df: pd.DataFrame, delay_threshold: int = 300
) -> pd.DataFrame:
    """
    Cap or remove outliers in delay columns (e.g., delays > 300 min or negative values),
    clean percentage columns, and handle error codes in all relevant columns.
    """
    df_clean = df.copy()
    # Clean delay columns
    delay_cols = [
        col
        for col in df_clean.columns
        if "delay" in col.lower() and df_clean[col].dtype in [float, int]
    ]
    for col in delay_cols:
        # Replace error codes with NaN
        df_clean[col] = df_clean[col].replace([-1, 999, 9999], np.nan)
        # Cap to [0, delay_threshold]
        df_clean.loc[(df_clean[col] < 0) | (df_clean[col] > delay_threshold), col] = (
            np.nan
        )
    # Clean percentage columns
    pct_cols = [
        col
        for col in df_clean.columns
        if (
            ("pct delay" in col.lower() or "percentage" in col.lower())
            and df_clean[col].dtype in [float, int]
        )
    ]
    for col in pct_cols:
        df_clean[col] = df_clean[col].replace([-1, 999, 9999], np.nan)
        df_clean.loc[(df_clean[col] < 0) | (df_clean[col] > 100), col] = np.nan
    # Clean number of trains columns
    train_num_cols = [
        col
        for col in df_clean.columns
        if ("number of trains" in col.lower() or "number_of_trains" in col.lower())
        and df_clean[col].dtype in [float, int]
    ]
    for col in train_num_cols:
        df_clean[col] = df_clean[col].replace([-1, 999, 9999], np.nan)
        df_clean.loc[df_clean[col] < 0, col] = np.nan
    return df_clean


def clean_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    Clean the dataset by handling missing values, duplicates, and data types

    Args:
        df (pd.DataFrame): Input dataframe

    Returns:
        pd.DataFrame: Cleaned dataframe
    """
    # Make a copy to avoid modifying the original
    df_clean = df.copy()

    # Print unique station names before cleaning
    print("\nUnique station names before cleaning:")
    print(
        "Departure stations:", sorted(df_clean["Departure station"].dropna().unique())
    )
    print("Arrival stations:", sorted(df_clean["Arrival station"].dropna().unique()))

    # Clean station names using fuzzy matching with a lower threshold
    print("\nCleaning station names...")
    df_clean = clean_station_names(
        df_clean, threshold=85
    )  # Lower threshold to catch more variations

    # Print unique station names after cleaning
    print("\nUnique station names after cleaning:")
    print(
        "Departure stations:", sorted(df_clean["Departure station"].dropna().unique())
    )
    print("Arrival stations:", sorted(df_clean["Arrival station"].dropna().unique()))

    # Display initial info
    print("\nInitial data info:")
    print(df_clean.info())

    # Check for missing values
    print("\nMissing values per column:")
    print(df_clean.isnull().sum())

    # Check for duplicates
    print(f"\nNumber of duplicate rows: {df_clean.duplicated().sum()}")

    # Remove duplicates if any
    df_clean = df_clean.drop_duplicates()

    # Print and clean extreme delay values
    print_extreme_delay_values(df_clean, delay_threshold=600)
    df_clean = cap_and_clean_delay_outliers(df_clean, delay_threshold=300)

    return df_clean

## EDA Functions

In [28]:
# Histogram of all numerical columns
def save_distribution_charts(df):
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    for col in numerical_cols:
        plt.figure(figsize=(10, 6))
        sns.histplot(data=df, x=col)
        plt.title(f"Distribution of {col}")
        plt.xlabel(col)
        plt.ylabel("Frequency")
        safe_col_name = (
            col.lower()
            .replace(" ", "_")
            .replace(">", "gt")
            .replace("<", "lt")
            .replace("(", "")
            .replace(")", "")
        )
        plt.tight_layout()  # Added tight_layout
        plt.savefig(f"visualizations/graphs/distribution_{safe_col_name}.png")
        plt.close()


# Heatmap to show correlation between numerical columns
def save_correlation_matrix(df):
    numerical_cols = df.select_dtypes(include=[np.number]).columns
    plt.figure(figsize=(18, 14))  # Increased figure size
    sns.heatmap(
        df[numerical_cols].corr(), annot=True, cmap="coolwarm", annot_kws={"size": 7}
    )  # Decreased annotation size
    plt.title("Correlation Matrix")
    plt.tight_layout()
    plt.savefig("visualizations/graphs/correlations_between_columns.png")
    plt.close()


# Line chart to show how average delay evolves over time
def save_average_delay_line_chart(df):
    if (
        "Date" in df.columns
        and "Average delay of all trains at departure" in df.columns
    ):
        df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
        delay_by_day = df.groupby("Date")[
            "Average delay of all trains at departure"
        ].mean()
        plt.figure(figsize=(12, 6))
        plt.plot(delay_by_day.index, delay_by_day.values)
        plt.title("Average Departure Delay Over Time")
        plt.xlabel("Date")
        plt.ylabel("Delay (minutes)")
        plt.grid(True)
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/delay_over_time.png")
        plt.close()


# Line chart to show number of trains scheduled each day
def save_trains_scheduled_line_chart(df):
    if "Date" in df.columns and "Number of scheduled trains" in df.columns:
        df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
        daily_trains = df.groupby("Date")["Number of scheduled trains"].sum()
        plt.figure(figsize=(12, 6))
        plt.plot(daily_trains.index, daily_trains.values, color="green")
        plt.title("Number of Trains Per Day")
        plt.xlabel("Date")
        plt.ylabel("Train Count")
        plt.grid(True)
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/trains_per_day.png")
        plt.close()


# Pie chart showing the percentage of cancelled vs non-cancelled trains
def save_cancellation_pie_chart(df):
    if (
        "Number of cancelled trains" in df.columns
        and "Number of scheduled trains" in df.columns
    ):
        cancelled = df["Number of cancelled trains"].sum()
        total = df["Number of scheduled trains"].sum()
        not_cancelled = total - cancelled
        plt.figure(figsize=(6, 6))
        plt.pie(
            [cancelled, not_cancelled],
            labels=["Cancelled", "Not Cancelled"],
            autopct="%1.1f%%",
        )
        plt.title("Train Cancellations")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/cancellation_ratio.png")
        plt.close()


# Horizontal bar chart showing the 5 stations with the highest average arrival delays
def save_top_delayed_stations_bar_chart(df):
    if (
        "Arrival station" in df.columns
        and "Average delay of all trains at arrival" in df.columns
    ):
        station_delays = df.groupby("Arrival station")[
            "Average delay of all trains at arrival"
        ].mean()
        top5 = station_delays.sort_values(ascending=False).head(5)
        plt.figure(figsize=(10, 6))
        top5.plot(kind="barh", color="skyblue")
        plt.title("Most Delayed Arrival Stations")
        plt.xlabel("Average Delay (min)")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/most_delayed_stations.png")
        plt.close()


# Pie chart showing average percentage of delays due to different causes
def save_delay_causes_pie_chart(df):
    causes = [
        "Pct delay due to external causes",
        "Pct delay due to infrastructure",
        "Pct delay due to traffic management",
        "Pct delay due to rolling stock",
        "Pct delay due to station management and equipment reuse",
        "Pct delay due to passenger handling (crowding, disabled persons, connections)",
    ]
    available_causes = [c for c in causes if c in df.columns]
    if available_causes:
        mean_causes = df[available_causes].mean()
        plt.figure(figsize=(7, 7))
        clean_labels = mean_causes.index.str.replace("Pct delay due to ", "")
        plt.pie(mean_causes, labels=clean_labels, autopct="%1.1f%%")
        plt.title("Average Delay Causes")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/delay_causes.png")
        plt.close()


# KDE and boxplots for delay columns
def save_delay_kde_and_boxplots(df):
    delay_cols = [
        col
        for col in df.columns
        if "delay" in col.lower() and df[col].dtype in [np.float64, np.int64]
    ]
    for col in delay_cols:
        # Histogram with KDE
        plt.figure(figsize=(10, 6))
        sns.histplot(df[col].dropna(), kde=True, bins=30)
        plt.title(f"Distribution and KDE of {col}")
        plt.xlabel(col)
        plt.ylabel("Frequency")
        safe_col_name = col.lower().replace(" ", "_")
        plt.tight_layout()  # Added tight_layout
        plt.savefig(f"visualizations/graphs/kde_{safe_col_name}.png")
        plt.close()
        # Boxplot
        plt.figure(figsize=(6, 6))
        sns.boxplot(y=df[col].dropna())
        plt.title(f"Boxplot of {col}")
        plt.ylabel(col)
        plt.tight_layout()  # Added tight_layout
        plt.savefig(f"visualizations/graphs/boxplot_{safe_col_name}.png")
        plt.close()


# Bar plots for delay by hour, day, and month
def save_delay_by_time_plots(df):
    if (
        "Date" in df.columns
        and "Average delay of all trains at departure" in df.columns
    ):
        df["Date"] = pd.to_datetime(df["Date"], errors="coerce")
        df = df.dropna(subset=["Date"])
        df = df.copy()
        df["Hour"] = df["Date"].dt.hour
        df["DayOfWeek"] = df["Date"].dt.dayofweek
        df["Month"] = df["Date"].dt.month
        # By hour
        plt.figure(figsize=(10, 6))
        df.groupby("Hour")["Average delay of all trains at departure"].mean().plot(
            kind="bar"
        )
        plt.title("Average Departure Delay by Hour of Day")
        plt.xlabel("Hour")
        plt.ylabel("Average Delay (min)")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/delay_by_hour.png")
        plt.close()
        # By day of week
        plt.figure(figsize=(10, 6))
        df.groupby("DayOfWeek")["Average delay of all trains at departure"].mean().plot(
            kind="bar"
        )
        plt.title("Average Departure Delay by Day of Week")
        plt.xlabel("Day of Week (0=Mon)")
        plt.ylabel("Average Delay (min)")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/delay_by_dayofweek.png")
        plt.close()
        # By month
        plt.figure(figsize=(10, 6))
        df.groupby("Month")["Average delay of all trains at departure"].mean().plot(
            kind="bar"
        )
        plt.title("Average Departure Delay by Month")
        plt.xlabel("Month")
        plt.ylabel("Average Delay (min)")
        plt.tight_layout()  # Added tight_layout
        plt.savefig("visualizations/graphs/delay_by_month.png")
        plt.close()


# Heatmap for route delays (departure vs arrival station)
def save_route_delay_heatmap(df):
    if (
        "Departure station" in df.columns
        and "Arrival station" in df.columns
        and "Average delay of all trains at arrival" in df.columns
    ):
        pivot = df.pivot_table(
            index="Departure station",
            columns="Arrival station",
            values="Average delay of all trains at arrival",
            aggfunc="mean",
        )
        plt.figure(figsize=(18, 12))
        sns.heatmap(pivot, cmap="coolwarm", linewidths=0.5)
        plt.title("Route Heatmap: Average Arrival Delay by Route")
        plt.xlabel("Arrival Station")
        plt.ylabel("Departure Station")
        plt.tight_layout()
        plt.savefig("visualizations/graphs/route_delay_heatmap.png")
        plt.close()


# Main EDA function that calls all the helper chart functions
def perform_eda(df: pd.DataFrame) -> None:
    print("\nBasic statistics:")
    print(df.describe())  # Show mean, std, min, max, etc.

    if not os.path.exists("visualizations"):
        os.makedirs("visualizations")

    save_distribution_charts(df)
    save_correlation_matrix(df)
    save_average_delay_line_chart(df)
    save_trains_scheduled_line_chart(df)
    save_cancellation_pie_chart(df)
    save_top_delayed_stations_bar_chart(df)
    save_delay_causes_pie_chart(df)
    save_delay_kde_and_boxplots(df)
    save_delay_by_time_plots(df)
    save_route_delay_heatmap(df)

## Workflow: Load, Clean, and EDA

In [29]:
# --- 1. Load the data ---
DATA_PATH = "dataset.csv"
df = load_data(DATA_PATH)

if df is None:
    print("Exiting due to data loading error.")
else:
    # Display initial info
    print("\nInitial data info (before cleaning):")
    df.info()

    # Check for missing values before cleaning
    print("\nMissing values per column (before cleaning):")
    print(df.isnull().sum())

Successfully loaded data with shape: (10840, 26)

Initial data info (before cleaning):
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10840 entries, 0 to 10839
Data columns (total 26 columns):
 #   Column                                                                         Non-Null Count  Dtype  
---  ------                                                                         --------------  -----  
 0   Date                                                                           10299 non-null  object 
 1   Service                                                                        10288 non-null  object 
 2   Departure station                                                              10298 non-null  object 
 3   Arrival station                                                                10293 non-null  object 
 4   Average journey time                                                           10010 non-null  float64
 5   Number of scheduled trains                 

In [30]:
# --- 2. Clean the data ---
if df is not None:
    print("\nStarting data cleaning...")
    df_clean = clean_data(df)
    print("\nData cleaning complete.")

    # Display info after cleaning
    print("\nData info (after cleaning):")
    df_clean.info()

    # Check for missing values after cleaning
    print("\nMissing values per column (after cleaning):")
    print(df_clean.isnull().sum())


Starting data cleaning...

Unique station names before cleaning:
Departure stations: ['0ARIS EST', '0TALIE', '7ILLE', '9ILLE', 'A2IGNON TGV', 'AIJON VILLE', 'AIX EN PROVENCE TGV', 'AIX EN PROVENCE TGw', 'ANGERS SAINT LAUD', 'ANGOULEIE', 'ANGOULEME', 'ANGOULlME', 'ANGeULEME', 'ANNE7Y', 'ANNECY', 'ANNECq', 'ANNEYY', 'ARRAS', 'ARRdS', 'AVAGNON TGV', 'AVIGNON TGV', 'AVIGNONlTGV', 'AqRAS', 'BARCELOBA', 'BARCELONA', 'BARIS LYON', 'BELLEGARDE (AIN)', 'BELLEGARDE (AINl', 'BELLEGARDE (AuN)', 'BELLEGARDED(AIN)', 'BESANCON FRANCHE COMTE TGV', 'BESANCtN FRANCHE COMTE TGV', 'BORDE3UX ST JEAN', 'BORDEAUX ST JEAN', 'BORDEAUX ST JEQN', 'BORDEAUX ST dEAN', 'BORDEAUX ST iEAN', 'BORDEAUXvST JEAN', 'BORDEAUl ST JEAN', 'BRE3T', 'BREST', 'BRoST', 'BwEST', 'CHAMBERY CHALLES LES EAUX', 'CHAMBERY CHALLES LES fAUX', 'CHAMBQRY CHALLES LES EAUX', 'DIJON VILLE', 'DIJON VILnE', 'DIJON VJLLE', 'DIJlN VILLE', 'DOU6I', 'DOUAI', 'DQNKERQUE', 'DUNIERQUE', 'DUNKERQUE', 'DdNKERQUE', 'DjNKERQUE', 'EARIS LYON', 'FARSEILLE 

In [31]:
# --- 3. Perform EDA ---
if 'df_clean' in locals() and df_clean is not None:
    print("\nPerforming EDA...")
    perform_eda(df_clean)
    print("EDA complete. Check the 'visualizations/graphs' directory for charts.")
else:
    print("Cannot perform EDA: Cleaned dataframe is not available.")


Performing EDA...

Basic statistics:
       Average journey time  Number of scheduled trains  \
count           9842.000000                 9853.000000   
mean             250.907742                  395.417824   
std              352.256931                  568.991652   
min                0.000000                    0.000000   
25%              104.000000                  153.000000   
50%              168.000000                  236.000000   
75%              230.000000                  403.000000   
max             1702.593801                 2681.206158   

       Number of cancelled trains  Cancellation comments  \
count                 9844.000000                    0.0   
mean                    14.001022                    NaN   
std                     29.947415                    NaN   
min                      0.000000                    NaN   
25%                      0.000000                    NaN   
50%                      2.000000                    NaN   
75%       

In [32]:
# --- 4. Save cleaned data ---
if 'df_clean' in locals() and df_clean is not None:
    CLEANED_DATA_PATH = "cleaned_dataset.csv"
    df_clean.to_csv(CLEANED_DATA_PATH, index=False)
    print(f"\nCleaned dataset saved to {CLEANED_DATA_PATH}")
else:
    print("Cannot save cleaned data: Cleaned dataframe is not available.")


Cleaned dataset saved to cleaned_dataset.csv
