In [None]:
!pip install snntorch

In [None]:
# imports
import snntorch as snn
from snntorch import surrogate
from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import itertools

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# importing sound classification models from torchaudio
import torchaudio

import matplotlib.pyplot as plt
import tqdm

In [None]:
# List the available devices:
print("Available devices:")
print(torch.cuda.device_count())
print("List of devices:")
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))
print(torch.cuda.device(1))
print(torch.cuda.get_device_name(1))
device_ids = [0, 1]

In [None]:
# using gpus if available
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print(device)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # to use both GPUs

## Dataset

In [None]:
os.listdir('/kaggle/input/TrainSet')

In [None]:
# Loading the dataset
# Structure of the dataset is as follows:
# dataset_folder -> Folder_name_of_one_class -> audio_file.wav
dataset_folder = '/kaggle/input/TrainSet'
animals_folder = dataset_folder + '/Animals'+'/animals_segmented'
music_folder = dataset_folder + '/Music'+'/instruments_segmented'
sot_folder = dataset_folder + '/SoT'+'/sound_of_things_segmented'

In [None]:
# Putting the data paths into dictionaries (key: class, value: list of file names)
data_paths = {
    'animals': [os.path.join(animals_folder, file) for file in os.listdir(animals_folder)],
    'music': [os.path.join(music_folder, file) for file in os.listdir(music_folder)],
    'sot': [os.path.join(sot_folder, file) for file in os.listdir(sot_folder)]
}

animals_dict = {
    "animals": 0,
    "music": 1,
    "sot": 2
}

# print len of each class
print(len(data_paths['animals']))
print(len(data_paths['music']))
print(len(data_paths['sot']))
print("Total number of files: ", len(
    data_paths['animals']) + len(data_paths['music']) + len(data_paths['sot']))

print(data_paths['animals'][:5])

In [None]:
# Tuple to tensor of numbers
def name_tuple_to_float_tensor(tuple):
    # Read the content of the tupe and use animals_dict to convert the class name to a number in a new tensor
    return torch.tensor([animals_dict[tuple[i]] for i in range(len(tuple))], dtype=torch.float32)

In [None]:
import torch
from torch.utils.data import Dataset

num_classes = 3

class AudioDataset(Dataset):
    def __init__(self, data_paths, transform=None):
        self.data_paths = data_paths
        self.classes = list(data_paths.keys())
        self.files = sum([data_paths[cls] for cls in self.classes], [])
        self.transform = transform
        self.printPath = False

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        for cls in self.classes:
            if idx < len(self.data_paths[cls]):
                file = self.data_paths[cls][idx]
                classe = cls
                break
            idx -= len(self.data_paths[cls])
        else:
            raise IndexError('Index out of range')

        waveform, sample_rate = torchaudio.load(file)
        if self.transform:
            mfcc = self.transform(waveform)

        if True:
            # Taking a mean tensor of the mfcc 2 channel tensor
            mfcc = torch.mean(mfcc, dim=0).detach()

        if self.printPath:
            return mfcc, sample_rate, classe, file
        else:
            return mfcc, sample_rate, classe

In [None]:
from torchvision.transforms import Compose, RandomApply

transforms = torchaudio.transforms.MFCC(
    sample_rate=48000,
    n_mfcc=20)

transforms_2 = Compose([
    RandomApply([torchaudio.transforms.PitchShift(
        sample_rate=48000, n_steps=2)], p=0.3),
    RandomApply([torchaudio.transforms.FrequencyMasking(
        freq_mask_param=15)], p=0.2),  # SpecAugment
    RandomApply([torchaudio.transforms.TimeMasking(
        time_mask_param=35)], p=0.2),  # SpecAugment
    torchaudio.transforms.MFCC(sample_rate=48000, n_mfcc=20),
    # https://dsp.stackexchange.com/questions/19564/cepstral-mean-normalization
    torchaudio.transforms.SlidingWindowCmn(
        cmn_window=600, min_cmn_window=100, center=False, norm_vars=False),
])

transform_norm_and_MFFC = Compose([
    # https://dsp.stackexchange.com/questions/19564/cepstral-mean-normalization
    torchaudio.transforms.SlidingWindowCmn(
        cmn_window=600, min_cmn_window=100, center=False, norm_vars=False),
    torchaudio.transforms.MFCC(sample_rate=48000, n_mfcc=20),
])

In [None]:
dataset1 = AudioDataset(data_paths, transform=transforms)
dataset2 = AudioDataset(data_paths, transform=transforms_2)
dataset3 = AudioDataset(data_paths, transform=transform_norm_and_MFFC)

print(len(dataset1), len(dataset2), len(dataset3))
mfcctensor, sample_rate, classe = dataset1.__getitem__(0)
print(mfcctensor.shape, sample_rate, classe)

