In [None]:
from pydantic import BaseModel


class ServerConfig(BaseModel):
    area_name: str
    gtfs_url: str
    ignored_lines: list[str]
    custom_stop_nodes: dict[str, int]


with open("krakow_server_config.json") as file:
    server_config = ServerConfig.model_validate_json(file.read())

server_config

In [None]:
from zipfile import ZipFile
import pandas as pd


with ZipFile("GTFS_KRK_T_test.zip") as zip_file:
    with zip_file.open("stops.txt") as file:
        gtfs_stops = pd.read_csv(file).set_index("stop_id")

    with zip_file.open("routes.txt") as file:
        gtfs_routes = pd.read_csv(file).set_index("route_id")

    with zip_file.open("trips.txt") as file:
        gtfs_trips = pd.read_csv(file).set_index("trip_id")

    with zip_file.open("stop_times.txt") as file:
        gtfs_stop_times = pd.read_csv(file)

In [None]:
gtfs_stops.head()

In [None]:
gtfs_routes.head()

In [None]:
gtfs_trips.head()

In [None]:
gtfs_stop_times.head()

In [None]:
from collections import defaultdict

gtfs_stop_times_dict = gtfs_stop_times.set_index(["trip_id", "stop_sequence"]).to_dict()
gtfs_stop_times_dict_for_stop_ids: dict[tuple[str, int], str] = gtfs_stop_times_dict["stop_id"]

gtfs_stop_ids_by_trip_id: defaultdict[str, list[str]] = defaultdict(list)
for trip_id, stop_sequence in sorted(gtfs_stop_times_dict_for_stop_ids.keys()):
    gtfs_stop_ids_by_trip_id[trip_id].append(
        gtfs_stop_times_dict_for_stop_ids[trip_id, stop_sequence]
    )

In [None]:
import pickle
import overpy as op


OVERPASS_QUERY = f"""
[out:json];
area["name"="{server_config.area_name}"]->.search_area;
(
    relation["route"="tram"](area.search_area);
  	node["railway"="tram_stop"]["public_transport"="stop_position"](area.search_area);
    node(id:{", ".join(map(str, server_config.custom_stop_nodes.values()))})(area.search_area);
);
out geom;
"""

query_result = op.Overpass().query(OVERPASS_QUERY)
with open("overpass_query_test.pkl", "wb") as file:
    pickle.dump(query_result, file)

In [None]:
with open("overpass_query_test.pkl", "rb") as file:
    query_result: op.Result = pickle.load(file)

osm_relations = query_result.get_relations()
osm_nodes = query_result.get_nodes()

In [None]:
osm_relations

In [None]:
osm_nodes

In [None]:
import string


def gtfs_stop_name_to_comparable(stop_name: str):
    return (
        stop_name.lower()
        .replace(".", "")
        .replace(" ", "")
        .replace("(nż)", "")
    )


def stop_name_to_comparable(stop_name: str):
    return (
        stop_name.lower()
        .rstrip(string.digits)
        .replace(".", "")
        .replace(" ", "")
    )


osm_node_by_id = {
    item.id: item
    for item in osm_nodes
}

stops_by_relation = {
    relation: [
        osm_node_by_id[member.ref]
        for member in relation.members
        if isinstance(member, op.RelationNode)
        and member.ref in osm_node_by_id
    ]
    for relation in osm_relations
}

comparable_stop_names_by_relation = {
    relation: [
        stop_name_to_comparable(item.tags.get("name"))
        for item in stops
    ]
    for relation, stops in stops_by_relation.items()
}

def line_name_key_sort(item: str):
    words = item.split()
    if len(words) < 2:
        return words

    line_number = words[1]
    if line_number.endswith(":"):
        line_number = line_number[:-1]

    return int(line_number)


sorted((item.tags["name"] for item in stops_by_relation), key=line_name_key_sort)

In [None]:
import difflib
from collections import defaultdict


