In [None]:
!pip install --quiet snntorch tonic

# Importar librerias

In [None]:
import tonic
import matplotlib.pyplot as plt
from IPython.display import HTML
import torch
import torch.nn as nn
import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import time

# Eventos a Frames

In [None]:
def to_frames(events):
    frame_transform = tonic.transforms.ToFrame(
        sensor_size=tonic.datasets.DVSGesture.sensor_size, 
        n_time_bins=100)
    return frame_transform(events)

# Cargar Dataset

In [None]:
dataset_path = '/kaggle/input/create-dvs128gesture-tonic-dataset'
train = tonic.datasets.DVSGesture(save_to=dataset_path, train=True)
test = tonic.datasets.DVSGesture(save_to=dataset_path, train=False)
frames, label = train[2]
frames = to_frames(frames)
ani = tonic.utils.plot_animation(frames)
HTML(ani.to_jshtml())

# Aplicar transformaciones al dataset

In [None]:
w,h=32,32
n_frames=32 #100

debug = False

transforms = tonic.transforms.Compose([
    tonic.transforms.Denoise(filter_time=10000), 
    tonic.transforms.Downsample(sensor_size=tonic.datasets.DVSGesture.sensor_size, target_size=(w,h)), 
    tonic.transforms.ToFrame(sensor_size=(w,h,2), n_time_bins=n_frames), 
])

train2 = tonic.datasets.DVSGesture(save_to=dataset_path, transform=transforms, train=True)
test2 = tonic.datasets.DVSGesture(save_to=dataset_path, transform=transforms, train=False)

cached_train = tonic.DiskCachedDataset(train2, cache_path='/temp/dvsgesture/train')
cached_test = tonic.DiskCachedDataset(test2, cache_path='/temp/dvsgesture/test')

frames, label = train2[2]
ani = tonic.utils.plot_animation(frames)
HTML(ani.to_jshtml())

# Configuración de GPU

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    print(f"GPUs Available: {torch.cuda.device_count()}")

# Arquitectura de Red

In [None]:
grad = snn.surrogate.fast_sigmoid(slope=25) # surrogate.atan()
beta = 0.5

net = nn.Sequential(
    nn.Conv2d(2, 12, 5), # in_channels, out_channels, kernel_size
    nn.MaxPool2d(2),
    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
    nn.Conv2d(12, 32, 5),
    nn.MaxPool2d(2),
    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True),
    nn.Flatten(),
    nn.Linear(800, 11), #800
    snn.Leaky(beta=beta, spike_grad=grad, init_hidden=True, output=True)
).to(device)

def forward_pass(net, data):
    spk_rec = []
    snn.utils.reset(net)  
    for step in range(data.size(0)): 
        spk_out, mem_out = net(data[step])
        spk_rec.append(spk_out)
    return torch.stack(spk_rec)

optimizer = torch.optim.Adam(net.parameters(), lr=0.002, betas=(0.9, 0.999))
loss_fn = SF.mse_count_loss(correct_rate=0.8, incorrect_rate=0.2)

loss_hist = []
acc_hist = []
test_acc_hist = []

# Validación del modelo

In [None]:
def validate_model():
    correct, total = 0, 0  
    for batch, (data, targets) in enumerate(iter(test_loader)): 
        data, targets = data.to(device), targets.to(device)
        spk_rec = forward_pass(net, data)         
        correct += SF.accuracy_rate(spk_rec, targets) * data.shape[0]
        total += data.shape[0]
    return correct/total

In [None]:
train_loader = torch.utils.data.DataLoader(cached_train, batch_size=64, shuffle=True, drop_last=True, 
                                           collate_fn=tonic.collation.PadTensors(batch_first=False))
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=32, shuffle=False, drop_last=True, 
                                          collate_fn=tonic.collation.PadTensors(batch_first=False))

In [None]:
start_time = time.time()
best_acc = 0.0
num_epochs = 400
cnt = 0

