In [1]:
import pandas as pd
import networkx as nx
from datetime import datetime, timedelta
import logging
import json
import matplotlib.pyplot as plt
from typing import Optional
from typing import List

In [2]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
class TrainRouter:
    def __init__(self, tgv_folder, ter_folder):
        self.graph = nx.MultiDiGraph()
        self.stops_info = {}
        self.stop_point_to_area = {}
        logger.info("Initialisation du router...")
        self.load_data(tgv_folder, "TGV")
        self.load_data(ter_folder, "TER")
        logger.info(f"Graphe construit avec {self.graph.number_of_nodes()} nœuds et {self.graph.number_of_edges()} arêtes")

    def load_data(self, folder_path, train_type):
        logger.info(f"Chargement des données {train_type} depuis {folder_path}")
        try:
            stops_df = pd.read_csv(f"{folder_path}/stops.txt")
            stop_times_df = pd.read_csv(f"{folder_path}/stop_times.txt")
            
            for _, row in stops_df.iterrows():
                stop_id = row['stop_id']
                if row['location_type'] == 0:
                    parent_station = row['parent_station']
                    self.stop_point_to_area[stop_id] = parent_station if parent_station else stop_id
                
                self.stops_info[stop_id] = {
                    'name': row['stop_name'],
                    'lat': row['stop_lat'],
                    'lon': row['stop_lon'],
                    'type': 'area' if row['location_type'] == 1 else 'point'
                }

            trip_groups = stop_times_df.groupby('trip_id')
            edges_added = 0
            
            for trip_id, group in trip_groups:
                sorted_stops = group.sort_values('stop_sequence')
                
                for i in range(len(sorted_stops) - 1):
                    current_stop = sorted_stops.iloc[i]
                    next_stop = sorted_stops.iloc[i + 1]
                    
                    current_point_id = current_stop['stop_id']
                    next_point_id = next_stop['stop_id']
                    
                    try:
                        departure_time = self._time_to_minutes(current_stop['departure_time'])
                        arrival_time = self._time_to_minutes(next_stop['arrival_time'])
                        duration = arrival_time - departure_time
                        
                        if duration > 0:
                            self.graph.add_edge(
                                current_point_id,
                                next_point_id,
                                weight=duration,
                                train_type=train_type,
                                departure=current_stop['departure_time'],
                                arrival=next_stop['arrival_time'],
                                departure_minutes=departure_time
                            )
                            edges_added += 1
                            
                    except ValueError as e:
                        logger.warning(f"Erreur de conversion de temps pour le trajet {trip_id}: {e}")
                        continue
            
            logger.info(f"{edges_added} arêtes ajoutées pour {train_type}")
            
        except Exception as e:
            logger.error(f"Erreur lors du chargement des données {train_type}: {e}")
            raise

    def _time_to_minutes(self, time_str):
        try:
            hours, minutes, seconds = map(int, time_str.split(':'))
            return hours * 60 + minutes
        except:
            raise ValueError(f"Format de temps invalide: {time_str}")

    def find_station_ids(self, city_name):
        city_name_lower = city_name.lower()
        matching_stops = []
        
        for stop_id, info in self.stops_info.items():
            if city_name_lower in info['name'].lower() and info['type'] == 'point':
                matching_stops.append(stop_id)
                logger.info(f"Station trouvée pour {city_name}: {info['name']} ({stop_id})")
        
        if not matching_stops:
            logger.warning(f"Aucune station trouvée pour {city_name}")
        
        return matching_stops

    def _dijkstra_with_time(self, start_stop, end_stop, current_minutes):
        distances = {node: float('inf') for node in self.graph.nodes()}
        distances[start_stop] = 0
        pq = [(0, start_stop, current_minutes, [])]
        visited = set()
        
        while pq:
            total_time, current, current_time, path = min(pq)
            pq.remove((total_time, current, current_time, path))
            
            if current in visited:
                continue
                
            visited.add(current)
            
            if current == end_stop:
                return path, total_time, path
                
            for neighbor in self.graph[current]:
                for _, edge_data in self.graph[current][neighbor].items():
                    departure_time = edge_data['departure_minutes']
                    wait_time = 0
                    
                    if departure_time < current_time:
                        # Train déjà parti donc passer au lendemain
                        wait_time = (24 * 60 - current_time) + departure_time
                    else:
                        wait_time = departure_time - current_time
                        
                    new_time = current_time + wait_time + edge_data['weight']
                    new_total = total_time + wait_time + edge_data['weight']
                    
                    if new_total < distances[neighbor]:
                        distances[neighbor] = new_total
                        new_path = path + [{
                            'from': self.stops_info[current]['name'],
                            'to': self.stops_info[neighbor]['name'],
                            'train_type': edge_data['train_type'],
                            'departure': edge_data['departure'],
                            'arrival': edge_data['arrival']
                        }]
                        pq.append((new_total, neighbor, new_time, new_path))
        
        return None, float('inf'), None

    def find_fastest_route(self, start_city: str, end_city: str, 
                         intermediate_cities: Optional[List[str]] = None, 
                         current_time: Optional[str] = None):
        """
        Trouve l'itinéraire le plus rapide en passant par des villes intermédiaires.
        
        Args:
            start_city: Ville de départ
            end_city: Ville d'arrivée
            intermediate_cities: Liste optionnelle de villes intermédiaires
            current_time: Heure de départ (format HH:MM:SS)
            
        Returns:
            Tuple contenant le chemin et les informations de route
        """
        if intermediate_cities is None:
            intermediate_cities = []
            
        if len(intermediate_cities) > 5:
            raise ValueError("Le nombre de villes intermédiaires ne peut pas dépasser 5")
            
        logger.info(f"Recherche d'itinéraire entre {start_city} et {end_city} "
                   f"via {', '.join(intermediate_cities) if intermediate_cities else 'trajet direct'}")
        
        cities = [start_city] + intermediate_cities + [end_city]
        
        complete_path = []
        complete_route_info = []
        current_minutes = self._time_to_minutes(current_time) if current_time else self._time_to_minutes(datetime.now().strftime('%H:%M:%S'))
        
        for i in range(len(cities) - 1):
            current_city = cities[i]
            next_city = cities[i + 1]
            
            path, route_info = self._find_path_between_cities(
                current_city, 
                next_city, 
                current_minutes
            )
            
            if not path or not route_info:
                logger.error(f"Pas de chemin trouvé entre {current_city} et {next_city}")
                return None, None
                
            last_arrival_str = route_info[-1]['arrival']
            current_minutes = self._time_to_minutes(last_arrival_str)
            
            complete_path.extend(path)
            complete_route_info.extend(route_info)
            
        return complete_path, complete_route_info

    def _find_path_between_cities(self, start_city: str, end_city: str, 
                                current_minutes: int):
        """
        Trouve le meilleur chemin entre deux villes.
        """
        start_stops = self.find_station_ids(start_city)
        end_stops = self.find_station_ids(end_city)
        
        if not start_stops or not end_stops:
            return None, None
        
        shortest_path = None
        min_duration = float('inf')
        best_route_info = None

        for start_stop in start_stops:
            for end_stop in end_stops:
                try:
                    path, duration, route_info = self._dijkstra_with_time(
                        start_stop, end_stop, current_minutes
                    )
                    if path and duration < min_duration:
                        min_duration = duration
                        shortest_path = path
                        best_route_info = route_info
                except Exception as e:
                    logger.error(f"Erreur lors de la recherche de chemin: {e}")
                    continue

        return shortest_path, best_route_info

    def format_route(self, route_info: List[dict], current_time: Optional[str] = None) -> str:
        """
        Formate l'itinéraire pour l'affichage.
        """
        if not route_info:
            return "Aucun itinéraire trouvé."
        
        if current_time is None:
            current_time = datetime.now().strftime('%H:%M:%S')
        
        formatted_route = [f"Heure actuelle : {current_time}\n"]
        total_duration = timedelta()
        
        first_departure = datetime.strptime(route_info[0]['departure'], '%H:%M:%S')
        current_dt = datetime.strptime(current_time, '%H:%M:%S')
        
        wait_time = first_departure - current_dt
        if wait_time.total_seconds() < 0:
            wait_time = timedelta(days=1) + wait_time
        
        formatted_route.append(f"Temps d'attente jusqu'au premier train : {wait_time}\n")
        
        # Gérer les connexions entre les segments
        previous_arrival = None
        total_connection_time = timedelta()
        
        for segment in route_info:
            departure = datetime.strptime(segment['departure'], '%H:%M:%S')
            arrival = datetime.strptime(segment['arrival'], '%H:%M:%S')
            duration = arrival - departure
            
            # Calculer le temps de correspondance si nécessaire
            if previous_arrival:
                connection_time = departure - previous_arrival
                if connection_time.total_seconds() < 0:
                    connection_time = timedelta(days=1) + connection_time
                total_connection_time += connection_time
                formatted_route.append(f"Temps de correspondance : {connection_time}\n")
            
            formatted_route.append(
                f"{segment['train_type']} : {segment['from']} → {segment['to']}\n"
                f"Départ : {segment['departure']}, Arrivée : {segment['arrival']}\n"
                f"Durée : {duration}\n"
            )
            
            total_duration += duration
            previous_arrival = arrival
        
        formatted_route.append(f"\nDurée totale en train : {total_duration}")
        formatted_route.append(f"Temps total de correspondance : {total_connection_time}")
        total_time = wait_time + total_duration + total_connection_time
        formatted_route.append(f"Temps total (attente + trajet + correspondances) : {total_time}")
        
        return "\n".join(formatted_route)