def is_longer_match(
    longest_match: difflib.Match,
    longest_relation: op.Relation,
    match_result: difflib.Match,
    match_relation: op.Relation,
):
    if match_result.size < 2 or longest_match.size > match_result.size:
        return False

    longest_relation_stop_count = len(stops_by_relation[longest_relation])
    match_relation_stop_count = len(stops_by_relation[match_relation])
    return (
        longest_match.size / longest_relation_stop_count
        <= match_result.size / match_relation_stop_count
    )


def find_longest_matching_relation(relations: list[op.Relation], gtfs_trip_stop_names: list[str]):
    longest_match = difflib.Match(0, 0, 0)
    longest_relation = relations[0]

    for relation in relations:
        sequence_matcher = difflib.SequenceMatcher(
            None,
            gtfs_trip_stop_names,
            comparable_stop_names_by_relation[relation],
        )

        match_result = sequence_matcher.find_longest_match(
            0, len(gtfs_trip_stop_names), 0, len(stops_by_relation[relation])
        )

        if is_longer_match(longest_match, longest_relation, match_result, relation):
            longest_match, longest_relation = match_result, relation

    return longest_match, longest_relation

 
def add_trip_to_mapping(
    gtfs_trip_id: str,
    relations: list[op.Relation],
    stop_mapping: dict[str, set[int]],
    start_stop_mapping: defaultdict[str, set[int]],
    end_stop_mapping: defaultdict[str, set[int]],
):
    gtfs_trip_stops = gtfs_stop_ids_by_trip_id[gtfs_trip_id]
    gtfs_trip_stop_data = gtfs_stops.loc[gtfs_trip_stops]
    gtfs_trip_stop_names = [
        gtfs_stop_name_to_comparable(item)
        for item in gtfs_trip_stop_data["stop_name"]
    ]

    longest_match, longest_relation = find_longest_matching_relation(
        relations, gtfs_trip_stop_names
    )

    for i, (gtfs_stop_id, osm_node) in enumerate(
        zip(
            gtfs_trip_stop_data.iloc[longest_match.a:longest_match.a + longest_match.size].index,
            stops_by_relation[longest_relation][longest_match.b:longest_match.b + longest_match.size]
        )
    ):
        if 0 < i < longest_match.size - 1:
            stop_mapping[gtfs_stop_id].add(osm_node.id)

    if gtfs_trip_stop_names[1:-1] == comparable_stop_names_by_relation[longest_relation][1:-1]:
        start_stop_mapping[gtfs_trip_stops[0]].add(stops_by_relation[longest_relation][0].id)
        end_stop_mapping[gtfs_trip_stops[-1]].add(stops_by_relation[longest_relation][-1].id)

    return longest_match.size, longest_relation


def update_relations_for_route(
    route_number: str,
    gtfs_route_id: str,
    gtfs_stop_id_to_osm_node_id_mapping: dict[str, set[int]],
    start_gtfs_stop_id_to_osm_node_id_mapping: defaultdict[str, set[int]],
    end_gtfs_stop_id_to_osm_node_id_mapping: defaultdict[str, set[int]],
    longest_match_by_relation: dict[op.Relation, int],
    longest_relation_by_trip_id: dict[str, op.Relation],
    missing_relation_lines: list[str],
):
    relations = [
        item
        for item in stops_by_relation
        if item.tags.get("ref") == route_number
    ]

    if not relations:
        missing_relation_lines.append(route_number)
        return
    
    gtfs_trips_for_route = gtfs_trips[gtfs_trips["route_id"] == gtfs_route_id]
    for gtfs_trip_id in gtfs_trips_for_route.index:
        longest_match_size, longest_relation = add_trip_to_mapping(
            str(gtfs_trip_id),
            relations,
            gtfs_stop_id_to_osm_node_id_mapping,
            start_gtfs_stop_id_to_osm_node_id_mapping,
            end_gtfs_stop_id_to_osm_node_id_mapping,
        )

        longest_match_by_relation[longest_relation] = max(
            longest_match_by_relation.get(longest_relation, 0),
            longest_match_size
        )

        longest_relation_by_trip_id[gtfs_trip_id] = longest_relation


