In [None]:
# install libraries first
!pip install snntorch
!pip install tonic

In [None]:
#imports
import urllib.request
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF

import tonic
import tonic.transforms as transforms

from torch import randn_like

In [None]:
# dataset transforms
sensor_size = tonic.datasets.NMNIST.sensor_size
transforms = transforms.Compose([
    transforms.Denoise(filter_time=10000),
    transforms.ToFrame(sensor_size=sensor_size, n_time_bins=100),
    ])

In [None]:
data_path = 'data/nmnist' # Directory where NMNIST dataset is stored
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Use GPU if available

nmnist_train = tonic.datasets.NMNIST(data_path, train=True, transform=transforms)
nmnist_test = tonic.datasets.NMNIST(data_path, train=False, transform=transforms)

In [None]:
config = {
    "num_epochs": 100,  # Number of epochs to train for (per trial)
    "batch_size": 128,  # Batch size
    "seed": 0,  # Random seed

    # Network parameters
    "batch_norm": True,  # Whether or not to use batch normalization
    "dropout": 0.09,  # Dropout rate
    "beta": 0.92,  # Decay rate parameter (beta)
    "threshold": 2.0,  # Threshold parameter (theta)
    "lr": 1.9e-3,  # Initial learning rate
    "slope": 6.0,  # Slope value (k)

    # Fixed params
    "num_steps": 100,  # Number of timesteps to encode input for
    "correct_rate": 0.8,  # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta valuese
}

In [None]:
batch_size = config["batch_size"]

cached_train = tonic.DiskCachedDataset(nmnist_train, cache_path='/temp/dvsgesture/train')
cached_test = tonic.DiskCachedDataset(nmnist_test, cache_path='/temp/dvsgesture/test')

trainloader = DataLoader(cached_train, shuffle=True, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))
testloader = DataLoader(cached_test, shuffle=True, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))
frames, target = next(iter(trainloader))

In [None]:
class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.thr = config["threshold"]
        self.slope = config["slope"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]
        self.batch_norm = config["batch_norm"]
        self.p1 = config["dropout"]
        self.spike_grad = surrogate.fast_sigmoid(self.slope)
        # self.init_net()

        # Initialize Layers
        self.conv1 = nn.Conv2d(2, 16, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.conv2 = nn.Conv2d(16, 32, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.fc1 = nn.Linear(32 * 5 * 5, 10, bias=False)
        self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.dropout = nn.Dropout(self.p1)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Forward pass

        for step in range(self.num_steps):
            cur1 = F.avg_pool2d(self.conv1(x[step]), 2)
            if self.batch_norm:
                cur1 = self.conv1_bn(cur1)

            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = F.avg_pool2d(self.conv2(spk1), 2)
            if self.batch_norm:
                cur2 = self.conv2_bn(cur2)

            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.dropout(self.fc1(spk2.flatten(1)))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

net = Net(config).to(device)

In [None]:
class Noisy_Inference(torch.autograd.Function):
    """
    Function taking the weight tensor as input and applying gaussian noise with standard deviation
    (noise_sd) and outputing the noisy version for the forward pass, but keeping track of the
    original de-noised version of the weight for the backward pass
    """
    noise_sd = 1.0e-1 # Change the strength of the noise to be injected into the forwared weights here

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we add some noise from a gaussian distribution
        """
        ctx.save_for_backward( input )
        weight = input.clone()
        delta_w = 2*torch.abs( weight ).max()
        # sd of the sum of two gaussians, given we have pos and neg devices in the chips
        # delta_w = torch.sqrt( delta_w**2 + delta_w**2 )
        noise = torch.randn_like(weight)*( Noisy_Inference.noise_sd * delta_w )
        return torch.add( weight, noise )

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we simply copy the gradient from upward in the computational graph
        """
        input, = ctx.saved_tensors
        weight = input.clone()
        return grad_output
noiser = Noisy_Inference.apply

In [None]:
optimizer = torch.optim.Adam(net.parameters(),
    lr=config["lr"], betas=config["betas"]
)

criterion = SF.mse_count_loss(correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
def train(config, net, trainloader, criterion, optimizer, device="cuda", scheduler=None):
    """Complete one epoch of training."""

    net.train()
    loss_accum = []
    i = 0
    for data, labels in trainloader:
        data, labels = data.to(device), labels.to(device)
        # print(data.shape)
        spk_rec, _ = net(data.permute(0, 1, 2, 3, 4))
        loss = criterion(spk_rec, labels)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        loss_accum.append(loss.item() / config["num_steps"])

    return loss_accum

def test(config, net, testloader, device="cuda"):
    """Calculate accuracy on full test set."""
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, _ = net(images.permute(0, 1, 2, 3, 4))
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

In [None]:
loss_list = []


patience = 5 # Number of epochs to wait before stopping
min_delta = 0.0005 # Minimum change in loss to qualify as an improvement
patience_counter = 0
best_loss = float('inf')
print(f"=======Training Network=======")
# Train
for epoch in range(config['num_epochs']):
    loss = train(config, net, trainloader, criterion, optimizer,
                 device
                )
    loss_list[i] = loss_list[i] + loss
    
    # Use the average loss of the epoch for early stopping
    avg_loss = sum(loss) / len(loss) # Calculate the average loss
    # avg_loss = loss[-1]  # Alternatively, use the last loss value
    # Test
    test_accuracy = test(config, net, testloader, device)
    print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy} \tLoss: {avg_loss}")
    if avg_loss < best_loss - min_delta:
        best_loss = avg_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            torch.save(net.state_dict(), f'model.pt')
            break