In [None]:
import os
from glob import glob
import json
import pickle
from tqdm import tqdm
from datetime import datetime, timedelta

import avstack
import avapi

from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState

In [None]:
def object_to_stone_soup_truth(obj, t_start):
    ts = t_start + timedelta(seconds=obj.t)
    xx, xy, xz = obj.position.x
    h, w, l = obj.box.size
    vx, vy, vz = obj.velocity.x
    er, ep, ey = obj.attitude.euler
    state = [xx, vx, xy, vy, xz, vz, h, w, l, er, ep, ey]
    metadata = {
        "object_type": obj.obj_type,
        "object_ID": obj.ID,
        "occlusion": obj.occlusion,
    }
    return GroundTruthState(state, timestamp=ts, metadata=metadata)


def update_truth_dictionary(truth, data_input, t_start):
    # get the set of visible objects
    objs_in_view = set()
    for agent in data_input["agent_data"]:
        for obj in data_input["agent_data"][agent]["objects"]["lidar"]:
            objs_in_view.add(obj.ID)

    # save all object states
    for obj in data_input["objects"]:
        if obj.ID not in truth["objects"]:
            truth["objects"][obj.ID] = GroundTruthPath()
        truth["objects"][obj.ID].append(object_to_stone_soup_truth(obj, t_start))
        # save visibility times
        if (obj.ID in objs_in_view) or (obj.ID > 10000):
            if obj.ID not in truth["visible_times"]["first"]:
                truth["visible_times"]["first"][obj.ID] = t_start + timedelta(
                    seconds=obj.t
                )
            truth["visible_times"]["last"][obj.ID] = t_start + timedelta(seconds=obj.t)

    return truth

### Process Truth Data

In [None]:
from simulation.replayer import DatasetReplayer

log_dir = "last_run"

with open(os.path.join(log_dir, "metadata.json"), "r") as f:
    metadata = json.load(f)

In [None]:
# Process truth data
t_start = datetime.fromtimestamp(metadata["t_start"])

truth = {
    "objects_visible": {},
    "visible_times": {"first": {}, "last": {}},
    "objects": {},
}

# load all truth objects
timestamps = []
replayer = DatasetReplayer(**metadata["replayer"])
for data_input in replayer(load_perception=False):
    truth = update_truth_dictionary(truth, data_input, t_start)
    timestamps.append(t_start + data_input["timestamp_dt"])

# massage set of visible objects
for obj_ID in truth["objects"]:
    if obj_ID in truth["visible_times"]["first"]:
        truth["objects_visible"][obj_ID] = GroundTruthPath()
        for state in truth["objects"][obj_ID]:
            if (
                truth["visible_times"]["first"][obj_ID]
                <= state.timestamp
                <= truth["visible_times"]["last"][obj_ID]
            ):
                truth["objects_visible"][obj_ID].append(state)

In [None]:
from avstack.environment.objects import ObjectStateDecoder
from avstack.geometry.fov import FieldOfViewDecoder
from avstack.modules.perception.detections import DetectionContainerDecoder

from mate.estimator import TrustMessageDecoder


# Frame and timestamp data
frames = []
timestamps = []
with open(os.path.join(log_dir, "timestamps.txt")) as f:
    lines = f.readlines()
for line in lines:
    frame, timestamp = line.rstrip().split(", ")
    frames.append(frame)
    timestamps.append(datetime.fromtimestamp(float(timestamp)))
frame_to_ts = {frame: ts for frame, ts in zip(frames, timestamps)}


