In [None]:
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture

print('CIFAR10-DVS downloadable', CIFAR10DVS.downloadable())
print('resource, url, md5/n', CIFAR10DVS.resource_url_md5())

print('DVS128Gesture downloadable', DVS128Gesture.downloadable())
print('resource, url, md5/n', DVS128Gesture.resource_url_md5())

In [None]:
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture

root_dir = '../datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')

In [None]:
#打印一帧
from spikingjelly.datasets import play_frame
import random

# 获取所有类别标签
labels = set()
for _, label in train_set:
    labels.add(label)
    if len(labels) == 11:  # DVS128Gesture 共11类
        break

import matplotlib.pyplot as plt

shown = set()
for idx in range(len(train_set)):
    frame, label = train_set[idx]
    if label not in shown:
        print(f'类别: {label}')
        # 假设 frame shape 为 [T, C, H, W]，你可以只显示第一帧的第一个通道
        plt.imshow(frame[0, 0])
        plt.title(f'类别: {label}')
        plt.show()
        shown.add(label)
    if len(shown) == 11:
        break

In [None]:
#固定时间积分，10ms一次
import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask, dvs128_gesture
root='../datasets/DVS128Gesture'
train_set = dvs128_gesture.DVS128Gesture(root, data_type='frame', duration=1000000, train=True)
for i in range(5):
    x, y = train_set[i]
    print(f'x[{i}].shape=[T, C, H, W]={x.shape}')
train_data_loader = DataLoader(train_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
    print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
    print(f'x_len={x_len}')
    mask = padded_sequence_mask(x_len)  # mask.shape = [T, N]
    print(f'mask=\n{mask.t().int()}')
    break

In [None]:
#简单的空间-时间滤波
%pip install scipy
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
from scipy.spatial import cKDTree

def fast_filter_events(events: Dict, min_neighbors=1, time_window=1000):
    """
    用KDTree加速的空间-时间滤波
    """
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    # 构建空间-时间特征
    features = np.stack([t // time_window, x, y], axis=1)
    tree = cKDTree(features)
    # 查询每个点的邻居数（包括自身）
    counts = tree.query_ball_point(features, r=1, return_length=True)
    keep = np.where(counts > min_neighbors)[0]
    return {
        't': t[keep],
        'x': x[keep],
        'y': y[keep],
        'p': p[keep]
    }

def integrate_events_to_2_frames_denoised(events: Dict, H: int, W: int):
    # 用加速版滤波
    events = fast_filter_events(events, min_neighbors=1, time_window=1000)
    if len(events['t']) == 0:
        # 没有事件，直接返回全零帧
        return np.zeros([2, 2, H, W])
    index_split = np.random.randint(low=0, high=len(events['t']))
    frames = np.zeros([2, 2, H, W])
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
    frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, len(t))
    return frames

In [None]:
#空间-时间-极性一致性滤波
from typing import Dict
import numpy as np
import spikingjelly.datasets as sjds
from scipy.spatial import cKDTree

def polarity_consistent_filter_events(events: Dict, min_neighbors=1, time_window=1000):
    """
    空间-时间-极性一致性滤波
    """
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    features = np.stack([t // time_window, x, y], axis=1)
    tree = cKDTree(features)
    # 查询每个点的邻居索引
    neighbors = tree.query_ball_point(features, r=1)
    keep = []
    for i, idxs in enumerate(neighbors):
        # 统计极性一致的邻居数（包括自己）
        same_polarity = np.sum(p[idxs] == p[i])
        if same_polarity > min_neighbors:
            keep.append(i)
    keep = np.array(keep)
    return {
        't': t[keep],
        'x': x[keep],
        'y': y[keep],
        'p': p[keep]
    }

def integrate_events_to_2_frames_denoised2(events: Dict, H: int, W: int):
    # 用极性一致性滤波
    events = polarity_consistent_filter_events(events, min_neighbors=1, time_window=1000)
    if len(events['t']) == 0:
        return np.zeros([2, 2, H, W])
    index_split = np.random.randint(low=0, high=len(events['t']))
    frames = np.zeros([2, 2, H, W])
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
    frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, len(t))
    return frames

In [None]:
train_set = DVS128Gesture(
    root_dir,
    train=True,
    data_type='frame',
    custom_integrate_function=integrate_events_to_2_frames_denoised2
)

In [None]:
#自定义积分方法
from typing import Dict
import spikingjelly.datasets as sjds
import numpy as np
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
    index_split = np.random.randint(low=0, high=events['t'].__len__())
    frames = np.zeros([2, 2, H, W])
    t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
    frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
    frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
    return frames
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)

In [None]:
from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)

In [None]:
from spikingjelly.datasets import play_frame
import random

# 获取所有类别标签
labels = set()
for _, label in train_set:
    labels.add(label)
    if len(labels) == 11:  # DVS128Gesture 共11类
        break

import matplotlib.pyplot as plt

shown = set()
for idx in range(len(train_set)):
    frame, label = train_set[idx]
    if label not in shown:
        print(f'类别: {label}')
        # 假设 frame shape 为 [T, C, H, W]，你可以只显示第一帧的第一个通道
        plt.imshow(frame[0, 0])
        plt.title(f'类别: {label}')
        plt.show()
        shown.add(label)
    if len(shown) == 11:
        break