In [None]:
import numpy as np
import torch
import plotly.graph_objects as go
from math import atan2
import plotly.colors



In [None]:
import numpy as np
import torch
import math
import plotly.graph_objects as go
from typing import Optional

def plot_bboxes_plotly(bboxes, max_boxes: Optional[int]=None,
                       color="blue", line_width: float=2.0, opacity: float=0.8,
                       show_center: bool = True, top_view: bool = True,
                       orthographic: bool = False, fig = None):
    """
    使用 plotly 可视化 bbox 列表。
    参数:
      bboxes: torch.Tensor or np.ndarray of shape (N, 10)
        列顺序为 [cx, cy, w, l, cz, h, sin, cos, vx, vy]
        其中 sin/cos 表示 yaw = atan2(sin, cos)
      max_boxes: 若指定则只绘制前 max_boxes 个框（便于调试）
      color: 线条颜色 (字符串或 RGB/HEX)
      line_width: 边线宽度
      opacity: 线条透明度 (0-1)
      show_center: 是否绘制中心点
      top_view: 是否默认采用顶视角（相机在 z>0 方向看向原点）
      orthographic: 是否使用正交投影（True 则无透视）
    返回:
      plotly.graph_objects.Figure
    """
    # 转成 numpy
    if isinstance(bboxes, torch.Tensor):
        data = bboxes.detach().cpu().numpy()
    else:
        data = np.asarray(bboxes)

    if data.ndim != 2 or data.shape[1] < 10:
        raise ValueError("bboxes must be shape (N,10)")

    N = data.shape[0]
    if max_boxes is not None:
        N = min(N, max_boxes)
        data = data[:N]

    # 12 条边的顶点索引
    edges = [
        (0,1),(1,2),(2,3),(3,0),  # top rectangle
        (4,5),(5,6),(6,7),(7,4),  # bottom rectangle
        (0,4),(1,5),(2,6),(3,7)   # vertical edges
    ]
    if fig == None:
        fig = go.Figure()

    # 为每个 box 添加线条 trace（把所有 box 的线条合并到一个 trace 也可，但这样更易控制图例）
    # 这里我们把所有边合并成一个 trace，以减少 trace 数量并在图例只显示一次
    all_x, all_y, all_z = [], [], []

    center_xs, center_ys, center_zs = [], [], []

    for row in data:
        cx, cy, w, l, cz, h = float(row[0]), float(row[1]), float(row[2]), float(row[3]), float(row[4]), float(row[5])
        sin_v, cos_v = float(row[6]), float(row[7])

        # 计算偏航角 yaw （注意 atan2(sin, cos)）
        yaw = math.atan2(sin_v, cos_v)

        # half-sizes: length->x方向一半，width->y方向一半，height->z方向一半
        lx = l / 2.0
        wy = w / 2.0
        hz = h / 2.0

        # 局部角点（以 box 中心为原点）
        x_local = np.array([ lx,  lx, -lx, -lx,  lx,  lx, -lx, -lx])
        y_local = np.array([ wy, -wy, -wy,  wy,  wy, -wy, -wy,  wy])
        z_local = np.array([ hz,  hz,  hz,  hz, -hz, -hz, -hz, -hz])

        c = math.cos(yaw)
        s = math.sin(yaw)
        x_rot = c * x_local - s * y_local
        y_rot = s * x_local + c * y_local
        z_rot = z_local

        xs = cx + x_rot
        ys = cy + y_rot
        zs = cz + z_rot

        # 将每条边加入全局数组，用 None 分割线段以便 plotly 识别不连接不同边
        for a,b in edges:
            all_x.extend([float(xs[a]), float(xs[b]), None])
            all_y.extend([float(ys[a]), float(ys[b]), None])
            all_z.extend([float(zs[a]), float(zs[b]), None])

        if show_center:
            center_xs.append(cx)
            center_ys.append(cy)
            center_zs.append(cz)

    # 添加边线 trace（合并）
    fig.add_trace(go.Scatter3d(
        x=all_x, y=all_y, z=all_z,
        mode='lines',
        line=dict(color=color, width=line_width),
        opacity=opacity,
        name='bboxes',
        showlegend=True
    ))

    # 添加中心点（若开启）
    if show_center and len(center_xs) > 0:
        fig.add_trace(go.Scatter3d(
            x=center_xs, y=center_ys, z=center_zs,
            mode='markers',
            marker=dict(size=3, symbol='circle', color=color),
            name='centers',
            showlegend=True,
            opacity=0.9
        ))

    # 布局：白色背景，保持比例
    scene = dict(
        xaxis=dict(title='X', backgroundcolor="white"),
        yaxis=dict(title='Y', backgroundcolor="white"),
        zaxis=dict(title='Z', backgroundcolor="white"),
        aspectmode='data'
    )

    # 初始相机（可选顶视角）
    camera = dict()
    if top_view:
        camera['eye'] = dict(x=0, y=0, z=4)  # z 越大视角越远，根据场景缩放调整
    else:
        camera['eye'] = dict(x=1.5, y=1.5, z=0.8)

    proj = dict(type="orthographic") if orthographic else dict(type="perspective")

    fig.update_layout(
        scene=scene,
        height = 800,
        width = 800,
        scene_camera=dict(**camera, projection=proj),
        paper_bgcolor="white",
        plot_bgcolor="white",
        margin=dict(l=0, r=0, t=0, b=0)
    )

    return fig



In [None]:
test = torch.load('gt_bbox.pt')
fig = plot_bboxes_plotly(test,color='red',show_center=False)
fig.show()

In [None]:
test = torch.load('bbox_projected.pt')
test = test[:,[i for i in range(31)],:].reshape(-1,10)
fig_withpred = plot_bboxes_plotly(test,show_center=False,fig=fig)
fig_withpred.show()

In [None]:
mmatch_result = torch.load('mmatch_result_dict.pt')
mmatch_result_pred = mmatch_result['pred']
mmatch_result_gt = mmatch_result['gt']
fig = plot_bboxes_plotly(mmatch_result_pred,show_center=False)
fig = plot_bboxes_plotly(mmatch_result_gt,color='red',show_center=False,fig = fig)
fig.show()

In [56]:
gt_bbox = torch.load('gt_bbox.pt')
bbox_preds = torch.load('bbox_preds.pt')
for_all = plot_bboxes_plotly(gt_bbox,color='red',show_center=False)
for_all_with_preds = plot_bboxes_plotly(bbox_preds,show_center=False,fig=for_all)
for_all_with_preds.show()