# Process pose, detection and track data
agent_poses_all = {}
agent_fovs_all = {}
agent_dets_all = {}
tracks_all = {}
for agent_subdir in [
    os.path.join(log_dir, "command-center"),
    *sorted(glob(os.path.join(log_dir, "agent-*"))),
]:
    agent_ID = int(agent_subdir[-1]) if "agent" in agent_subdir else "cc"

    # --------------------------
    # agent-specific information
    # --------------------------
    if agent_ID != "cc":
        agent_poses_all[agent_ID] = []
        agent_fovs_all[agent_ID] = []
        agent_dets_all[agent_ID] = []

        # agent pose
        for file in sorted(glob(os.path.join(agent_subdir, "pose", "*.txt"))):
            with open(file, "r") as f:
                pose = json.load(f, cls=ObjectStateDecoder)
            agent_poses_all[agent_ID].append(pose)
        assert len(agent_poses_all[agent_ID]) == len(frames)

        # agent fovs
        for file in sorted(glob(os.path.join(agent_subdir, "fov", "*.txt"))):
            with open(file, "r") as f:
                fov = json.load(f, cls=FieldOfViewDecoder)
            agent_fovs_all[agent_ID].append(fov)
        assert len(agent_fovs_all[agent_ID]) == len(frames)

        # detections
        for file in sorted(glob(os.path.join(agent_subdir, "detections", "*.txt"))):
            with open(file, "r") as f:
                dets = json.load(f, cls=DetectionContainerDecoder)
            agent_dets_all[agent_ID].append(dets)
        assert len(agent_dets_all[agent_ID]) == len(frames)

    # --------------------------
    # agent + CC information
    # --------------------------
    tracks_all[agent_ID] = {ts: {} for ts in timestamps}

    # tracks -- stone soup tracks only save last file...need to massage
    files_tracks = sorted(glob(os.path.join(agent_subdir, "tracks", "*.pickle")))
    assert len(files_tracks) == 1
    with open(files_tracks[0], "rb") as f:
        tracks = pickle.load(f)
    for track in tracks:
        for state in track.states:
            assert state.timestamp in tracks_all[agent_ID]
            # augment
            state.ID = track.ID
            tracks_all[agent_ID][state.timestamp][state.ID] = state

# Convert tracks to list -- need original dict for overwriting purposes
for agent_ID in tracks_all:
    for timestamp in tracks_all[agent_ID]:
        states = [track for track in tracks_all[agent_ID][timestamp].values()]
        tracks_all[agent_ID][timestamp] = states

# Trust-related information
trust_all = []
for file in sorted(glob(os.path.join(log_dir, "trust", "*.txt"))):
    with open(file, "r") as f:
        trust = json.load(f, cls=TrustMessageDecoder)
        trust_all.append(trust)
assert len(trust_all) == len(frames)

### Make Visualizations

In [None]:
from ordered_set import OrderedSet
from stonesoup.plotter import AnimatedPlotterly

plot_result = False

if plot_result:
    # plot in global
    truths_this = OrderedSet(truth["objects_visible"].values())
    plotter = AnimatedPlotterly(timestamps, tail_length=0.3, sim_duration=4)
    plotter.plot_ground_truths(truths_this, [0, 2])
    plotter.plot_tracks(
        tracks["cc"][-1].data, uncertainty=True, plot_history=True, mapping=[0, 2]
    )
    plotter.fig

In [None]:
import matplotlib.pyplot as plt

from avstack.geometry import GlobalOrigin3D
from mate import plotting


tracks_cc = []

for idx_frame in tqdm(range(len(agent_poses_all[0]))):

    # pull off data -- dict of current frame
    agent_poses = {
        agent_ID: poses[idx_frame].position.x
        for agent_ID, poses in agent_poses_all.items()
    }
    agent_fovs_global = {
        agent_ID: fovs[idx_frame].change_reference(GlobalOrigin3D, inplace=False)
        for agent_ID, fovs in agent_fovs_all.items()
    }
    agent_dets_global = {
        agent_ID: dets[idx_frame].apply_and_return(
            "change_reference", GlobalOrigin3D, inplace=False
        )
        for agent_ID, dets in agent_dets_all.items()
    }
    tracks_global = {
        agent_ID: tracks[timestamps[idx_frame]]
        for agent_ID, tracks in tracks_all.items()
    }
    agent_trust = trust_all[idx_frame].agent_trust
    track_trust = trust_all[idx_frame].track_trust

    # HACK only use cc tracks that are in trust
    tracks_global["cc"] = [
        track for track in tracks_global["cc"] if track.ID in track_trust
    ]
    tracks_cc.append(tracks_global["cc"])

    # make visualization plots
    plotting.plot_agents_detections(
        agent_poses,
        agent_fovs_global,
        agent_dets_global,
        show=False,
        save=True,
        fig_dir=os.path.join(log_dir, "visualization", "detections"),
        suffix=f"-idxframe-{idx_frame}",
        extension="png",
    )
    plotting.plot_agents_tracks(
        agent_poses,
        agent_fovs_global,
        tracks_global["cc"],
        show=False,
        save=True,
        fig_dir=os.path.join(log_dir, "visualization", "tracks"),
        suffix=f"-idxframe-{idx_frame}",
        extension="png",
    )
    plotting.plot_trust(
        tracks_global["cc"],
        track_trust,
        agent_trust,
        show=False,
        save=True,
        fig_dir=os.path.join(log_dir, "visualization", "trust"),
        use_subfolders=True,
        suffix=f"-idxframe-{idx_frame}",
        extension="png",
    )
    plt.close()

