# Test of SNN for Gesture Detection

## Imports

In [None]:
# SNN
import tonic

# Visualization
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Misc
import numpy as np
import numpy.lib.recfunctions as rf

# Core
import os

## Load Dataset Using Tonic

In [None]:
# Get the path to this file and create a data directory
data_path = os.path.join(os.getcwd(), "data")
os.makedirs(data_path, exist_ok=True)

# Load the dataset using Tonic
train = tonic.datasets.DVSGesture(save_to=data_path, train=True)
test = tonic.datasets.DVSGesture(save_to=data_path, train=False)

## Event Extraction

Each set in the dataset contains a series of `events`. An event consists of a XY-coordinate that either increased or decreases in intensity depending on the polarity. The timestamps are the time the event occured in ms.

### Labels
For some reason the labels from the dataset csv are _not_ zero-indexed. Beware.

- **1**: hand_clapping
- **2**: right_hand_wave
- **3**: left_hand_wave
- **4**: right_hand_clockwise
- **5**: right_hand_counter_clockwise
- **6**: left_hand_clockwise
- **7**: left_hand_counter_clockwise
- **8**: forearm_roll_forward
- **8**: forearm_roll_backward
- **9**: drums
- **10**: guitar
- **11**: random_other_gestures

In [None]:
events, label = train[0]

timestamps = events['t']
x_coords = events['x']
y_coords = events['y']
polarities = events['p']

cut = 5
print("Timestamps:", timestamps[:cut])
print("X-coordinates:", x_coords[:cut])
print("Y-coordinates:", y_coords[:cut])
print("Polarities:", polarities[:cut])
print("Label:", label+1) # +1 because the labels are 0-indexed

## Accumulate All Events and Plot

In [None]:
sensor_size = (128, 128)  # Assuming a 64x128 sensor

# Create empty images for ON and OFF events
on_event_image = np.zeros(sensor_size)
off_event_image = np.zeros(sensor_size)

num_events = 10000
for i in range(num_events):
    x, y, p = x_coords[i], y_coords[i], polarities[i]
    if p:
        on_event_image[y, x] += 1
    else:
        off_event_image[y, x] += 1

# Plot the accumulated event frames
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(on_event_image, cmap="Reds")
ax[0].set_title("ON Events")

ax[1].imshow(off_event_image, cmap="Blues")
ax[1].set_title("OFF Events")

plt.show()

## Animation
Bootleg way to animate the events.

Events are binned by time and a frame is generated and saved in `./frames`.

In [None]:
def to_us(seconds):
    """
    Converts seconds to microseconds.
    """
    return seconds * 1e6

def create_frames_dir() -> str:
    """
    Create the frames directory and clean it if it exists.
    """
    data_path = os.path.join(os.getcwd(), "frames")
    if os.path.exists(data_path):
        for file in os.listdir(data_path):
            file_path = os.path.join(data_path, file)
            if os.path.isfile(file_path):
                os.unlink(file_path)
    else:
        os.makedirs(data_path, exist_ok=True)
    return data_path

def time_bin_frames(events, sensor_size, time_bin) -> None:
    """
    Accumulates events in time bins and saves them as frames.

    Args:
        events: The events dictionary.
        time_bin: The time bin in microseconds.
    """
    timestamps = events['t']
    x_coords = events['x']
    y_coords = events['y']
    polarities = events['p']

    frame_path = create_frames_dir()

    on_event_image = np.zeros(sensor_size)
    off_event_image = np.zeros(sensor_size)
    time = time_bin
    for i, (x, y, p) in enumerate(zip(x_coords, y_coords, polarities)):
        on_event_image[y, x] += p
        off_event_image[y, x] += 1 - p
        if timestamps[i] >= time:
            total_events = on_event_image + off_event_image
            plt.imsave(os.path.join(frame_path, f"frame_{time*1e-3}ms.png"), total_events, cmap="Reds")
            on_event_image = np.zeros(sensor_size)
            off_event_image = np.zeros(sensor_size)
            time += time_bin

def frames_to_video(time_bin, frame_dir="./frames") -> HTML:
    """
    Converts the frames to a video.
    """
    frame_dir = "./frames"
    # List images and sort them correctly. This is an absolutely disgusting solution.
    images = sorted(
        [image for image in os.listdir(frame_dir) if image.endswith(".png")],
        key=lambda x: int("".join(filter(str.isdigit, x)))
    )
    images = [os.path.join(frame_dir, image) for image in images]

    fig, ax = plt.subplots()
    ax.axis("off")

    img = plt.imread(images[0])
    im = ax.imshow(img, animated=True)

    def update(frame):
        im.set_array(plt.imread(images[frame]))
        return [im]

    anim = animation.FuncAnimation(fig, update, frames=len(images), interval=time_bin*1e-3, blit=True)
    video = HTML(anim.to_html5_video())
    plt.close()
    return video

In [None]:
events, label = train[0]
time_bin_frames(events, (128, 128), to_us(0.05))
frames_to_video(to_us(0.05))

## Data Processing for Training

In [None]:
w, h = 32, 32
n_frames = 32
debug = False

# Denoise: Removes outlier events with inactive surrounding pixels for 10ms
# Downsample: Downsamples the image to 32x32
# ToFrame: Converts the events to n_frames frames per trail
transforms = tonic.transforms.Compose([
    tonic.transforms.Denoise(filter_time=10000),
    tonic.transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size, target_size=(w,h)),
])

data_path = os.path.join(os.getcwd(), "data")
os.makedirs(data_path, exist_ok=True)

train2 = tonic.datasets.DVSGesture(save_to=data_path, transform=transforms, train=True)
test2 = tonic.datasets.DVSGesture(save_to=data_path, transform=transforms, train=False)

cache_path = os.path.join(os.getcwd(), "./data/cache")
os.makedirs(cache_path, exist_ok=True)
cached_train = train2 if debug else tonic.DiskCachedDataset(train2, cache_path=cache_path)
cached_test = test2 if debug else tonic.DiskCachedDataset(test2, cache_path=cache_path)

events, label = cached_train[0]
time_bin_frames(events, (w, h), to_us(0.05))
frames_to_video(to_us(0.05))