In [1]:
import sys
sys.path.append('../../')

In [None]:
import numpy as np
from tqdm import tqdm
import cv2
import torch
import lightning as pl
from einops import rearrange, reduce

from utils.padding import InputPadderFromShape
from data.utils.types import DatasetMode, DataType
from data.genx_utils.labels import ObjectLabels
from modules.utils.detection import RNNStates
from models.layers.yolox.utils.boxes import postprocess, postprocess_with_motion

# —————————————————————————————————————————
# トラッカー実装
# —————————————————————————————————————————
class TrackedObject:
    def __init__(self, object_id, cx, cy, w, h, dx=0, dy=0):
        self.object_id = object_id
        self.cx = cx
        self.cy = cy
        self.w = w
        self.h = h
        self.dx = dx
        self.dy = dy
        self.lost_frames = 0
    
    def predict_next_position(self):
        self.cx += self.dx
        self.cy += self.dy
        return self.cx, self.cy

class ObjectTracker:
    def __init__(self, max_lost=5, iou_threshold=0.3):
        self.objects = {}
        self.next_id = 0
        self.max_lost = max_lost
        self.iou_threshold = iou_threshold
    
    def iou(self, boxA, boxB):
        xA = max(boxA[0] - boxA[2]/2, boxB[0] - boxB[2]/2)
        yA = max(boxA[1] - boxA[3]/2, boxB[1] - boxB[3]/2)
        xB = min(boxA[0] + boxA[2]/2, boxB[0] + boxB[2]/2)
        yB = min(boxA[1] + boxA[3]/2, boxB[1] + boxB[3]/2)

        interArea = max(0, xB - xA) * max(0, yB - yA)
        boxAArea = boxA[2] * boxA[3]
        boxBArea = boxB[2] * boxB[3]
        return interArea / float(boxAArea + boxBArea - interArea + 1e-6)
    
    def update(self, detections):
        assigned = set()
        # 既存オブジェクトを更新
        for oid, obj in list(self.objects.items()):
            obj.predict_next_position()
            best_iou = 0
            best_det = None
            for det in detections:
                if det in assigned:
                    continue
                i = self.iou((obj.cx, obj.cy, obj.w, obj.h), det)
                if i > best_iou and i > self.iou_threshold:
                    best_iou = i
                    best_det = det
            if best_det is not None:
                obj.cx, obj.cy, obj.w, obj.h = best_det
                assigned.add(best_det)
                obj.lost_frames = 0
            else:
                obj.lost_frames += 1
            if obj.lost_frames > self.max_lost:
                del self.objects[oid]
        
        # 新規オブジェクトを追加
        for det in detections:
            if det not in assigned:
                self.objects[self.next_id] = TrackedObject(self.next_id, *det)
                self.next_id += 1
    
    def get_tracked_objects(self):
        # (id, cx, cy, w, h) のリストを返す
        return [(o.object_id, o.cx, o.cy, o.w, o.h) for o in self.objects.values()]

# —————————————————————————————————————————
# 推論ループ
# —————————————————————————————————————————
dataset2size = {
    "gen1": (304, 240),
    "gen4": (640, 360),
    "VGA":  (640, 480),
}
dataset2labelmap = {
    "gen1": ("car", "pedestrian"),
    "gen4": ("pedestrian","two wheeler","car"),
    "VGA":  ("pedestrian","two wheeler","car"),
}