print(mfcctensor)

In [None]:
import librosa.display
import matplotlib.pyplot as plt

# Select one channel of the MFCC tensor
mfcc_channel_1 = mfcctensor.detach().numpy()  # mfcctensor[0].detach().numpy()
print(mfcc_channel_1.shape)
# Plot the MFCC
plt.figure(figsize=(10, 4))
librosa.display.specshow(mfcc_channel_1, x_axis='time')
plt.colorbar()
plt.title('MFCC')
plt.tight_layout()
plt.show()

# Select one channel of the MFCC tensor
# mfcc_channel_2 = mfcctensor[1].detach().numpy()

# Plot the MFCC
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(mfcc_channel_2, x_axis='time')
# plt.colorbar()
# plt.title('MFCC')
# plt.tight_layout()
# plt.show()

# diff = mfcctensor[0] - mfcctensor[1]
# print("max diff: ", torch.max(diff))

In [None]:
# Create a dataloader
# audio_dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

BATCH_SIZE = 32

# Split the dataset into training, validation and test sets
train_size = int(0.8 * (len(dataset1) + len(dataset2) + len(dataset3)))
val_size = int(0.1 * (len(dataset1) + len(dataset2) + len(dataset3)))
test_size = len(dataset1) + len(dataset2) + \
    len(dataset3) - train_size - val_size

# train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset1 + dataset2 + dataset3, [train_size, val_size, test_size])
print(len(train_dataset), len(val_dataset), len(test_dataset))
print(len(train_dataset) + len(val_dataset) + len(test_dataset))

# Create dataloaders for the training, validation and test sets
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# get the first element of the dataloader
dataiter = iter(train_dataloader)
mfcctensor, sample_rate, classe = next(dataiter)
print(mfcctensor.shape, sample_rate, classe)

## Models

In [None]:
# Leaky neuron model, overriding the backward pass with a custom function
class LeakySigmoidSurrogate(nn.Module):
    def __init__(self, beta, threshold=1.0, k=25):

        # Leaky_Surrogate is defined in the previous tutorial and not used here
        super(Leaky_Surrogate, self).__init__()

        # initialize decay rate beta and threshold
        self.beta = beta
        self.threshold = threshold
        self.surrogate_func = self.FastSigmoid.apply

    # the forward function is called each time we call Leaky
    def forward(self, input_, mem):
        # call the Heaviside function
        spk = self.surrogate_func((mem-self.threshold))
        reset = (spk - self.threshold).detach()
        mem = self.beta * mem + input_ - reset
        return spk, mem

    # Forward pass: Heaviside function
    # Backward pass: Override Dirac Delta with gradient of fast sigmoid
    @staticmethod
    class FastSigmoid(torch.autograd.Function):
        @staticmethod
        def forward(ctx, mem, k=25):
            # store the membrane potential for use in the backward pass
            ctx.save_for_backward(mem)
            ctx.k = k
            out = (mem > 0).float()  # Heaviside on the forward pass: Eq(1)
            return out

        @staticmethod
        def backward(ctx, grad_output):
            (mem,) = ctx.saved_tensors  # retrieve membrane potential
            grad_input = grad_output.clone()
            # gradient of fast sigmoid on backward pass: Eq(4)
            grad = grad_input / (ctx.k * torch.abs(mem) + 1.0) ** 2
            return grad, None

In [None]:
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5

lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)

In [None]:
# dataloader arguments
batch_size = BATCH_SIZE

dtype = torch.float

In [None]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [None]:
# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50

In [None]:
def forward_pass(net, num_steps, data):
    mem_rec = []
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(num_steps):
        spk_out, mem_out = net(data)
        # print(spk_out.shape)
        # print(spk_out[:batch_size, :].shape)
        # print(spk_out[(batch_size):, :].shape)
        # assert (spk_out[:batch_size, :] == spk_out[(batch_size):, :]) & (mem_out[:batch_size, :] == mem_out[(batch_size):, :])
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)
    return torch.stack(spk_rec), torch.stack(mem_rec)

In [None]:
def batch_accuracy(train_dataloader, net, num_steps):
    with torch.no_grad():
        total = 0
        acc = 0
        net.eval()

        train_dataloader = iter(train_dataloader)
        for data, _, targets in train_dataloader:
            data = data.cuda()
            data = data.unsqueeze(1)
            targets = name_tuple_to_float_tensor(targets).cuda()
            spk_rec, _ = forward_pass(net, num_steps, data)

            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

    return acc/total

## Simple Model

