In [1]:
import os
import sys
import json
import numpy as np
import imageio.v3 as iio
import ipywidgets as widgets
from IPython.display import display, clear_output
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from typing import Optional, Tuple

sys.path.append(os.path.abspath(os.path.join('..', 'src')))
from registry import SOLUTIONS
from base import PlaneDetectionResult
from image_utils import prepare_3d_coordinates, load_image



In [2]:
class UIState:
    def __init__(self):
        self.category_dropdown = None
        self.file_dropdown = None
        self.method_dropdown = widgets.Dropdown(options=SOLUTIONS.keys(), description='Метод:')
        self.next_button = None
        self.counter_label = None
        self.output = widgets.Output()
        self.image_dict = {}
        self.file_index = {}
        self.show_gt_checkbox = widgets.Checkbox(
            value=True,
            description='Показывать Ground Truth',
            indent=False
        )

state = UIState()

categories = ["clean", "medium", "heavy"]


In [None]:
def get_ground_truth_poly_points(category, filename, img, gt_path):
    angle_min = -30
    angle_max = 30
    h, w = img.shape[:2]

    json_name = filename.replace('.png', '.jpg.json')
    json_path = os.path.join(gt_path, category, json_name)
    if os.path.exists(json_path):
        with open(json_path, 'r') as f:
            gt = json.load(f)

        if gt.get("objects"):
            poly_px = np.array(gt["objects"][0]["data"])

            px_x = np.array([pt[0] for pt in poly_px])
            px_y = np.array([pt[1] for pt in poly_px])

            gt_angles = np.radians(angle_min + (px_x / (w - 1)) * (angle_max - angle_min))
            gt_times = px_y * 0.02

            gt_r = img[px_y.astype(int), px_x.astype(int)] / 1000.0
            gt_r[gt_r > 20] = 0

            gt_X = gt_r * np.sin(gt_angles)
            gt_Y = gt_times
            gt_Z = -gt_r * np.cos(gt_angles)

            return np.column_stack((gt_X, gt_Y, gt_Z))

    return None

In [None]:
def plot_lidar(X, Y, Z, plane_result: Optional[PlaneDetectionResult] = None, ground_truth: Optional[np.ndarray] = None):
    fig = go.Figure()

    fig.add_trace(go.Scatter3d(
        x=X.ravel(), y=Y.ravel(), z=Z.ravel(),
        mode='markers',
        marker=dict(size=1, color=Z.ravel(), colorscale='Blues', opacity=0.8),
        name="Лидар"
    ))

    def add_region(name: str, color: str):
        points = getattr(plane_result, name, None)
        if points is not None and len(points) > 0:
            fig.add_trace(go.Scatter3d(
                x=points[:, 0],
                y=points[:, 1],
                z=points[:, 2],
                mode='markers',
                marker=dict(size=2, color=color),
                name=name
            ))

    if plane_result:
        add_region("bottom_points", "blue")
        add_region("leftside_points", "red")
        add_region("rightside_points", "green")
        add_region("front_points", "yellow")
        add_region("back_points", "purple")

        A, B, C, D = plane_result.plane_coeffs

        if plane_result.bottom_hull is not None:
            pts = plane_result.bottom_hull
            pts = np.vstack([pts, pts[0]])

            Xp, Yp, Zp = pts[:, 0], pts[:, 1], pts[:, 2]

            Zproj = -(A * Xp + B * Yp + D) / C
            fig.add_trace(go.Mesh3d(
                x=Xp, y=Yp, z=Zproj,
                color='blue',
                opacity=0.4,
                name="Detected Plane",
                showscale=False
            ))

            fig.add_trace(go.Scatter3d(
                x=Xp, y=Yp, z=Zp,
                mode='lines',
                line=dict(color='black', width=5),
                name='Bottom Hull'
            ))
        else:
            x_range = np.linspace(X.min(), X.max(), 10)
            y_range = np.linspace(Y.min(), Y.max(), 10)
            Xgrid, Ygrid = np.meshgrid(x_range, y_range)
            Zgrid = -(A * Xgrid + B * Ygrid + D) / C

            fig.add_trace(go.Surface(
                x=Xgrid, y=Ygrid, z=Zgrid,
                colorscale=[[0, 'lightgreen'], [1, 'lightgreen']],
                showscale=False,
                opacity=0.4,
                name="Detected Plane"
            ))

    # Ground Truth
    if ground_truth is not None:
        gt_X, gt_Y, gt_Z = ground_truth[:, 0], ground_truth[:, 1], ground_truth[:, 2]
        fig.add_trace(go.Scatter3d(
            x=gt_X, y=gt_Y, z=gt_Z,
            mode='lines+markers',
            line=dict(color='orange', width=4),
            marker=dict(size=3, color='orange'),
            name="Ground Truth"
        ))

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
        ),
        width=800,
        height=700
    )
    fig.show()