In [5]:
router = TrainRouter("assets/export_gtfs_voyages", "assets/export-ter-gtfs-last")

INFO:__main__:Initialisation du router...
INFO:__main__:Chargement des données TGV depuis assets/export_gtfs_voyages
INFO:__main__:29346 arêtes ajoutées pour TGV
INFO:__main__:Chargement des données TER depuis assets/export-ter-gtfs-last
INFO:__main__:276046 arêtes ajoutées pour TER
INFO:__main__:Graphe construit avec 5479 nœuds et 305392 arêtes


In [10]:
logger.info("Statistiques du graphe:")
logger.info(f"Nombre de nœuds (gares): {router.graph.number_of_nodes()}")
logger.info(f"Nombre d'arêtes (trajets): {router.graph.number_of_edges()}")

INFO:__main__:Statistiques du graphe:
INFO:__main__:Nombre de nœuds (gares): 5479
INFO:__main__:Nombre d'arêtes (trajets): 305392


In [12]:
current_time = datetime.now().strftime('%H:%M:%S')
path, route_info = router.find_fastest_route("strasbourg", "lyon", current_time=current_time)
if route_info:
    print(router.format_route(route_info))
else:
    print("Aucun itinéraire trouvé.")

INFO:__main__:Recherche d'itinéraire entre strasbourg et lyon via trajet direct
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCEICE-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCELyria-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCEOUIGO-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCETGV INOUI-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCECar TER-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg (StopPoint:OCETrain TER-87212027)
INFO:__main__:Station trouvée pour strasbourg: Strasbourg Roethig (StopPoint:OCETrain TER-87212191)
INFO:__main__:Station trouvée pour lyon: Paris Gare de Lyon Hall 1 - 2 (StopPoint:OCELyria-87686006)
INFO:__main__:Station trouvée pour lyon: Paris Gare de Lyon Hall 1 - 2 (StopPoint:OCEOUIGO-87686006)
INFO:__main__:Station trouvée pour lyon: Paris Gare de Lyon Hall 1 - 2 (StopPoint:

Heure actuelle : 15:55:13

Temps d'attente jusqu'au premier train : 0:20:47

TGV : Strasbourg → Mulhouse
Départ : 16:16:00, Arrivée : 17:06:00
Durée : 0:50:00

Temps de correspondance : 0:03:00

TGV : Mulhouse → Belfort - Montbéliard TGV
Départ : 17:09:00, Arrivée : 17:31:00
Durée : 0:22:00

Temps de correspondance : 0:03:00

TGV : Belfort - Montbéliard TGV → Besançon Franche-Comté TGV
Départ : 17:34:00, Arrivée : 17:55:00
Durée : 0:21:00

Temps de correspondance : 0:04:00

TGV : Besançon Franche-Comté TGV → Chalon-sur-Saône
Départ : 17:59:00, Arrivée : 18:53:00
Durée : 0:54:00

Temps de correspondance : 0:03:00

TGV : Chalon-sur-Saône → Lyon Part Dieu
Départ : 18:56:00, Arrivée : 19:56:00
Durée : 1:00:00


Durée totale en train : 3:27:00
Temps total de correspondance : 0:13:00
Temps total (attente + trajet + correspondances) : 4:00:47


In [15]:
route_info

[{'from': 'Strasbourg',
  'to': 'Mulhouse',
  'train_type': 'TGV',
  'departure': '16:16:00',
  'arrival': '17:06:00'},
 {'from': 'Mulhouse',
  'to': 'Belfort - Montbéliard TGV',
  'train_type': 'TGV',
  'departure': '17:09:00',
  'arrival': '17:31:00'},
 {'from': 'Belfort - Montbéliard TGV',
  'to': 'Besançon Franche-Comté TGV',
  'train_type': 'TGV',
  'departure': '17:34:00',
  'arrival': '17:55:00'},
 {'from': 'Besançon Franche-Comté TGV',
  'to': 'Chalon-sur-Saône',
  'train_type': 'TGV',
  'departure': '17:59:00',
  'arrival': '18:53:00'},
 {'from': 'Chalon-sur-Saône',
  'to': 'Lyon Part Dieu',
  'train_type': 'TGV',
  'departure': '18:56:00',
  'arrival': '19:56:00'}]

In [45]:
with open('assets/liste-des-circuits-de-voie.json', 'r', encoding='utf-8') as f:
    voies_data = json.load(f)

with open('assets/gares-de-voyageurs.json', 'r', encoding='utf-8') as f:
    gares_data = json.load(f)

In [46]:
voies_segments = []
for voie in voies_data:
    if 'geo_shape' in voie and 'geometry' in voie['geo_shape']:
        if voie['geo_shape']['geometry']['type'] == "LineString":
            coordinates = voie['geo_shape']['geometry']['coordinates']
            voies_segments.append(coordinates)

gares_positions = []
for gare in gares_data:
    if 'position_geographique' in gare and gare['position_geographique'] != None:
        lon = gare['position_geographique']['lon']
        lat = gare['position_geographique']['lat']
        gares_positions.append((lon, lat))

In [47]:
gares_dict = {gare['nom']: (gare['position_geographique']['lat'], gare['position_geographique']['lon']) for gare in gares_data if gare['position_geographique'] != None}

In [62]:
coordinates = []
for route in route_info:
    coordinates.append(gares_dict[route['from']])
    coordinates.append(gares_dict[route['to']])

In [None]:
lats_chemin, lons_chemin = zip(*coordinates)

plt.figure(figsize=(12, 10))

for segment in voies_segments:
    lons, lats = zip(*segment)
    plt.plot(lons, lats, color='gray', linewidth=0.5, alpha=0.7)

plt.plot(lons_chemin, lats_chemin, marker='o', color='blue', label='Itinéraire')

plt.title("Réseau ferré SNCF en France", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()