In [None]:
# Define Network
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.conv1 = nn.Conv2d(1, 12, 5)
        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.conv2 = nn.Conv2d(12, 64, 5)
        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
        self.fc1 = nn.Linear(9552, 3)
        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)

    def forward(self, x):

        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()
        # x = x.unsqueeze(1)

        cur1 = F.max_pool2d(self.conv1(x), 2)
        spk1, mem1 = self.lif1(cur1, mem1)

        cur2 = F.max_pool2d(self.conv2(spk1), 2)
        spk2, mem2 = self.lif2(cur2, mem2)

        cur3 = self.fc1(spk2.view(batch_size, -1))
        spk3, mem3 = self.lif3(cur3, mem3)

        return spk3, mem3

In [None]:
#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,
                              init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,
                              init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(76416, 3),
                    snn.Leaky(beta=beta, spike_grad=spike_grad,
                              init_hidden=True, output=True)
                    )
net = nn.DataParallel(net, device_ids=device_ids)
net = net.cuda()

In [None]:
data, sr, targets = next(iter(train_dataloader))
data = data.cuda()
data = data.unsqueeze(1)
print(data.shape)
targets = name_tuple_to_float_tensor(targets).cuda()
targets = targets.long()

for step in range(num_steps):
    spk_out, mem_out = net(data)

In [None]:
spk_rec, mem_rec = forward_pass(net, num_steps, data)
print(spk_rec.shape, mem_rec.shape)

In [None]:
# already imported snntorch.functional as SF
targets = targets.long()
print(targets.shape)
loss_fn = SF.ce_rate_loss()
loss_val = loss_fn(spk_rec, targets)
print(loss_val)

acc = SF.accuracy_rate(spk_rec, targets)
print(acc * 100, '%')

In [None]:
test_acc = batch_accuracy(val_dataloader, net, num_steps)
print(test_acc * 100, '%')

In [None]:
torch.cuda.empty_cache()

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=1e-5, betas=(0.9, 0.999))
num_epochs = 40
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in tqdm.trange(num_epochs):
    # Training loop
    for data, _, targets in iter(train_dataloader):
        data = data.cuda()
        data = data.unsqueeze(1)
        targets = name_tuple_to_float_tensor(targets).cuda()
        targets = targets.long()

        # forward pass
        net.train()
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        if counter % 50 == 0:
            with torch.no_grad():
                net.eval()

                test_acc = batch_accuracy(val_dataloader, net, num_steps)
                print(
                    f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
                test_acc_hist.append(test_acc.item())

        counter += 1

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()

In [None]:
torch.cuda.empty_cache()

In [None]:
dataset_with_path = AudioDataset(data_paths, transform=transforms)
dataset_with_path.printPath = True
test_dataloader = torch.utils.data.DataLoader(
    dataset_with_path, batch_size=BATCH_SIZE, shuffle=True)

# get the first element of the dataloader
dataiter = iter(test_dataloader)
data, sample_rate, classe, file = next(dataiter)
print(data.shape, sample_rate, classe, file)

data = data.cuda()
data = data.unsqueeze(1)

In [None]:
spk_rec, mem_rec = forward_pass(net, num_steps, data)

In [None]:
from IPython.display import HTML

idx = 0

fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
labels = ['animals', 'music', 'sot']
# plt.rcParams['animation.ffmpeg_path'] = 'C:\\path\\to\\your\\ffmpeg.exe'
print(spk_rec.shape)
#  Plot spike count histogram
anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels,
                        animate=True, interpolate=4)

HTML(anim.to_html5_video())
# anim.save("spike_bar.mp4")

In [None]:
import IPython.display as ipd
print(f"The target label is: {targets[idx]}")

# Read audio file
waveform, sample_rate = torchaudio.load(file[idx])
print(waveform.shape, sample_rate)

# Read audio ipython
ipd.Audio(waveform, rate=sample_rate)

In [None]:
# Confusion matrix
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# Get the predictions for the test set
y_pred = []
y_true = []

with torch.no_grad():
    for data, _, labels, _ in test_dataloader:
        data = data.cuda()
        data = data.unsqueeze(1)
        labels = name_tuple_to_float_tensor(labels).cuda()
        labels = labels.long()
        # Assuming the model output is a tuple (spk3, mem3)
        outputs, _ = net(data)
        _, predicted = torch.max(outputs.data, 1)
        y_pred += predicted.tolist()
        y_true += labels.tolist()

# Compute and print a pretty confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred)

# Create a dataframe from the confusion matrix
df_cm = pd.DataFrame(conf_matrix, index=[i for i in range(num_classes)],
                     columns=[i for i in range(num_classes)])

plt.figure(figsize=(10, 7))
sns.heatmap(df_cm, annot=True, cmap='Blues')

# Print the classification report
class_names = [f'Class {i}' for i in range(num_classes)]
print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
# Save the model
torch.save(net.state_dict(), 'snn_simple_model.pth')