for epoch in range(num_epochs):
    for batch, (data, targets) in enumerate(iter(train_loader)):
        data = data.to(device)
        data = data.float()
        targets = targets.to(device)
        net.train()
        
        spk_rec = forward_pass(net, data)
        loss = loss_fn(spk_rec, targets)

        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        
        loss_hist.append(loss.item())

        acc = SF.accuracy_rate(spk_rec, targets)
        acc_hist.append(acc)

        if cnt % 50 == 0:
            print(f"Epoch {epoch}, Iteration {batch} \nTrain Loss: {loss.item():.2f}")
            print(f"Train Accuracy: {SF.accuracy_rate(spk_rec, targets) * 100:.2f}%")
            test_acc = validate_model()
            test_acc_hist.append(test_acc)
            print(f"Test Accuracy: {test_acc * 100:.2f}%\n")
            
            if test_acc > best_acc:
                print(f"New best model found! Saving model at epoch {epoch} with accuracy {test_acc * 100:.2f}%")
                best_acc = test_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_acc': best_acc,
                    'loss': loss.item(),
                }, "best_model.pth")

        cnt+=1

end_time = time.time()


elapsed_time = end_time - start_time


minutes, seconds = divmod(elapsed_time, 60)
seconds, milliseconds = divmod(seconds, 1)
milliseconds = round(milliseconds * 1000)


print(f"Elapsed time: {int(minutes)} minutes, {int(seconds)} seconds, {milliseconds} milliseconds")

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18,4))

# Plot Train Accuracy
axes[0].plot(acc_hist)
axes[0].set_title("Train Set Accuracy")
axes[0].set_xlabel("Iteration")
axes[0].set_ylabel("Accuracy")

# Plot Test Accuracy
axes[1].plot(test_acc_hist)
axes[1].set_title("Test Set Accuracy")
axes[1].set_xlabel("Iteration")
axes[1].set_ylabel("Accuracy")

# Plot Training Loss
axes[2].plot(loss_hist)
axes[2].set_title("Loss History")
axes[2].set_xlabel("Iteration")
axes[2].set_ylabel("Loss")

plt.show()

In [None]:
def to_frames(events):
     
    frame_transform = tonic.transforms.ToFrame(
        sensor_size=tonic.datasets.DVSGesture.sensor_size, 
        n_time_bins=100)
    return frame_transform(events)

# Cargar el modelo entrenado

In [None]:
model_path = "/kaggle/working/best_model.pth" 
checkpoint = torch.load(model_path, weights_only=False)
net.load_state_dict(checkpoint['model_state_dict'])
net.eval() 

# Inferencia de solo texto

In [None]:
# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                         collate_fn=tonic.collation.PadTensors(batch_first=False),)

# Lista para almacenar las predicciones y etiquetas correctas
predictions = []
true_labels = []
indices = []

# Realizar 4 inferencias aleatorias
for i in range(4):
    # Obtener un batch aleatorio del conjunto de prueba
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)
  
    # Realizar la inferencia
    
    spk_rec = forward_pass(net, data)
    
    predicted_label = torch.argmax(spk_rec.sum(dim=0), dim=1).item()

    # Almacenar la predicción y la etiqueta correcta
    predictions.append(predicted_label)
    true_labels.append(target.item())
    
# Imprimir los resultados
for pred, true in zip(predictions, true_labels):
    print(f"Predicción: {pred+1}, Etiqueta correcta: {true+1}")
    




# Inferencia con texto y video

In [None]:

# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                         collate_fn=tonic.collation.PadTensors(batch_first=False),)


# Lista para almacenar las predicciones y etiquetas correctas
predictions = []
true_labels = []


for i in range(4):
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)
    
   
    # Suponiendo que data tiene la forma (num_frames, height, width, channels)
    frames = data[:, 0, ...].cpu()
    ani = tonic.utils.plot_animation(frames)
    display(HTML(ani.to_jshtml()))

    
    spk_rec = forward_pass(net, data)
    predicted_label = torch.argmax(spk_rec.sum(dim=0), dim=1).item()
    
    # Realizar la inferencia
    # ... (tu código para realizar la inferencia)
    predictions.append(predicted_label)
    true_labels.append(target.item())
    print(f"Valor predicho: {predicted_label+1}, Etiqueta correcta: {target.item()+1}")


# Imprimir los resultados
for pred, true in zip(predictions, true_labels):
    print(f"Predicción: {pred+1}, Etiqueta correcta: {true+1}")



In [None]:
# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                         collate_fn=tonic.collation.PadTensors(batch_first=False),)

