In [None]:
import datetime
%load_ext autoreload
%autoreload 2

In [None]:
from benchmarks.generator import get_revenue_behavior, get_revenue_behavior_deprecated
from benchmarks.utils import sns_box_plot, sns_line_plot, int_input, get_schedule_from_supply, infer_line_stations, \
    get_services_by_tsp_df, plot_marey_chart, plot_scheduled_services

from robin.services_generator.entities import ServiceGenerator
from robin.supply.entities import Service, Supply

from pathlib import Path
import shutil

In [None]:
# Config files
supply_config_path = Path("../configs/generator/supply_config.yml")
generator_config_path = Path("../configs/generator/generator_config.yml")

# Save paths
generator_save_path = Path(f'../data/generator/supply_dummy.yml')
supply_save_path = '../configs/mealpy/'
figures = '../figures/'

# Clean save paths directories
if Path(generator_save_path.parent).exists():
    shutil.rmtree(generator_save_path.parent)

Path(generator_save_path.parent).mkdir(parents=True)

if Path(supply_save_path).exists():
    shutil.rmtree(supply_save_path)

Path(supply_save_path).mkdir(parents=True)

In [None]:
seed = 21

if generator_config_path:
    generator = ServiceGenerator(supply_config_path=supply_config_path)
    _ = generator.generate(file_name=generator_save_path,
                           path_config=generator_config_path,
                           n_services=16,
                           seed=seed)
    print(f'Number of service requests generated: {len(_)}')

In [None]:
supply = Supply.from_yaml(generator_save_path)
tsp_df = get_services_by_tsp_df(supply.services)
print(tsp_df)
print("Services: ", len(supply.services))

In [None]:
requested_schedule = get_schedule_from_supply(generator_save_path)
revenue_behavior = get_revenue_behavior_deprecated(supply)
lines = supply.lines
line = infer_line_stations(lines)
n_services = sum(tsp_df["Number of Services"].values)

In [None]:
services_by_ru = {}
for service in revenue_behavior:
    if revenue_behavior[service]['ru'] not in services_by_ru:
        services_by_ru[revenue_behavior[service]['ru']] = 1
    else:
        services_by_ru[revenue_behavior[service]['ru']] += 1

services_by_ru = {f"RU{k}": v for k, v in services_by_ru.items()}
print(services_by_ru)

In [None]:
frame_capacity = {ru: round(services_by_ru[ru] / n_services * 100, 2) for ru in services_by_ru}
print(frame_capacity)

In [None]:
plot_marey_chart(requested_supply=supply,
                 colors_by_tsp=True,
                 main_title="Marey chart",
                 plot_security_gaps=True,
                 security_gap=10,
                 save_path=Path('../reports/mealpy/marey_chart.pdf'))

In [None]:
for service in supply.services:
    print(service)

In [None]:
from robin.supply.entities import Station

from typing import List, Mapping

def get_stations_positions(stations: List[Station]) -> Mapping[Station, float]:
    """
    Compute the positions of stations along the line based on their coordinates.

    Args:
        stations: List of Station objects.

    Returns:
        A dictionary mapping each station to its position along the line.
    """
    positions = {}
    if not stations:
        return positions

    # First station is at position zero
    positions[stations[0]] = 0.0
    total_distance = 0.0

    # Iterate over consecutive station pairs
    for prev, curr in zip(stations, stations[1:]):
        segment = geodesic(prev.coords, curr.coords).km
        total_distance += segment
        positions[curr] = total_distance

    return positions

In [None]:
from functools import cache
from datetime import timedelta
from typing import Callable, Tuple

@cache
def get_time_from_position(
    point_a: Tuple[timedelta, float],
    point_b: Tuple[timedelta, float]
) -> Callable[[float], timedelta]:
    """
    Build a linear interpolator that maps a position (float) back to a time.

    Args:
        point_a: (time, position) for the first sample.
        point_b: (time, position) for the second sample.

    Returns:
        A function f(pos: float) -> timedelta giving the interpolated time.
    """
    # Convert times to minutes (float)
    t0 = point_a[0].total_seconds() / 60
    t1 = point_b[0].total_seconds() / 60

    # Extract positions
    x0 = point_a[1]
    x1 = point_b[1]

    # Ensure we have a valid line
    if t0 == t1:
        raise ValueError("point_a and point_b must have different times")
    if x0 == x1:
        raise ValueError("point_a and point_b must have different positions")

    # Slope: position change per minute
    slope = (x1 - x0) / (t1 - t0)

    def time_from_position(position: float) -> timedelta:
        """
        Given a position, compute the corresponding time via
        inverse of y = slope * t + intercept.
        """
        # Invert the line: t = (position - intercept) / slope
        minutes = (position - x0) / slope + t0
        return timedelta(minutes=minutes)

    return time_from_position

In [None]:
def get_line_stations(services: List[Service]) -> List[Station]:
    """
    Get the stations of the line from a list of services.

    Args:
        services: List of Service objects.

    Returns:
        List[Station]: A list of Station objects representing the stations of the line.
    """
    # Build a set of stations from the services
    set_stations = set()
    for service in services:
        for station in service.line.stations:
            set_stations.add(station)

    # Identify path with all stations
    stations = None
    for path in services[0].line.corridor.paths:
        if all(station in path for station in set_stations):
            stations = path
            break

    return stations