def detect_node_mapping_errors():
    gtfs_stop_id_to_osm_node_id_mapping: dict[str, set[int]] = {
        str(stop_id): set()
        for stop_id in gtfs_stops.index
    }
    start_gtfs_stop_id_to_osm_node_id_mapping: dict[str, set[int]] = defaultdict(set)
    end_gtfs_stop_id_to_osm_node_id_mapping: dict[str, set[int]] = defaultdict(set)

    longest_match_by_relation: dict[op.Relation, int] = {}
    longest_relation_by_trip_id: dict[str, op.Relation] = {}
    missing_relation_lines: list[str] = []
    for gtfs_route_id, gtfs_route_row in gtfs_routes.iterrows():
        route_number = str(gtfs_route_row["route_long_name"])
        if route_number in server_config.ignored_lines:
            continue

        update_relations_for_route(
            route_number,
            gtfs_route_id,
            gtfs_stop_id_to_osm_node_id_mapping,
            start_gtfs_stop_id_to_osm_node_id_mapping,
            end_gtfs_stop_id_to_osm_node_id_mapping,
            longest_match_by_relation,
            longest_relation_by_trip_id,
            missing_relation_lines
        )

    for gtfs_stop_id, node_id in server_config.custom_stop_nodes.items():
        gtfs_stop_id_to_osm_node_id_mapping[gtfs_stop_id] = {node_id}

    nodes_without_mapping: set[str] = set()
    gtfs_stop_id_to_node_id: dict[str, int] = {}
    nodes_with_conflict: dict[str, list[tuple[str | None, int]]] = {}
    for gtfs_stop_id, osm_node_ids in gtfs_stop_id_to_osm_node_id_mapping.items():
        match len(osm_node_ids):
            case 0:
                nodes_without_mapping.add(gtfs_stop_id)
            case 1:
                gtfs_stop_id_to_node_id[gtfs_stop_id] = next(iter(osm_node_ids))
            case _:
                nodes_with_conflict[gtfs_stop_id] = [
                    (osm_node_by_id[node_id].tags.get("name"), node_id)
                    for node_id in osm_node_ids
                ]

    start_gtfs_stop_id_to_node_ids = {
        gtfs_stop_id: list(node_ids)
        for gtfs_stop_id, node_ids in start_gtfs_stop_id_to_osm_node_id_mapping.items()
    }

    end_gtfs_stop_id_to_node_ids = {
        gtfs_stop_id: list(node_ids)
        for gtfs_stop_id, node_ids in end_gtfs_stop_id_to_osm_node_id_mapping.items()
    }

    nodes_without_mapping = (
        nodes_without_mapping
        .difference(start_gtfs_stop_id_to_osm_node_id_mapping)
        .difference(end_gtfs_stop_id_to_osm_node_id_mapping)
    )

    underutilized_relations = [
        relation
        for relation, stops in stops_by_relation.items()
        if longest_match_by_relation[relation] < len(stops)
    ]

    return (
        gtfs_stop_id_to_node_id,
        start_gtfs_stop_id_to_node_ids,
        end_gtfs_stop_id_to_node_ids,
        missing_relation_lines,
        nodes_with_conflict,
        nodes_without_mapping,
        underutilized_relations,
        longest_relation_by_trip_id
    )


(
    gtfs_stop_id_to_node_id,
    start_gtfs_stop_id_to_node_ids,
    end_gtfs_stop_id_to_node_ids,
    missing_relation_lines,
    nodes_with_conflict,
    nodes_without_mapping,
    underutilized_relations,
    longest_relation_by_trip_id,
) = detect_node_mapping_errors()

