In [12]:
from dataclasses import dataclass

import numpy as np
import plotly.graph_objects as go

from swarm_nfomp.utils.math import PointArray2D, RectangleRegionArray
from swarm_nfomp.utils.position2d import Position2D

In [13]:
@dataclass
class MovingObstacle:
    y_line: np.ndarray
    width: float
    length: float
    velocity: float
    x_initial_position: float


@dataclass
class MovingObject:
    position: Position2D
    time: float

In [43]:
@dataclass
class StampedPointArray:
    times: np.ndarray
    points: PointArray2D


@dataclass
class MovingObstacleArray:
    y_line: np.ndarray
    width: np.ndarray
    length: np.ndarray
    velocity: np.ndarray
    x_initial_position: np.ndarray

    def inside(self, positions: StampedPointArray) -> np.ndarray:
        min_x = self.min_x(positions.times)
        min_y = self.min_y(positions.times)
        max_x = self.max_x(positions.times)
        max_y = self.max_y(positions.times)
        result = min_x <= positions.points.x[:, None]
        result &= positions.points.x[:, None] <= max_x
        result &= min_y <= positions.points.y[:, None]
        result &= positions.points.y[:, None] <= max_y
        return np.any(result, axis=1)

    def min_x(self, times):
        x_positions = self.x_initial_position[np.newaxis, :] + self.velocity * times[:, np.newaxis]
        result = x_positions - self.length[np.newaxis, :] / 2
        return result

    def max_x(self, times):
        x_positions = self.x_initial_position[np.newaxis, :] + self.velocity * times[:, np.newaxis]
        result = x_positions + self.length[np.newaxis, :] / 2
        return result

    def min_y(self, times):
        result = self.y_line[np.newaxis, :] - self.width[np.newaxis, :] / 2
        return np.repeat(result, len(times), axis=0)

    def max_y(self, times):
        result = self.y_line[np.newaxis, :] + self.width[np.newaxis, :] / 2
        return np.repeat(result, len(times), axis=0)

    @classmethod
    def from_dict(cls, data):
        y_line = np.array(data['y_line'])
        width = np.array(data['width'])
        length = np.array(data['length'])
        velocity = np.array(data['velocity'])
        x_initial_position = np.array(data['x_initial_position'])
        return cls(y_line, width, length, velocity, x_initial_position)

    def __len__(self):
        return len(self.x_initial_position)


class HighwayCollisionDetector:
    def __init__(self, moving_obstacle_array: MovingObstacleArray,
                 outside_rectangle_region_array: RectangleRegionArray):
        self.moving_obstacle_array = moving_obstacle_array
        self.outside_rectangle_region_array = outside_rectangle_region_array

    def is_collision(self, points: StampedPointArray):
        return self.moving_obstacle_array.inside(points) | (
            ~self.outside_rectangle_region_array.inside(points.points))

    @classmethod
    def from_dict(cls, data):
        moving_obstacle_array = MovingObstacleArray.from_dict(data['moving_obstacle_array'])
        outside_rectangle_region_array = RectangleRegionArray.from_dict(data['outside_rectangle_region_array'])
        return cls(moving_obstacle_array, outside_rectangle_region_array)


In [87]:
highway_collision_detector_parameters = {
    'moving_obstacle_array': {
        'y_line': [1, 2, 3],
        'width': [0.5, 0.6, 0.7],
        'length': [0.8, 0.9, 1.0],
        'velocity': [1, 1.75, 2],
        'x_initial_position': [0, 0, 0],
    },
    "outside_rectangle_region_array": [[-2, 24, 0, 4]]
}
collision_detector = HighwayCollisionDetector.from_dict(highway_collision_detector_parameters)

In [109]:
import itertools


def flat_lists(list_to_flat):
    return list(itertools.chain.from_iterable(list_to_flat))


