In [None]:
import sys
import os
import inspect

import re
import json
import statistics
import argparse
import itertools
from pathlib import Path
from dataclasses import dataclass

# script_dir: str = os.path.dirname(os.path.abspath(__file__))
# os.chdir(script_dir)

# Get the directory of the current notebook and cd into it
# notebook_path = os.path.basename(sys.argv[0])
# notebook_dir = os.path.dirname(os.path.abspath(notebook_path))

# notebook_path = os.path.abspath(inspect.getfile(inspect.currentframe()))
# notebook_dir = os.path.dirname(notebook_path)

# os.chdir(notebook_dir)
# print(f"{sys.argv=}")

# os.chdir('./scripts')

print(f"{os.getcwd()=}")

from ldj import ldj
from utils import *

import numpy as np
import matplotlib.pyplot as plt
# from matplotlib.font_manager import FontProperties
import matplotlib.font_manager as fm
from matplotlib.patches import FancyBboxPatch
from matplotlib.patches import PathPatch
from matplotlib.path import get_path_collection_extents
import seaborn as sns

from rich import print, pretty
from tabulate import tabulate
from typing import  Iterable
import pretty_errors
from catppuccin import PALETTE
from IPython.display import display, HTML

pretty.install()

EXPERIMENT_DIR = Path("../experiments/collaborative-complex")
assert EXPERIMENT_DIR.is_dir() and EXPERIMENT_DIR.exists()

flavor = PALETTE.latte.colors

data = dict()

@dataclass
class Results:
    with_tracking: dict
    without_tracking: dict

results = Results(dict(), dict())

with open(EXPERIMENT_DIR / "tracking-true.json") as f:
    results.with_tracking = json.load(f)

with open(EXPERIMENT_DIR / "tracking-false.json") as f:
    results.without_tracking = json.load(f)


for dataset in [results.with_tracking, results.without_tracking]:
    robots_to_remote = []
    for robot_id, data in dataset['robots'].items():
        positions = data['positions']
        if len(positions) == 0:
            robots_to_remote.append(robot_id)
    
    for id in robots_to_remote:
        del dataset['robots'][id]


In [None]:
@dataclass(frozen=True)
class Statistics:
    mean: float
    median: float
    stdev: float
    min: float
    max: float


    def display(self) -> None:
        data = [
            ["Mean", self.mean],
            ["Median", self.median],
            ["Standard Deviation", self.stdev],
            ["Min", self.min],
            ["Max", self.max]
        ]
        html_table = tabulate(data, headers=["Statistic", "Value"], tablefmt="html")
        centered_html_table = f"""
        <div style="display: flex; justify-content: center;">
            {html_table}
        </div>
        """
        # display(HTML(html_table))
        display(HTML(centered_html_table))
        # print(tabulate(data, headers=["Statistic", "Value"], tablefmt="html"))


def compute_stats(data: list[float]) -> Statistics:
    return Statistics(
        mean=np.mean(data),
        median=np.median(data),
        stdev=np.std(data),
        min=np.min(data),
        max=np.max(data),
    )


In [None]:
@dataclass(frozen=True)
class PerpendicularPositionErrorResult:
    errors: list[float]
    rmses: list[float]


def perpendicular_position_error(exported_data: dict) -> PerpendicularPositionErrorResult:
    errors: list[float] = []
    rmses: list[float] = []

    for robot_id, robot_data in exported_data['robots'].items():
        color: str = robot_data['color']
        positions = np.array([p for p in robot_data['positions']])
        mission = robot_data['mission']
        waypoints = []
        for route in mission['routes']:
            waypoints.append(route['waypoints'][0])
            for wp in route['waypoints'][1:]:
                waypoints.append(wp)

        waypoints = np.array(waypoints)
        waypoints = np.squeeze(waypoints)

        lines: list[LinePoints] = [LinePoints(start=start, end=end) for start, end in sliding_window(waypoints, 2)]
        closest_projections = [closest_projection_onto_line_segments(p, lines) for p in positions]

        error: float = np.sum(np.linalg.norm(positions - closest_projections, axis=1))
        rmse: float = np.sqrt(error / len(positions))

        errors.append(error)
        rmses.append(rmse)

    return PerpendicularPositionErrorResult(errors=errors, rmses=rmses)


