In [10]:
from pathlib import Path
import json
import sqlite3
from collections import defaultdict

import polars as pl
import plotly.graph_objects as go
import plotly.express as px

def _find_repo_root(start: Path) -> Path:
    markers = ["run_inference.py", "deterministic_vrp_solver", ".git"]
    current = start
    for _ in range(5):
        if any((current / m).exists() for m in markers):
            return current
        if current.parent == current:
            break
        current = current.parent
    return start

repo_cwd = _find_repo_root(Path.cwd())
solver_dir = repo_cwd

solution_path = repo_cwd / "solution.json"
orders_path = repo_cwd / "ml_ozon_logistic" / "ml_ozon_logistic_dataSetOrders.json"
couriers_path = repo_cwd / "ml_ozon_logistic" / "ml_ozon_logistic_dataSetCouriers.json"

candidate_ports_paths = [
    repo_cwd / "deterministic_vrp_solver" / "data" / "ports_database.sqlite",
]

ports_db_path = next((p for p in candidate_ports_paths if p.exists()), candidate_ports_paths[0])
warehouse_ports_db_path = repo_cwd / "deterministic_vrp_solver" / "data" / "warehouse_ports_database.sqlite"

show_ports = True
line_opacity = 0.45
order_marker_size = 5
port_marker_size = 8

In [11]:
with open(solution_path, "r", encoding="utf-8") as f:
    solution = json.load(f)
routes = solution.get("routes", [])

orders_df = (pl.read_json(str(orders_path))
    .explode("Orders")
    .unnest("Orders")
    .select(["ID", "MpId", "Lat", "Long"]) 
)

order_to_geo = {row[0]: (row[2], row[3], row[1]) for row in orders_df.iter_rows()}

warehouse = pl.read_json(str(couriers_path)).select("Warehouse").item(0, "Warehouse")
warehouse_lat = float(warehouse.get("Lat", 0.0))
warehouse_lon = float(warehouse.get("Long", 0.0))

conn = sqlite3.connect(str(ports_db_path))
cur = conn.cursor()

polygon_to_ports = defaultdict(list)
try:
    cur.execute("SELECT polygon_id, port_id FROM polygon_ports")
    for mp_id, port_id in cur.fetchall():
        polygon_to_ports[int(mp_id)].append(int(port_id))
except Exception:
    polygon_to_ports = defaultdict(list)

ports_geo = {}
try:
    cur.execute("SELECT port_id, lat, lon FROM ports")
    for pid, plat, plon in cur.fetchall():
        ports_geo[int(pid)] = (float(plat), float(plon))
except Exception:
    ports_geo = {}

conn.close()

In [12]:
fig = go.Figure()

courier_colors = px.colors.qualitative.Alphabet * 10
polygon_colors = px.colors.qualitative.Light24 * 10

mp_to_points = defaultdict(lambda: {"lat": [], "lon": []})
for route in routes:
    for oid in route["route"][1:-1]:  # skip warehouse at ends
        lat, lon, mp = order_to_geo.get(oid, (None, None, None))
        if lat is None:
            continue
        mp_to_points[mp]["lat"].append(lat)
        mp_to_points[mp]["lon"].append(lon)

for i, (mp, pts) in enumerate(mp_to_points.items()):
    fig.add_trace(go.Scattergl(
        x=pts["lon"], y=pts["lat"],
        mode="markers",
        marker=dict(size=order_marker_size, color=polygon_colors[i % len(polygon_colors)]),
        showlegend=False
    ))

for idx, route in enumerate(routes):
    color = courier_colors[idx % len(courier_colors)]
    coords = []
    for oid in route["route"]:
        lat, lon, _ = order_to_geo.get(oid, (None, None, None))
        if lat is None:
            continue
        coords.append((lat, lon))
    if len(coords) >= 2:
        fig.add_trace(go.Scattergl(
            x=[c[1] for c in coords], y=[c[0] for c in coords],
            mode="lines",
            line=dict(color=color, width=1),
            opacity=line_opacity,
            showlegend=False
        ))

if show_ports and ports_geo:
    port_lats = []
    port_lons = []
    for mp, ports in polygon_to_ports.items():
        for pid in ports:
            if pid in ports_geo:
                plat, plon = ports_geo[pid]
                port_lats.append(plat)
                port_lons.append(plon)
    if port_lats:
        fig.add_trace(go.Scattergl(
            x=port_lons, y=port_lats,
            mode="markers",
            marker=dict(size=port_marker_size, symbol="square", color="#000000"),
            showlegend=False
        ))

fig.add_trace(go.Scattergl(
    x=[warehouse_lon], y=[warehouse_lat],
    mode="markers",
    marker=dict(size=14, symbol="star", color="#FF0000", line=dict(width=1, color="#660000")),
    name="Warehouse",
    showlegend=False
))

fig.update_layout(
    width=1200, height=800,
    margin=dict(l=10, r=10, t=10, b=10),
    xaxis_title="Longitude", yaxis_title="Latitude",
    template="plotly_white"
)
fig.show()