class HighwayCollisionDetectorVisualizer:
    def __init__(self):
        self._fig = go.Figure()

    def visualize(self, detector: HighwayCollisionDetector, points: StampedPointArray):
        scatter = self.get_region_scatter(detector.outside_rectangle_region_array)
        self._fig.add_trace(scatter)
        self.add_moving_obstacles_frames(detector, points)
        self.update_layout()
        self._fig.update_yaxes(scaleanchor="x", scaleratio=1)

    @staticmethod
    def get_region_scatter(region, **kwargs):
        x_lists = [[region.min_x[i], region.max_x[i], region.max_x[i], region.min_x[i], region.min_x[i], None] for i in
                   range(len(region))]
        y_lists = [[region.min_y[i], region.min_y[i], region.max_y[i], region.max_y[i], region.min_y[i], None] for i in
                   range(len(region))]
        x = flat_lists(x_lists)
        y = flat_lists(y_lists)
        return go.Scatter(x=x[:-1], y=y[:-1], mode="lines", name="Outside region", **kwargs)

    def update_layout(self):
        self._fig.update_layout(
            updatemenus=[
                dict(
                    type="buttons",
                    direction="left",
                    showactive=False,
                    x=0.1,
                    y=1.2,
                    buttons=list([
                        dict(
                            label="Play",
                            method="animate",
                            args=[None, {
                                "frame": {
                                    "duration": 50,
                                    "redraw": False
                                },
                                "fromcurrent": True,
                                "transition": {
                                    "duration": 0,
                                }
                            }]
                        ),
                    ]),
                )
            ]
        )

    def add_moving_obstacles_frames(self, detector: HighwayCollisionDetector, points):
        times = points.times
        min_x = detector.moving_obstacle_array.min_x(times)
        max_x = detector.moving_obstacle_array.max_x(times)
        min_y = detector.moving_obstacle_array.min_y(times)
        max_y = detector.moving_obstacle_array.max_y(times)
        frames = []
        is_collision = detector.is_collision(points)
        print(is_collision)
        for i in range(len(times)):
            x_lists = [[min_x[i, j], max_x[i, j], max_x[i, j], min_x[i, j], min_x[i, j], None] for j in
                       range(max_y.shape[1])]
            y_lists = [[min_y[i, j], min_y[i, j], max_y[i, j], max_y[i, j], min_y[i, j], None] for j in
                       range(max_y.shape[1])]
            x = flat_lists(x_lists)[:-1]
            y = flat_lists(y_lists)[:-1]
            x_free = [points.points.x[i]] if not is_collision[i] else [0]
            y_free = [points.points.y[i]] if not is_collision[i] else [0]
            x_collision = [points.points.x[i]] if is_collision[i] else [0]
            y_collision = [points.points.y[i]] if is_collision[i] else [0]
            if i == 0:
                self._fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name="MovingObstacles", fill="toself"))
                self._fig.add_trace(go.Scatter(x=x_collision, y=y_collision, mode="markers", name="collision", marker=dict(color="red")))
                self._fig.add_trace(go.Scatter(x=x_free, y=y_free, mode="markers", name="free", marker=dict(color="green")))
            frames.append(go.Frame(data=[
                self.get_region_scatter(detector.outside_rectangle_region_array),
                go.Scatter(x=x, y=y, mode="lines", name="MovingObstacles", fill="toself"),
                go.Scatter(x=x_collision, y=y_collision, mode="markers", name="collision", marker=dict(color="red")),
                go.Scatter(x=x_free, y=y_free, mode="markers", name="free", marker=dict(color="green"))
            ]))
            # frames.append(go.Frame(data=self.get_region_scatter(detector.outside_rectangle_region_array)))
        self._fig.frames = frames

    def save(self, filename):
        self._fig.write_html(filename)

In [110]:
visualizer = HighwayCollisionDetectorVisualizer()
test_times = np.linspace(0, 10, 100)
test_points = np.random.random((100, 2))
test_points[:, 0] = 26 * test_points[:, 0] - 2
test_points[:, 1] = 4 * test_points[:, 1]
visualizer.visualize(collision_detector, StampedPointArray(test_times, PointArray2D.from_vec(test_points)))
visualizer.save("data/highway_collision_detector_visualizer.html")

[False False False False False False False False False False False False
 False  True False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False False False False False
 False False False False False False False False  True False False  True
 False False False False]