# Lista para almacenar las predicciones y etiquetas correctas
predictions = []
true_labels = []
indices = []

# Realizar 4 inferencias aleatorias
for i in range(4):
    # Obtener un batch aleatorio del conjunto de prueba
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)
  
    # Realizar la inferencia
    
    spk_rec = forward_pass(net, data)
    
    predicted_label = torch.argmax(spk_rec.sum(dim=0), dim=1).item()

    # Almacenar la predicción y la etiqueta correcta
    predictions.append(predicted_label)
    true_labels.append(target.item())
    
# Imprimir los resultados
for pred, true in zip(predictions, true_labels):
    print(f"Predicción: {pred+1}, Etiqueta correcta: {true+1}")
    



In [None]:
import torch
import matplotlib.pyplot as plt
import snntorch as snn

def forward_pass2(net, data):
    spk_rec = []
    mem_rec = []  # Lista para registrar el potencial de membrana en cada paso
    snn.utils.reset(net)  # Resetea los estados internos de las neuronas (importante para secuencias)
    
    for step in range(data.size(0)):  # data.size(0) es el número de pasos temporales
        spk_out, mem_out = net(data[step])  # Obtener los spikes y el potencial de membrana
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)  # Almacenar el potencial de membrana
    
    return torch.stack(spk_rec), torch.stack(mem_rec)

# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                          collate_fn=tonic.collation.PadTensors(batch_first=False),)

for i in range(4):  # Realizar 4 inferencias aleatorias
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)

    # Realizar la inferencia
    spk_rec, mem_rec = forward_pass2(net, data)

    # Crear un subplot para el potencial de membrana
    plt.figure(figsize=(10, 6))

    # Graficar el potencial de membrana (mem_rec)
    neuron_count = mem_rec.size(1)  # Número de neuronas
    for neuron in range(neuron_count):  # Iterar sobre las neuronas
        plt.plot(mem_rec[:, neuron].detach().cpu().numpy())  # Graficar cada neurona sin leyenda

    plt.title(f"Potencial de Membrana durante la Inferencia {i+1}")
    plt.xlabel("Tiempo")
    plt.ylabel("Potencial de Membrana")
    plt.tight_layout()
    plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
import snntorch as snn
from IPython.display import HTML, display

def forward_pass2(net, data):
    spk_rec = []
    mem_rec = []  # Lista para registrar el potencial de membrana en cada paso
    snn.utils.reset(net)  # Resetea los estados internos de las neuronas (importante para secuencias)
    
    for step in range(data.size(0)):  # data.size(0) es el número de pasos temporales
        spk_out, mem_out = net(data[step])  # Obtener los spikes y el potencial de membrana
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)  # Almacenar el potencial de membrana
    
    return torch.stack(spk_rec), torch.stack(mem_rec)

# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                          collate_fn=tonic.collation.PadTensors(batch_first=False),)

# Lista para almacenar las predicciones y etiquetas correctas
predictions = []
true_labels = []

for i in range(4):  # Realizar 4 inferencias aleatorias
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)
    
    # Suponiendo que data tiene la forma (num_frames, height, width, channels)
    frames = data[:, 0, ...].cpu()
    ani = tonic.utils.plot_animation(frames)
    display(HTML(ani.to_jshtml()))

    # Realizar la inferencia
    spk_rec, mem_rec = forward_pass2(net, data)
    
    # Obtener la etiqueta predicha
    predicted_label = torch.argmax(spk_rec.sum(dim=0), dim=1).item()
    
    # Almacenar la predicción y la etiqueta correcta
    predictions.append(predicted_label)
    true_labels.append(target.item())
    
    print(f"Valor predicho: {predicted_label + 1}, Etiqueta correcta: {target.item() + 1}")

    # Crear un subplot para el potencial de membrana
    plt.figure(figsize=(10, 6))

    # Graficar el potencial de membrana (mem_rec)
    neuron_count = mem_rec.size(1)  # Número de neuronas
    for neuron in range(neuron_count):  # Iterar sobre las neuronas
        plt.plot(mem_rec[:, neuron].detach().cpu().numpy())  # Graficar cada neurona sin leyenda

    plt.title(f"Potencial de Membrana durante la Inferencia {i + 1}")
    plt.xlabel("Tiempo")
    plt.ylabel("Potencial de Membrana")
    plt.tight_layout()
    plt.show()

