In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull, Delaunay
from multiprocessing import Pool
from tqdm import tqdm
import yaml
import argparse
import importlib

from utils.transforms import Crop
from datasets.utils import relabel, parse_calibration, parse_poses, Trans
from utils.metric import MultiClassMetric

In [None]:
# 🔹 전역 설정
RANGE_X = (-50.0, 50.0)
RANGE_Y = (-50.0, 50.0)
SIZE_XY = (512, 512)  # (H, W)
FRAMES_MAX = 8

In [None]:
# 🔹 데이터 로드 & 초기화
def get_data(data_path, pred_path, fname, task_cfg):
    pts = np.fromfile(os.path.join(data_path, fname), dtype=np.float32).reshape(-1, 4)
    lbl_raw = np.fromfile(os.path.join(pred_path, fname.split(".")[0] + ".label"), dtype=np.uint32)
    sem16 = lbl_raw & 0xFFFF
    sem = relabel(sem16, task_cfg["learning_map"])
    return pts, sem


# 🔹 BEV 격자 좌표 계산
def QuantizeBEV(pcds, range_x=RANGE_X, range_y=RANGE_Y, size_xy=SIZE_XY):
    x, y = pcds[:, 0], pcds[:, 1]
    dx = (range_x[1] - range_x[0]) / size_xy[0]
    dy = (range_y[1] - range_y[0]) / size_xy[1]
    xi = ((x - range_x[0]) / dx).astype(np.int32)
    yi = ((y - range_y[0]) / dy).astype(np.int32)
    return np.stack([xi, yi], axis=-1)  # (N,2)


# 🔹 2D 격자 투표
def determine_voxel_labels_bev(xy, sem_lbl, size_xy=SIZE_XY):
    H, W = size_xy
    C = int(sem_lbl.max().item()) + 1
    valid = (xy[:, 0] >= 0) & (xy[:, 0] < H) & (xy[:, 1] >= 0) & (xy[:, 1] < W)
    coords = xy[valid]
    labels = sem_lbl[valid]
    votes = torch.zeros(H * W, C, device=labels.device, dtype=torch.long)
    idx1d = coords[:, 0] * W + coords[:, 1]
    votes.scatter_add_(0, idx1d.unsqueeze(1).expand(-1, C), F.one_hot(labels, C).to(votes.dtype))
    grid2d = votes.view(H, W, C).argmax(dim=-1)
    return grid2d  # (H,W)


# 🔹 포인트 라벨 역매핑
def get_point_labels_from_voxel_labels_bev(xy, grid2d):
    H, W = grid2d.shape
    N = xy.shape[0]
    out = torch.zeros(N, device=grid2d.device, dtype=torch.long)
    valid = (xy[:, 0] >= 0) & (xy[:, 0] < H) & (xy[:, 1] >= 0) & (xy[:, 1] < W)
    coords = xy[valid]
    out[valid] = grid2d[coords[:, 0], coords[:, 1]]
    return out  # (N,)


# 🔹 DBSCAN + 2D ConvexHull 기반 재분류
def cluster_bev(pts_xy, labels, eps=0.5, min_samples=5):
    idx_fg = (labels == 2).nonzero(as_tuple=False).squeeze()
    if idx_fg.numel() == 0:
        return labels
    pts = pts_xy[idx_fg].cpu().numpy()
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(pts)
    for cl in set(db.labels_):
        if cl < 0:
            continue
        m = db.labels_ == cl
        hull = ConvexHull(pts[m])
        dela = Delaunay(pts[m][hull.vertices])
        all_pts = pts_xy.cpu().numpy()
        in_h = dela.find_simplex(all_pts) >= 0
        stat = int((labels[in_h] == 1).sum().item())
        mov = int((labels[in_h] == 2).sum().item())
        new_lbl = 2 if mov > stat else 1
        affected = idx_fg[m]
        labels[affected] = new_lbl
    return labels