In [None]:
@dataclass(frozen=True)
class CollisionsResult:
    interrobot: int
    environment: int

def collisions(exported_data: dict) -> CollisionsResult:
    interrobot: int = len(exported_data['collisions']['robots'])
    environment: int = len(exported_data['collisions']['environment'])
    return CollisionsResult(interrobot=interrobot, environment=environment)


In [None]:
@dataclass(frozen=True)
class TotalDistanceTravelledResult:
    distance: list[float]
    optimal_distance: list[float]

    @staticmethod
    def new() -> 'TotalDistanceTravelledResult':
        return TotalDistanceTravelledResult(distance=[], optimal_distance=[])

# def euclidian_distance(a: (float, float), b: (float, float)) -> float:

def total_distance_travelled(exported_data: dict) -> TotalDistanceTravelledResult:
    result = TotalDistanceTravelledResult.new()

    for robot_id, robot_data in exported_data['robots'].items():
        # positions = np.array(robot_data['positions'])
        
        positions = robot_data['positions']
        # print(f"{positions=}")
        # if len(positions) == 0:
        #     continue

        # for i in range(len(positions) - 1, -1, -1):
        #     point = positions[i]
        #     if abs(point[0]) > 50 or abs(point[1]) > 50:
        #         _ = positions.pop()

        positions = np.array([p for p in robot_data['positions']])
        # print(f"{positions.shape=}")
        # if 0 in positions.shape:
        #     continue

        waypoints = []
        mission = robot_data['mission']
        for route in mission['routes']:
            waypoints.append(route['waypoints'][0])
            for wp in route['waypoints'][1:]:
                waypoints.append(wp)

        waypoints = np.array(waypoints)
        waypoints = np.squeeze(waypoints)

        
        for ix in [0, -1]:
            x =waypoints[ix][0]
            xlimit = 95
            if abs(x) > xlimit:
                sign: int = -1 if x < 0.0 else 1
                x = sign * xlimit
                waypoints[ix][0] = x
            
            ylimit = 60
            y =waypoints[ix][1]
            if abs(y) > ylimit:
                sign: int = -1 if y < 0.0 else 1
                y = sign * ylimit
                waypoints[ix][1] = y

        def accumulated_distance(points):
            # Compute pairwise Euclidean distances between successive points
            distances = np.sum(np.sqrt(np.sum(np.diff(points, axis=0)**2, axis=1)))
            return distances

        optimal_distance = accumulated_distance(waypoints)
        total_distance_travelled = accumulated_distance(positions)
        result.distance.append(total_distance_travelled)
        result.optimal_distance.append(optimal_distance)

    return result

# With Tracking

## Makespan

In [None]:
print(f"makespan = {results.with_tracking['makespan']:.2f} seconds")


## Perpendicular Position Error

In [None]:
pperror = perpendicular_position_error(results.with_tracking)
compute_stats(pperror.rmses).display()


## Collisions

In [None]:
collisions(results.with_tracking)


# Without Tracking

In [None]:
print(f"makespan = {results.without_tracking['makespan']:.2f} seconds")


## Perpendicular Position Error

In [None]:
pperror = perpendicular_position_error(results.without_tracking)
compute_stats(pperror.rmses).display()


## Collisions

In [None]:
collisions(results.without_tracking)


In [None]:
waypoint = total_distance_travelled(results.without_tracking)
path = total_distance_travelled(results.with_tracking)

mean = np.std(np.array(waypoint.distance) / np.array(waypoint.optimal_distance))
print(f"waypoint {mean=:.4f}")


mean = np.std(np.array(path.distance) / np.array(path.optimal_distance))
print(f"path {mean=:.4f}")

In [None]:
for robot_id, data in results.without_tracking['robots'].items():
    positions = data['positions']
    if len(positions) == 0:
        print(f"{robot_id=}")

---

# Velocity

In [None]:
# '.robots | values | .velocity'