In [None]:
missing_relation_lines

In [None]:
start_gtfs_stop_id_to_node_ids

In [None]:
end_gtfs_stop_id_to_node_ids

In [None]:
gtfs_stop_id_to_node_id

In [None]:
nodes_with_conflict

In [None]:
nodes_without_mapping

In [None]:
underutilized_relations

In [None]:
import random


def get_node_id_for_trip_stop(
    gtfs_stop_id: str, gtfs_stop_sequence: int, total_stops: int
):
    if gtfs_stop_id in gtfs_stop_id_to_node_id:
        return gtfs_stop_id_to_node_id[gtfs_stop_id]
    
    if (
        gtfs_stop_id in start_gtfs_stop_id_to_node_ids
        and gtfs_stop_id in end_gtfs_stop_id_to_node_ids
    ):
        return (
            random.choice(start_gtfs_stop_id_to_node_ids[gtfs_stop_id])
            if gtfs_stop_sequence < total_stops / 2
            else random.choice(end_gtfs_stop_id_to_node_ids[gtfs_stop_id])
        )
    

    if gtfs_stop_id in start_gtfs_stop_id_to_node_ids:
        return random.choice(start_gtfs_stop_id_to_node_ids[gtfs_stop_id])
    elif gtfs_stop_id in end_gtfs_stop_id_to_node_ids:
        return random.choice(end_gtfs_stop_id_to_node_ids[gtfs_stop_id])

    return None


def get_stop_nodes_by_gtfs_trip_id():
    stop_nodes_by_gtfs_trip_id: dict[str, list[int]] = {}
    gtfs_trips_with_missing_node_ids: list[tuple[str, list[int | None]]] = []

    for gtfs_trip_id, longest_relation in longest_relation_by_trip_id.items():
        relation_stop_nodes = stops_by_relation[longest_relation]
        relation_stop_names = comparable_stop_names_by_relation[longest_relation]

        gtfs_trip_stops = gtfs_stop_ids_by_trip_id[gtfs_trip_id]
        gtfs_trip_stop_data = gtfs_stops.loc[gtfs_trip_stops]
        gtfs_trip_stop_names = [
            gtfs_stop_name_to_comparable(item)
            for item in gtfs_trip_stop_data["stop_name"]
        ]

        if gtfs_trip_stop_names[1:-1] == relation_stop_names[1:-1]:
            stop_nodes_by_gtfs_trip_id[gtfs_trip_id] = [item.id for item in relation_stop_nodes]
            continue

        stop_nodes_from_mapping = [
            get_node_id_for_trip_stop(
                stop, i, len(gtfs_trip_stops)
            )
            for i, stop in enumerate(gtfs_trip_stops)
        ]

        if None in stop_nodes_from_mapping:
            gtfs_trips_with_missing_node_ids.append((gtfs_trip_id, stop_nodes_from_mapping))
        else:
            stop_nodes_by_gtfs_trip_id[gtfs_trip_id] = stop_nodes_from_mapping

    return (
        stop_nodes_by_gtfs_trip_id,
        gtfs_trips_with_missing_node_ids,
    )
    

(
    stop_nodes_by_gtfs_trip_id,
    gtfs_trips_with_missing_node_ids
) = get_stop_nodes_by_gtfs_trip_id()

In [None]:
stop_nodes_by_gtfs_trip_id

In [None]:
gtfs_trips_with_missing_node_ids

In [None]:
trips = {
    tuple(osm_node_by_id[item] for item in node_ids)
    for node_ids in stop_nodes_by_gtfs_trip_id.values()
}

unusual_trips_iter = iter(trips)

In [None]:
import folium

m = folium.Map(location=(50.05, 19.95), zoom_start=13)

tram_route = folium.PolyLine([(item.lat, item.lon) for item in next(unusual_trips_iter)])
tram_route.add_to(m)

m