In [5]:
def update_file_options(category: str):
    files = sorted(f for f in os.listdir(f"../data/images/{category}") if f.endswith(".png"))
    state.image_dict[category] = files
    state.file_index[category] = 0
    state.file_dropdown.options = files
    state.file_dropdown.value = files[0]
    update_counter()

def update_counter():
    cat = state.category_dropdown.value
    idx = state.file_index.get(cat, 0)
    total = len(state.image_dict.get(cat, []))
    state.counter_label.value = f"{idx + 1} / {total}"

def on_next_click(_):
    cat = state.category_dropdown.value
    idx = state.file_index.get(cat, 0) + 1
    files = state.image_dict.get(cat, [])
    if idx < len(files):
        state.file_index[cat] = idx
        state.file_dropdown.value = files[idx]
        update_counter()

def on_file_change(change):
    cat = state.category_dropdown.value
    try:
        idx = state.image_dict[cat].index(change['new'])
        state.file_index[cat] = idx
    except ValueError:
        state.file_index[cat] = 0
    update_counter()
    on_button_click(None)

def on_button_click(_):
    with state.output:
        clear_output()
        category = state.category_dropdown.value
        filename = state.file_dropdown.value
        method = state.method_dropdown.value

        img_path = os.path.join("..", "data", "images", category, filename)
        img = load_image(img_path)
        X, Y, Z = prepare_3d_coordinates(img)
        result = SOLUTIONS[method](img)
        gt = get_ground_truth_poly_points(category, filename, img, "../data/ground_truth")

        print(f"Файл: {filename}")

        plot_lidar(X, Y, Z, result, gt if state.show_gt_checkbox.value else None)


In [6]:
def setup_interface(on_process):
    state.category_dropdown = widgets.Dropdown(options=categories, description='Категория:')
    state.category_dropdown.observe(lambda change: update_file_options(change['new']), names='value')

    state.file_dropdown = widgets.Dropdown(description='Файл:')
    state.file_dropdown.observe(on_file_change, names='value')

    state.next_button = widgets.Button(description="Следующий", button_style='info')
    state.next_button.on_click(on_next_click)

    state.counter_label = widgets.Label(value="")
    
    state.show_gt_checkbox = widgets.Checkbox(
        value=True,
        description='Показывать Ground Truth',
        indent=False
    )
    state.show_gt_checkbox.observe(lambda change: on_button_click(None), names='value')
    state.method_dropdown.observe(lambda change: on_button_click(None), names='value')

    update_file_options(categories[0])

    ui = widgets.VBox([
        state.category_dropdown,
        widgets.HBox([state.file_dropdown, state.next_button, state.counter_label]),
        state.method_dropdown,
        state.show_gt_checkbox,
        state.output
    ])
    display(ui)

setup_interface(on_button_click)


VBox(children=(Dropdown(description='Категория:', options=('clean', 'medium', 'heavy'), value='clean'), HBox(c…