# 🔹 프레임 단위 후처리 (BEV)
def post_processing_bev(idx):
    global files, data_path, pred_path, pred_bf_path, poses, save_path, crop_fov, task_cfg

    # 1) 현재 프레임 로드
    pts_c, lbl_c = get_data(data_path, pred_path, files[idx], task_cfg)
    raw_bf = np.fromfile(os.path.join(pred_bf_path, files[idx].split(".")[0] + ".label"), dtype=np.uint32)
    lbl_bf = relabel(raw_bf & 0xFFFF, task_cfg["learning_map"])

    # 2) 과거 프레임 통합
    inv_cur = np.linalg.inv(poses[idx])
    hist_pts, hist_lbl = [], []
    if idx >= FRAMES_MAX:
        ids = range(idx - FRAMES_MAX, idx)
    else:
        ids = [i for i in range(FRAMES_MAX) if i != idx]
    for h in ids:
        p, l = get_data(data_path, pred_path, files[h], task_cfg)
        mat = inv_cur.dot(poses[h])
        p = Trans(p, mat)
        hist_pts.append(p)
        hist_lbl.append(l)
    h_pts = np.concatenate(hist_pts, axis=0)
    h_lbl = np.concatenate(hist_lbl, axis=0)

    # 3) FOV 크롭
    h_pts, h_lbl, _ = crop_fov(h_pts, h_lbl)
    pts_o, lbl_o = pts_c.copy(), lbl_c.copy()
    pts_c, lbl_c, mask = crop_fov(pts_c, lbl_c)

    # 4) BEV 투표
    all_pts = np.vstack([h_pts[:, :3], pts_c[:, :3]])
    all_lbl = np.concatenate([h_lbl, lbl_c])
    bev_xy = QuantizeBEV(all_pts)
    bev_xy_t = torch.tensor(bev_xy, device="cuda").long()
    sem_t = torch.tensor(all_lbl, device="cuda").long()
    grid2d = determine_voxel_labels_bev(bev_xy_t, sem_t)
    curr_xy = bev_xy_t[h_pts.shape[0] :]
    new_lbl = get_point_labels_from_voxel_labels_bev(curr_xy, grid2d)
    lbl_o[mask] = new_lbl.cpu().numpy()

    # 5) 클러스터 재판정
    pts_xy_t = torch.tensor(pts_o[:, :2], device="cuda")
    lbl_t = torch.tensor(lbl_o, device="cuda").long()
    refined = cluster_bev(pts_xy_t, lbl_t)
    lbl_o = refined.cpu().numpy()

    # 6) 저장
    out = relabel(lbl_o, task_cfg["learning_map_inv"])
    os.makedirs(save_path, exist_ok=True)
    out.tofile(os.path.join(save_path, files[idx].split(".")[0] + ".label"))


# 🔹 평가 지표 계산
def metric_bev(root_seq, save_seq):
    val_path = os.path.join(save_seq, "predictions/")
    gt_path = os.path.join(root_seq, "08/labels/")
    flist = sorted(os.listdir(gt_path))
    with open("datasets/semantic-kitti.yaml", "r") as f:
        cfg = yaml.load(f, Loader=yaml.SafeLoader)
    crit = MultiClassMetric(["static", "moving"])
    for fn in tqdm(flist):
        g = np.fromfile(os.path.join(gt_path, fn), dtype=np.uint32) & 0xFFFF
        gt = torch.tensor(relabel(g, cfg["learning_map"]), device="cuda")
        p = np.fromfile(os.path.join(val_path, fn), dtype=np.uint32) & 0xFFFF
        pr = torch.tensor(relabel(p, cfg["learning_map"]), device="cuda").long()
        crit.addBatch(gt, F.one_hot(pr, num_classes=2))
    print("BEV Best Epoch →", crit.get_metric())

In [None]:
# 🔹 메인
if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--tag",    type=str, default="bev")
    parser.add_argument("--modal",  type=str, default="val")
    args = parser.parse_args()

    cfg_mod = importlib.import_module(args.config.replace(".py","").replace("/","."))
    pGen,_,_,_ = cfg_mod.get_config()
    prefix = pGen.name

    # 데이터셋 설정
    with open("datasets/semantic-kitti.yaml","r") as f:
        task_cfg = yaml.load(f, Loader=yaml.SafeLoader)
    base = "/home/workspace/KITTI/dataset/sequences"
    if args.modal=="val":
        seqs = ["08"]
        pred_root    = f"experiments/{prefix}/{args.tag}/val_results/sequences"
        pred_bf_root = f"experiments/{prefix}/{args.tag}/val_bf_results/sequences"
        save_root    = f"experiments/{prefix}/{args.tag}/bev_refine_val/sequences"
    else:
        seqs = [str(i) for i in range(11,22)]
        pred_root    = f"experiments/{prefix}/{args.tag}/test_results/sequences"
        pred_bf_root = f"experiments/{prefix}/{args.tag}/test_bf_results/sequences"
        save_root    = f"experiments/{prefix}/{args.tag}/bev_refine_test/sequences"

    crop_fov = Crop(dims=(0,1,2), fov=[[-50,-50,-4],[50,50,2]])

    for seq in seqs:
        print(f"▶ Sequence {seq}")
        data_path    = os.path.join(base, seq, "velodyne/")
        calib_path   = os.path.join(base, seq, "calib.txt")
        pose_path    = os.path.join(base, seq, "poses.txt")
        pred_path    = os.path.join(pred_root,    seq, "predictions/")
        pred_bf_path = os.path.join(pred_bf_root, seq, "predictions/")
        save_path    = os.path.join(save_root,    seq, "predictions/")

        files = sorted(os.listdir(data_path))
        poses = parse_poses(pose_path, parse_calibration(calib_path))

        with Pool(8) as p:
            list(tqdm(p.imap(post_processing_bev, range(len(files))),
                      total=len(files)))
        if args.modal=="val":
            metric_bev(os.path.join(base, seq), save_root)