### Compute Metrics

In [None]:
from ordered_set import OrderedSet
from stonesoup.dataassociator.tracktotrack import TrackToTruth
from stonesoup.metricgenerator.manager import MultiManager
from stonesoup.measures import Euclidean
from stonesoup.metricgenerator.ospametric import OSPAMetric
from stonesoup.metricgenerator.tracktotruthmetrics import SIAPMetrics
from stonesoup.metricgenerator.uncertaintymetric import SumofCovarianceNormsMetric


tracking_filters = ["Baseline"]

ospa_generators = [
    OSPAMetric(
        c=40,
        p=1,
        generator_name=f"{tracking_filter} OSPA metrics",
        tracks_key=f"tracks_{tracking_filter}",
        truths_key="truths",
    )
    for tracking_filter in tracking_filters
]

siap_generators = [
    SIAPMetrics(
        position_measure=Euclidean((0, 2)),
        velocity_measure=Euclidean((1, 3)),
        generator_name=f"{tracking_filter} SIAP metrics",
        tracks_key=f"tracks_{tracking_filter}",
        truths_key="truths",
    )
    for tracking_filter in tracking_filters
]

uncertainty_generators = [
    SumofCovarianceNormsMetric(
        generator_name=f"{tracking_filter} OSPA metrics",
        tracks_key=f"tracks_{tracking_filter}",
    )
    for tracking_filter in tracking_filters
]
associator = TrackToTruth(association_threshold=30)

generators = ospa_generators + siap_generators + uncertainty_generators
metric_manager = MultiManager(generators, associator=associator)

metric_manager.add_data(
    {
        "truths": OrderedSet(truth["objects_visible"].values()),
        "tracks_Baseline": tracks["cc"][-1].data,
    }
)
metrics = metric_manager.generate_metrics()

In [None]:
from stonesoup.plotter import MetricPlotter


fig1 = MetricPlotter()
fig1.plot_metrics(metrics, metric_names=["OSPA distances"])

In [None]:
from stonesoup.metricgenerator.metrictables import SIAPTableGenerator


siap_metrics = metrics["Baseline SIAP metrics"]
siap_averages_baseline = {
    siap_metrics.get(metric)
    for metric in siap_metrics
    if metric.startswith("SIAP") and not metric.endswith(" at times")
}

_ = SIAPTableGenerator(siap_averages_baseline).compute_metric()
print("\n\nSIAP metrics for Baseline:")

In [None]:
# save object results
# ============================
for obj in data_input["objects"]:
    # save object in global frame
    obj.t = data_input["timestamp"]  # HACK
    if obj.ID not in truth["global"]:
        truth["objects"][obj.ID] = GroundTruthPath()
    truth["objects"][obj.ID].append(object_to_stone_soup_truth(obj, t_start))

    # save times object was in view
    if (obj.ID in objs_in_view) or (obj.ID > 10000):
        if obj.ID not in visible_times["first"]:
            visible_times["first"][obj.ID] = obj.t
        if obj.t < visible_times["first"][obj.ID]:
            raise RuntimeError("New timestamp is earlier than 'first'")
        visible_times["last"][obj.ID] = obj.t