In [None]:
import numpy as np
from typing import List, Mapping, NamedTuple
from datetime import timedelta

from geopy.distance import geodesic
from robin.supply.entities import Service, Station

class Segment(NamedTuple):
    service_idx: int
    start_pos: float
    end_pos: float
    time_at: "Callable[[float], timedelta]"

def _build_segments_for_service(
    idx: int,
    service: Service,
    positions: Mapping[Station, float]
) -> List[Segment]:
    """
    Build a list of motion segments for a service, each with its spatial interval
    and a local linear time interpolator.
    """
    segments: List[Segment] = []
    stations = service.line.stations

    for k, (prev_stn, next_stn) in enumerate(zip(stations, stations[1:])):
        start_pos = positions[prev_stn]
        end_pos = positions[next_stn]

        # Scheduled departure and arrival
        depart_time = service.schedule[k][1]
        arrive_time = service.schedule[k + 1][0]

        # Local interpolator mapping any position in [start_pos, end_pos]
        time_interp = get_time_from_position(
            (depart_time, start_pos),
            (arrive_time, end_pos)
        )

        segments.append(Segment(idx, start_pos, end_pos, time_interp))

    return segments


def _segments_conflict(
    seg1: Segment,
    seg2: Segment,
    safety_headway: int
) -> bool:
    """
    Determine if two motion segments conflict within a given safety headway.

    They conflict if their spatial intervals overlap and their time gaps
    at the overlap boundaries violate the headway constraint.
    """
    # Spatial overlap
    overlap_start = max(seg1.start_pos, seg2.start_pos)
    overlap_end = min(seg1.end_pos, seg2.end_pos)
    if overlap_start >= overlap_end:
        return False

    # Time at overlap boundaries
    t1_start = seg1.time_at(overlap_start)
    t1_end = seg1.time_at(overlap_end)
    t2_start = seg2.time_at(overlap_start)
    t2_end = seg2.time_at(overlap_end)

    # Time differences in whole minutes
    dt_start = int((t2_start - t1_start).total_seconds() // 60)
    dt_end = int((t2_end - t1_end).total_seconds() // 60)

    # No conflict if both differences have the same sign (ordering preserved)
    # and both exceed twice the safety headway
    same_order = dt_start * dt_end > 0
    if same_order and abs(dt_start) >= 2 * safety_headway and abs(dt_end) >= 2 * safety_headway:
        return False

    return True


def get_conflict_matrix(
    services: List[Service],
    safety_headway: int = 10
) -> np.ndarray:
    """
    Compute a symmetric conflict matrix for a fleet of services.

    Returns an NxN boolean array where entry [i, j] is True if services
    i and j have any overlapping segments that violate the safety headway.
    """
    n = len(services)
    conflicts = np.zeros((n, n), dtype=bool)

    # Spatial positions of all stations on the line
    stations = get_line_stations(services)
    positions = get_stations_positions(stations)

    # Precompute segments per service
    all_segments = [
        _build_segments_for_service(i, svc, positions)
        for i, svc in enumerate(services)
    ]

    # Check each pair of services once
    for i in range(n):
        for j in range(i + 1, n):
            # Test all segment pairs
            conflict_found = False
            for seg1 in all_segments[i]:
                for seg2 in all_segments[j]:
                    if _segments_conflict(seg1, seg2, safety_headway):
                        conflict_found = True
                        break
                if conflict_found:
                    break

            conflicts[i, j] = conflict_found
            conflicts[j, i] = conflict_found

    return conflicts

In [None]:
conflict_matrix = get_conflict_matrix(supply.services)

conflict_matrix

In [None]:
np.any(conflict_matrix, axis=1).sum()

In [None]:
for service in supply.services:
    print(service)

In [None]:
for i, row in enumerate(conflict_matrix):
    print(f"Service {i}: {row.sum()}")

In [None]:
print(conflict_matrix)

In [None]:
revenue_behavior

In [None]:
conflict_pairs = {}
conflicts = {}
for i, row in enumerate(conflict_matrix):
    for j, conflict in enumerate(row[i+1:], start=i+1):
        if conflict:
            conflict_pairs[f"Conflicto_{i}_{j}"] = (supply.services[i].id, supply.services[j].id)
            conflicts[(supply.services[i].id, supply.services[j].id)] = {
                'name': f"Conflicto_{i}_{j}",
                'weight': revenue_behavior[supply.services[i].id]['canon'] + revenue_behavior[supply.services[j].id]['canon']
            }

print(conflict_pairs)
print()
print(conflicts)

In [None]:
len(conflict_pairs)

In [None]:
services_dict = {}

def timedelta_to_minutes(td: timedelta) -> int:
    """
    Convert a timedelta to total minutes.
    """
    return int(td.total_seconds() // 60)

for service in supply.services:
    services_dict[service.id] = {
        'schedule': {station.id: (timedelta_to_minutes(arrival), timedelta_to_minutes(departure))
                     for station, (arrival, departure) in zip(service.line.stations, service.schedule)},
        'conflicts': [conflict for conflict in conflict_pairs if service.id in conflict_pairs[conflict]],
        'revenue': revenue_behavior[service.id]['canon'],
        'importance': revenue_behavior[service.id]['importance'],
        'penalty_sensitivity': revenue_behavior[service.id]['k'],
    }

print(services_dict)

In [None]:
revenue_behavior