In [3]:
import tonic.datasets as datasets
import tonic.transforms as transforms
from tonic import DiskCachedDataset
import torch
from torch.utils.data import DataLoader
import tonic
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
class CropTo32(object):
    def __call__(self, frames):
        return frames[..., 1:33, 1:33]

class OnlyPositive(object):
    def __call__(self, frames):
        return frames[:, 1:2, :, :]



# Use native sensor size!
frame_transform = transforms.Compose([
    transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, n_time_bins=200),
    CropTo32(),
    OnlyPositive(),
])

trainset = datasets.NMNIST(
    save_to='./data/nmnist',
    train=True,
    transform=frame_transform,
)


train_loader = DataLoader(
    trainset,
    batch_size=32,
    shuffle=True,
    collate_fn=tonic.collation.PadTensors()
)

In [4]:
frames, targets = next(iter(train_loader))  # frames: [B, T, 1, 32, 32]
sample_idx = 0   # Pick the first sample in the batch
frames_sample = frames[sample_idx]  # [T, 1, 32, 32]
frames_sample = frames_sample.squeeze(1).cpu().numpy()  # [T, 32, 32]

# --- Matplotlib Animation ---
fig, ax = plt.subplots(figsize=(4, 4))
im = ax.imshow(frames_sample[0], cmap='gray', vmin=0, vmax=frames_sample.max())
ax.set_title(f"Label: {targets[sample_idx].item()}")
ax.axis('off')

def animate(i):
    im.set_array(frames_sample[i])
    ax.set_title(f"Frame {i} / {frames_sample.shape[0]-1} | Label: {targets[sample_idx].item()}")
    return [im]

anim = animation.FuncAnimation(fig, animate, frames=frames_sample.shape[0], interval=50, blit=True)

plt.close(fig)  # Prevent duplicate static image

HTML(anim.to_jshtml())