In [40]:
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np

def plot_all(goals, agents, velocity, rots, frame_idx=None, plot_frames=False):
    if frame_idx is not None:
        agents, goals, rots, velocity = localize_all_wrt_one_agent(agents, frame_idx, goals, rots, velocity)
    # Convert tensors to numpy for easier handling in Plotly
    goals_numpy = goals.numpy()
    agents_numpy = agents.numpy()
    velocity_numpy = velocity.numpy()
    rots_numpy = rots.numpy()

    # Initialize the figure
    fig = go.Figure()

    # Plot agents
    fig.add_trace(go.Scatter(x=agents_numpy[:, 0], y=agents_numpy[:, 1],
                             mode='markers', name='Agents',
                             marker=dict(color='red', size=10)))

    # Label each agent
    for i in range(agents_numpy.shape[0]):
        fig.add_trace(go.Scatter(
            x=[agents_numpy[i, 0]],
            y=[agents_numpy[i, 1]],
            mode='text',
            text=[f'A{i+1}'],
            textposition='top center',
            showlegend=False
        ))

    # Plot goals
    fig.add_trace(go.Scatter(x=goals_numpy[:, 0], y=goals_numpy[:, 1],
                             mode='markers', name='Goals',
                             marker=dict(color='blue', size=10)))

    # Plot velocity vectors and rotations
    for i in range(velocity_numpy.shape[0]):
        fig.add_trace(go.Scatter(
            x=[agents_numpy[i, 0], agents_numpy[i, 0] + velocity_numpy[i, 0]],
            y=[agents_numpy[i, 1], agents_numpy[i, 1] + velocity_numpy[i, 1]],
            mode='lines+markers',
            name=f'Velocity {i+1}',
            line=dict(color='green')
        ))

        # Add rotation as arrow visualization
        fig.add_trace(go.Scatter(
            x=[agents_numpy[i, 0], agents_numpy[i, 0] + np.cos(rots_numpy[i, 0])],
            y=[agents_numpy[i, 1], agents_numpy[i, 1] + np.sin(rots_numpy[i, 0])],
            mode='lines+text',
            textposition='top center',
            text=['', f'{rots_numpy[i, 0] / (2 * np.pi):.2f}'],
            name=f'Rotation {i+1}',
            line=dict(color='purple', dash='dot')
        ))


        # Plot line from agent to its goal
        fig.add_trace(go.Scatter(
            x=[agents_numpy[i, 0], goals_numpy[i, 0]],
            y=[agents_numpy[i, 1], goals_numpy[i, 1]],
            mode='lines',
            name=f'Agent to Goal {i+1}',
            line=dict(color='orange', dash='dash')
        ))

        # Plot frames if option enabled
        if plot_frames is not None:
            frame = local_frame(rots)
            for i in range(agents_numpy.shape[0]):
                x_dir = frame[i, 0, 0].item()
                y_dir = frame[i, 1, 0].item()
                x_pos = agents_numpy[i, 0]
                y_pos = agents_numpy[i, 1]

                # Add frame directions
                fig.add_trace(go.Scatter(
                    x=[x_pos, x_pos + x_dir],
                    y=[y_pos, y_pos + y_dir],
                    mode='lines',
                    name=f'Frame {i+1} X-direction',
                    line=dict(color='cyan', dash='solid')
                ))

                x_dir = frame[i, 0, 1].item()
                y_dir = frame[i, 1, 1].item()
                fig.add_trace(go.Scatter(
                    x=[x_pos, x_pos + x_dir],
                    y=[y_pos, y_pos + y_dir],
                    mode='lines',
                    name=f'Frame {i+1} Y-direction',
                    line=dict(color='magenta', dash='solid')
                ))

    # Set up plot layout
    fig.update_layout(
        title="Goals, Agents, Velocity, Rotations, Frames and Lines to Goals",
        xaxis_title="X",
        yaxis_title="Y",
        legend_title="Legend",
        showlegend=True,
        xaxis=dict(scaleanchor="y", scaleratio=1, range=[-3, 3]),  # Maintain equal x:y ratio, set range
        yaxis=dict(range=[-3, 3])  # Set range
    )

    # Show the figure
    fig.show()


def localize_all_wrt_one_agent(agents, target_agent, goals, rots, velocity):
    frame = local_frame(rots)[target_agent]
    frame = torch.stack([frame] * agents.shape[0], dim=0)
    center = agents[target_agent]
    agents = localize(agents, frame, center=center)
    velocity = localize(velocity, frame)
    goals = localize(goals, frame, center=center)
    rots = rots - rots[target_agent]
    return agents, goals, rots, velocity

def local_frame(rotation):
    c = torch.cos(rotation)
    s = torch.sin(rotation)
    # For row vectors, use R(-theta)^T = [[cos, -sin], [sin, cos]]
    frame = torch.cat(
        [
            torch.stack([c, -s], dim=-1),
            torch.stack([s, c], dim=-1)
        ],
        dim=-2,
    )
    return frame




def localize(pos, frame, center=None):
    if center is not None:
        pos = pos - center
    return torch.bmm(
        pos.unsqueeze(1),
        frame,
    ).squeeze(1)
# Generate data
goals = torch.rand([3, 2]) * 4 - 2
agents = torch.rand([3, 2]) * 4 - 2
velocity = torch.rand([3, 2])
rots = torch.rand([3, 1]) * 2 * torch.pi

plot_all(goals, agents, velocity, rots, plot_frames=True)
plot_all(goals, agents, velocity, rots, frame_idx=0, plot_frames=True)
plot_all(goals, agents, velocity, rots, frame_idx=1, plot_frames=True)