def inference(
    data: pl.LightningDataModule,
    model: pl.LightningModule,
    ckpt_path: str,
    show_gt: bool,
    show_pred: bool,
    output_path: str,
    fps: int,
    num_sequence: int,
    dataset_mode: DatasetMode
):
    # 動画書き出し準備
    size = dataset2size[data.dataset_name]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_path, fourcc, fps, size)
    
    # DataModule モード切り替え
    if dataset_mode == "train":
        data.setup('fit')
        loader = data.train_dataloader()
        model.setup("fit")
    elif dataset_mode == "val":
        data.setup('validate')
        loader = data.val_dataloader()
        model.setup("validate")
    elif dataset_mode == "test":
        data.setup('test')
        loader = data.test_dataloader()
        model.setup("test")
    else:
        raise ValueError(f"Invalid mode: {dataset_mode}")
    
    num_classes = len(dataset2labelmap[data.dataset_name])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 推論用準備
    if show_pred:
        model.eval().to(device)
        rnn_state = RNNStates()
        padder = InputPadderFromShape(model.in_res_hw)
    
    # チェックポイント読み込み
    if ckpt_path:
        ckpt = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt['state_dict'])
    
    sequence_count = 0
    for batch in tqdm(loader):
        ev_repr = batch["data"][DataType.EV_REPR]
        labels  = batch["data"][DataType.OBJLABELS_SEQ]
        is_first = batch["data"][DataType.IS_FIRST_SAMPLE]
        
        # 新シーケンス開始ならトラッカーをリセット
        if is_first.any():
            sequence_count += 1
            if sequence_count > num_sequence:
                break
            tracker = ObjectTracker(max_lost=5, iou_threshold=0.3)
            prev_states = None
        
        seq_len = len(ev_repr)
        for t in range(seq_len):
            # 入力テンソル準備
            ev = ev_repr[t].to(torch.float32).to(device)
            
            # ——— Ground Truth（必要なら）
            if show_gt:
                cur_labels, valid_idx = labels[t].get_valid_labels_and_batch_indices()
                if len(cur_labels) > 0:
                    labels_yolox = ObjectLabels.get_labels_as_batched_tensor(
                        obj_label_list=cur_labels, format_=model.mdl_config.label.format
                    )
                else:
                    labels_yolox = None
            
            # ——— 推論
            if show_pred:
                ev_padded = padder.pad_tensor_ev_repr(ev)
                if model.mdl.model_type == 'DNN':
                    preds, _ = model.forward(event_tensor=ev_padded)
                else:  # RNN
                    if prev_states is None:
                        rnn_state.reset(worker_id=0, indices_or_bool_tensor=is_first)
                        prev_states = rnn_state.get_states(worker_id=0)
                    preds, _, states = model.forward(
                        event_tensor=ev_padded, previous_states=prev_states
                    )
                    prev_states = states
                    rnn_state.save_states_and_detach(worker_id=0, states=states)
                
                fmt = model.mdl_config.label.format
                if fmt == 'yolox':
                    dets = postprocess(predictions=preds, num_classes=num_classes,
                                       conf_thre=0.1, nms_thre=0.45)
                else:
                    dets = postprocess_with_motion(prediction=preds, num_classes=num_classes,
                                                   conf_thre=0.1, nms_thre=0.45)
                
                # バウンディングボックスを (cx,cy,w,h) リストへ変換
                det_list = []
                for x1, y1, x2, y2, conf, cls in dets[0].cpu().numpy():
                    cx = (x1 + x2) / 2
                    cy = (y1 + y2) / 2
                    w  =  x2 - x1
                    h  =  y2 - y1
                    det_list.append((cx, cy, w, h))
                
                # トラッカー更新・取得
                tracker.update(det_list)
                tracked_objs = tracker.get_tracked_objects()
                # tracked_objs: [(id, cx, cy, w, h), ...]
            
            # ——— フレームごとの描画はここで行う or バッファに貯める
            # visualize(video_writer, ev, labels_yolox, tracked_objs, data.dataset_name)
        
        # シーケンス終了後の書き出し（必要なら）
    
    video_writer.release()
    print("Inference done.")


In [None]:
from omegaconf import OmegaConf

from modules.utils.fetch import fetch_data_module, fetch_model_module
from config.modifier import dynamically_modify_train_config


yaml_path = "sample.config.yaml"
config = OmegaConf.load(yaml_path)
dynamically_modify_train_config(config)
## データセットの読み込み
data = fetch_data_module(config=config)
## モデルの読み込み
module = fetch_model_module(config=config)

ckpt_path = "path/to/your/checkpoint.ckpt"

inference(data=data,
          model=module,
          ckpt_path=ckpt_path,
          show_gt=False,
          show_pred=True,
          output_path="./output.mp4",
          fps=10,
          num_sequence=1,
          dataset_mode="test")