# Imprimir los resultados finales
for pred, true in zip(predictions, true_labels):
    print(f"Predicción: {pred + 1}, Etiqueta correcta: {true + 1}")


In [None]:
import torch
import matplotlib.pyplot as plt
import snntorch as snn
from IPython.display import HTML, display

def forward_pass2(net, data):
    spk_rec = []
    mem_rec = []  # Lista para registrar el potencial de membrana en cada paso
    snn.utils.reset(net)  # Resetea los estados internos de las neuronas (importante para secuencias)
    
    for step in range(data.size(0)):  # data.size(0) es el número de pasos temporales
        spk_out, mem_out = net(data[step])  # Obtener los spikes y el potencial de membrana
        spk_rec.append(spk_out)  # Asegúrate de que spk_out tenga la forma correcta
        mem_rec.append(mem_out)  # Almacenar el potencial de membrana
    
    return torch.stack(spk_rec), torch.stack(mem_rec)

# Obtener un DataLoader para el conjunto de prueba
test_loader = torch.utils.data.DataLoader(cached_test, batch_size=1, shuffle=True, drop_last=True,
                                          collate_fn=tonic.collation.PadTensors(batch_first=False),)

# Lista para almacenar las predicciones y etiquetas correctas
predictions = []
true_labels = []

# Lista de etiquetas para el gráfico de conteo de picos
labels = ['1', '2', '3', '4', '5', '6', '7', '8','9', '10', '11']  # Cambiar según el número de clases

for i in range(4):  # Realizar 4 inferencias aleatorias
    data, target = next(iter(test_loader))
    data = data.to(device)
    target = target.to(device)

    # Verificar si los datos están bien
    print(f"Data shape: {data.shape}, Target: {target}")

    # Suponiendo que data tiene la forma (num_frames, height, width, channels)
    frames = data[:, 0, ...].cpu()
    ani = tonic.utils.plot_animation(frames)
    display(HTML(ani.to_jshtml()))

    # Realizar la inferencia
    spk_rec, mem_rec = forward_pass2(net, data)

    # Verificar si spk_rec tiene datos
    print(f"spk_rec shape: {spk_rec.shape}")

    # Obtener la etiqueta predicha
    predicted_label = torch.argmax(spk_rec.sum(dim=0), dim=1).item()

    # Almacenar la predicción y la etiqueta correcta
    predictions.append(predicted_label)
    true_labels.append(target.item())

    print(f"Valor predicho: {predicted_label + 1}, Etiqueta correcta: {target.item() + 1}")

    # Crear un subplot para el potencial de membrana
    plt.figure(figsize=(10, 6))

    # Graficar el potencial de membrana (mem_rec)
    neuron_count = mem_rec.size(1)  # Número de neuronas
    for neuron in range(neuron_count):  # Iterar sobre las neuronas
        plt.plot(mem_rec[:, neuron].detach().cpu().numpy())  # Graficar cada neurona sin leyenda

    plt.title(f"Potencial de Membrana durante la Inferencia {i + 1}")
    plt.xlabel("Tiempo")
    plt.ylabel("Potencial de Membrana")
    plt.tight_layout()
    plt.show()

    # Graficar el conteo de picos (spike count) para la etiqueta objetivo
    idx = target.item()  # Obtener la etiqueta objetivo como un entero
    if idx < spk_rec.size(1):  # Verificar que el índice esté dentro de los límites
        fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
        print(f"La etiqueta objetivo es: {target.item()}")

        # Graficar el histograma de conteo de picos
        anim = splt.spike_count(spk_rec[:, 0].detach().cpu(), fig, ax, labels=labels,  # Cambiar a [0] para tomar el primer spike
                                animate=True, interpolate=1)

        display(HTML(anim.to_html5_video()))
    else:
        print(f"Índice {idx} fuera de los límites de spk_rec (tamaño: {spk_rec.size(1)})")

# Imprimir los resultados finales
for pred, true in zip(predictions, true_labels):
    print(f"Predicción: {pred + 1}, Etiqueta correcta: {true + 1}")
