In [1]:
import sys
sys.path.append('..')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from config.modifier import dynamically_modify_train_config
from modules.utils.fetch import fetch_data_module, fetch_model_module
from omegaconf import OmegaConf, DictConfig

def visualize_image(image, title=None):
    """
    可視化のための画像を表示する関数。
    
    パラメータ:
        image: numpy配列またはtorch.Tensor (C, H, W)
        title: 表示するタイトル（オプション）
    """
    # torch.Tensorの場合numpy配列に変換
    if isinstance(image, torch.Tensor):
        image = image.detach().cpu().numpy()
    
    # (C, H, W) -> (H, W, C) に変換
    if image.shape[0] == 3:  # チャンネル数が先頭の場合
        image = np.transpose(image, (1, 2, 0))
    
    # データをクリップ (特にtorchの可能性を考慮して)
    
    # 可視化
    plt.imshow(image)
    if title:
        plt.title(title)
    plt.axis("off")
    plt.show()

def modify_config(config: DictConfig, dt: int = 50):
    
    config.dataset.path = "../datasets/pre_gen4"
    config.dataset.ev_repr_name = f"event_frame_dt={dt}"
    config.dataset.sequence_length = 10
    config.dataset.batch_size.train = 1
    config.dataset.batch_size.eval = 1 
   

In [3]:
yaml_path = "param.yaml"

In [None]:


config = OmegaConf.load(yaml_path)
config = modify_config(config)
config = dynamically_modify_train_config(config)

data_module = fetch_data_module(config=config)
data_module.setup("test")

model_module = fetch_model_module(config=config)
model_module.setup("test")

In [5]:
test_dataloader = data_module.test_dataloader()

In [6]:
from data.utils.types import DataType
from utils.padding import InputPadderFromShape

data_iter = iter(test_dataloader)
data = next(data_iter)["data"]

input_padder = InputPadderFromShape(desired_hw=data[DataType.EV_REPR][0].shape[2:4])

ev_tensor_sequence = data[DataType.EV_REPR]
sparse_obj_labels = data[DataType.OBJLABELS_SEQ]
is_first_sample = data[DataType.IS_FIRST_SAMPLE]
token_mask_sequence = data.get(DataType.TOKEN_MASK, None)


sequence_len = len(ev_tensor_sequence)


In [None]:
for tidx in range(sequence_len):
    ev_tensors = ev_tensor_sequence[tidx]
    ev_tensors = input_padder.pad_tensor_ev_repr(ev_tensors)

    visualize_image(ev_tensors.squeeze(0), title=f"t={tidx}")