for res, name in [(results.without_tracking, "Waypoint Tracking"), (results.with_tracking, "Path Tracking")]:
    iter = res['robots'].items()
    
    durations = np.empty(len(iter))
    
    for i, (_, robot_data) in enumerate(iter):
        m = robot_data['mission']
        finished_at: float = m['finished_at']
        started_at: float = m['started_at']
        dur: float = finished_at - started_at
        # print(f"{dur=}")
        durations[i] = dur
    
    # print(f"{np.mean(durations)=}")
    
           # centered_html_table = f"""
           #  <div style="display: flex; justify-content: center;">
           #      {html_table}
           #  </div>
           #  """
           #  # display(HTML(html_table))
           #  display(HTML(centered_html_table))
    
    title = f"<h1 align='center'>{name}</h1>"
    display(HTML(title))
    
    compute_stats(durations).display()


In [None]:
@dataclass(frozen=True)
class Statistics:
    mean: float
    median: float
    stdev: float
    min: float
    max: float


    def display(self) -> None:
        data = [
            ["Mean", self.mean],
            ["Median", self.median],
            ["Standard Deviation", self.stdev],
            ["Min", self.min],
            ["Max", self.max]
        ]
        html_table = tabulate(data, headers=["Statistic", "Value"], tablefmt="html")
        centered_html_table = f"""
        <div style="display: flex; justify-content: center;">
            {html_table}
        </div>
        """
        # display(HTML(html_table))
        display(HTML(centered_html_table))
        # print(tabulate(data, headers=["Statistic", "Value"], tablefmt="html"))


def compute_stats(data: list[float]) -> Statistics:
    return Statistics(
        mean=np.mean(data),
        median=np.median(data),
        stdev=np.std(data),
        min=np.min(data),
        max=np.max(data),
    )


In [None]:
def diffstat(vec0: np.ndarray, vec1: np.ndarray):
    

In [None]:
from dataclasses import dataclass
from typing import List
import numpy as np
from IPython.display import display, HTML
from tabulate import tabulate

@dataclass(frozen=True)
class Statistics:
    mean: float
    median: float
    stdev: float
    min: float
    max: float

def compute_stats(data: List[float]) -> Statistics:
    return Statistics(
        mean=np.mean(data),
        median=np.median(data),
        stdev=np.std(data),
        min=np.min(data),
        max=np.max(data),
    )

def compare_stats(data1: List[float], data2: List[float], label1: str = "Vector 1", label2: str = "Vector 2") -> None:
    stats1 = compute_stats(data1)
    stats2 = compute_stats(data2)
    
    data = [
        ["Statistic", label1, label2],
        ["Mean", f"{stats1.mean:.4f}", f"{stats2.mean:.4f}"],
        ["Median", f"{stats1.median:.4f}", f"{stats2.median:.4f}"],
        ["Standard Deviation", f"{stats1.stdev:.4f}", f"{stats2.stdev:.4f}"],
        ["Min", f"{stats1.min:.4f}", f"{stats2.min:.4f}"],
        ["Max", f"{stats1.max:.4f}", f"{stats2.max:.4f}"]
    ]

    
    html_table = tabulate(data, headers="firstrow", tablefmt="html")
  # Modify the HTML to change the color of the "Median" value
    html_table = html_table.replace(
        f"<td>{stats1.mean}</td>", f"<td style='color: green;'><b>{stats1.mean}</b></td>"
    )
    centered_html_table = f"""
    <div style="display: flex; justify-content: center;">
        {html_table}
    </div>
    """
    display(HTML(centered_html_table))

In [None]:
compare_stats([1, 2, 3], [4, 5, 6])

In [None]:
def compare_stats(vectors: list[np.ndarray], titles: list[str], higher_is_better: bool = True) -> None:
    assert len(vectors) == len(titles)
    
    stats = [compare_stats(v) for v in vectors]

    
    stats1 = compute_stats(data1)
    stats2 = compute_stats(data2)
    
    data = [
        ["Statistic", label1, label2],
        ["Mean", f"{stats1.mean:.4f}", f"{stats2.mean:.4f}"],
        ["Median", f"{stats1.median:.4f}", f"{stats2.median:.4f}"],
        ["Standard Deviation", f"{stats1.stdev:.4f}", f"{stats2.stdev:.4f}"],
        ["Min", f"{stats1.min:.4f}", f"{stats2.min:.4f}"],
        ["Max", f"{stats1.max:.4f}", f"{stats2.max:.4f}"]
    ]
    
    html_table = tabulate(data, headers="firstrow", tablefmt="html")
    centered_html_table = f"""
    <div style="display: flex; justify-content: center;">
        {html_table}
    </div>
    """
    display(HTML(centered_html_table))