In [2]:
import numpy as np
import random
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, Dataset

import snntorch as snn
from snntorch import surrogate

import os
from scipy.io import loadmat
from scipy.signal import ellip, lfilter,butter,find_peaks

import matplotlib.pyplot as plt

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

In [4]:
MODEL_TYPE = "SNN"
BATCH_SIZE = 512
EPOCHS = 64

In [5]:
SEED_VALUE = 1337

torch.manual_seed(SEED_VALUE)
np.random.seed(SEED_VALUE)
random.seed(SEED_VALUE)
if torch.cuda.is_available():
    print("using cuda")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(SEED_VALUE)

In [6]:
def draw_figures(original_data, events, label, predict_data, single_label):
    
    ind_ = range(0, 800)
    # print(events.shape, label.shape, predict_data.shape)

    fig, axs = plt.subplots(4, 1, layout='constrained')
    
    axs[0].stem(events[0, ind_])
    axs[0].set_ylabel("Events")

    axs[1].plot(label[0, ind_])
    axs[1].set_ylabel("Label")

    axs[2].plot(predict_data[0, ind_])
    axs[2].set_ylabel("Pred")

    axs[3].plot(single_label[0, ind_])
    axs[3].set_ylabel("Label_Single")

    plt.savefig("./spike_label_plots.jpg")
    plt.close()

In [7]:
def convert_data_to_spiketrain(filepath, filename, multipier):
    MAT = loadmat(os.path.join(filepath, filename + ".mat"))
    data = np.array(MAT['data'])[0]
    spikeTime = np.array(MAT['spike_times'])[0][0][0]
    sampling_interval =np.array(MAT['samplingInterval'][0][0]) * 1e-3
    sample_rate = 1/sampling_interval
    b, a = butter(4, [300*2/sample_rate, 5000*2/sample_rate], btype='band')
    data = lfilter(b, a, data)

    ABS_THD = 4*np.median(np.abs(data) /0.6745)
    data_up = np.copy(data)
    data_up[data_up < ABS_THD] = 0
    peaks, _ = find_peaks(data_up)
    data_down = np.copy(data)
    data_down[data_down > -ABS_THD] = 0
    valleys, _ = find_peaks(abs(data_down))
    median_peak = np.median(data_up[peaks])
    median_valleys = np.median(data_down[valleys])
    spike_amplitude = (median_peak - median_valleys) / 2
    modulation_thd = spike_amplitude * multipier
    
    pulseTrain = np.load(os.path.join(filepath, filename + ".npy"))
    # ON_Threshold = modulation_thd
    # OFF_Threshold = -modulation_thd
    # pulseTrain = delta_modulation_synced(data, ON_Threshold, OFF_Threshold)

    spikeTimeGT = np.array(MAT['OVERLAP_DATA'] > 0).astype(np.float32)
    data_len = spikeTimeGT.shape[1]
    spikeTimeGT = np.insert(spikeTimeGT, 0, [0 for _ in range(22)])
    spikeTimeGT = spikeTimeGT[:data_len].reshape(1, -1)
    
    spikeTimeGT_id = np.array(MAT['spike_times'])[0][0][0] + 22

    return data, pulseTrain, spikeTimeGT, sample_rate, spikeTimeGT_id

In [8]:
def load_data(filepath, filename):
    original_data, pulseTrain, spikeTimeGT, sample_rate, spikeTimeGT_single = convert_data_to_spiketrain(filepath, filename, multipier=0.3)
    
    spike_data_ind = (pulseTrain[0]/10).astype(int).reshape(-1,1)
    spike_num = pulseTrain[1].reshape(-1,1)
    spikeTimeGT_single = spikeTimeGT_single.astype(int)

    spike = np.zeros_like(spikeTimeGT)
    for i, id in enumerate(spike_data_ind):
        spike[:, id] = spike_num[i]

    label_single = np.zeros_like(spikeTimeGT)
    for i, id in enumerate(spikeTimeGT_single):
        label_single[:, id] = np.ones(1)

    draw_figures(original_data=original_data, events=spike, label=spikeTimeGT, predict_data=spikeTimeGT, single_label=label_single)
    
    return original_data, spike, spikeTimeGT, sample_rate, label_single

In [9]:
def transform_to_3d(samples, labels, overlap=True, stride=1, bin_width=2, num_steps=6, model_type="ANN"):
    if overlap:
        advance_num = stride
        bin_width_num = bin_width
    else:
        advance_num = stride
        bin_width_num = stride
    
    new_samples, new_labels = [], []
    temp_sample = torch.zeros((samples.shape[0], int(samples.shape[1] // advance_num), bin_width_num), dtype=torch.float32)
    temp_label = torch.zeros((labels.shape[0], int(samples.shape[1] // advance_num)), dtype=torch.float32)

    for col in range(temp_sample.shape[1]):
        if col <  bin_width_num/advance_num:
            bin_start = 0
            bin_end = int(col * advance_num)
            if col == 0:
                bin_end = 1
            temp_sample[:, col, bin_start:bin_end] = samples[:, bin_start: bin_end]
            # continue
        else:
            bin_start = int(col * advance_num - bin_width_num)
            bin_end = int(col * advance_num)
            temp_sample[:, col, :] = samples[:, bin_start: bin_end]

        # temp_label[:, col] = 1 if 1 in labels[:, bin_start: bin_end] else 0
        temp_label[:, col] = labels[:, col * advance_num]

    if num_steps < bin_width_num:
        sum_num = bin_width_num // num_steps
        temp_sample_num_steps = torch.zeros((int(bin_width_num//num_steps), temp_sample.shape[1], num_steps), dtype=torch.float32)
        for idx in range(num_steps):
            start_idx = idx*sum_num
            end_idx = idx*sum_num + sum_num
            temp_sample_num_steps[:, :, idx] = temp_sample[:, :, start_idx: end_idx].squeeze().t()
            # temp_sample_num_steps[:, :, idx] = torch.sum(abs(temp_sample[:, :, start_idx: end_idx]), dim=2)

        if model_type == 'ANN':
            new_samples.append(temp_sample_num_steps)
        else:
            new_samples.append(torch.sign(temp_sample_num_steps))
    else:

        if model_type == 'ANN':
            new_samples.append(temp_sample)
        else:
            new_samples.append(torch.sign(temp_sample))

    new_labels.append(temp_label)

    return new_samples, new_labels

In [10]:
class SNNModelTraining(nn.Module):
    def __init__(self, input_dim, beta=0.5, mem_threshold=0.5, spike_grad2=surrogate.atan(alpha=2),
                 layer1=64, layer2=16, output_dim=1, dropout_rate=0.1, num_step=6):
        super().__init__()

        self.num_step = num_step
        self.input_dim = input_dim
        self.beta = beta
        self.spike_grad = spike_grad2
        self.mem_threshold = mem_threshold

        self.fc1 = nn.Linear(input_dim, layer1)
        self.fc2 = nn.Linear(layer1, layer2)
        self.fc3 = nn.Linear(layer2, output_dim)

        self.lif1 = snn.Leaky(beta=self.beta, spike_grad=self.spike_grad, threshold=0.5, learn_beta=False, 
                              learn_threshold=False, init_hidden=False, reset_mechanism="none") # 0.8
        self.lif2 = snn.Leaky(beta=self.beta, spike_grad=self.spike_grad, threshold=0.5, learn_beta=False,
                              learn_threshold=False, init_hidden=False, reset_mechanism="none") # 0.4
        self.lif3 = snn.Leaky(beta=self.beta, spike_grad=self.spike_grad, threshold=0.5, learn_beta=False,
                              learn_threshold=False, init_hidden=False, reset_mechanism="none")
        self.dropout = nn.Dropout(dropout_rate)

        self.norm_layer = nn.LayerNorm([self.num_step, self.input_dim])
        self.reset_mem = False
        self.relu = nn.ReLU()
        self.Q_bit = 8
        self.quant_max = 1


    def forward(self, x):
        if self.reset_mem:
            mem1 = self.lif1.init_leaky()
            mem2 = self.lif2.init_leaky()
            mem3 = self.lif3.init_leaky()

        # x = self.norm_layer(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        scale_f=(2**self.Q_bit-1)/self.quant_max
            
        for step in range(self.num_step):
            
            input_ = x[:, step, :]
            cur1 = self.dropout(self.fc1(input_))
            spk1, mem1 = self.lif1(cur1, mem1)

            cur2 = self.dropout(self.fc2(spk1))
            spk2, mem2 = self.lif2(cur2, mem2)

            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            # mem1 = torch.round(scale_f*torch.clamp(mem1, -self.quant_max, self.quant_max))/scale_f
            # mem2 = torch.round(scale_f*torch.clamp(mem2, -self.quant_max, self.quant_max))/scale_f
            # mem3 = torch.round(scale_f*torch.clamp(mem3, -self.quant_max, self.quant_max))/scale_f

        return spk3   # Training Mode only need to return the output

In [11]:
def split_dataset(samples, train_ratio=0.5):
    len_dataset = samples.shape[1]
    train_len = int(len_dataset*train_ratio)
    ind_train = list(range(0, train_len))
    end_ind_val = int(train_len+train_len//2)
    ind_val = list(range(train_len, end_ind_val))
    ind_test = list(range(end_ind_val, len_dataset))

    return ind_train, ind_val, ind_test

In [12]:
class MyDataset(Dataset):
    def __init__(self, samples, labels):
        self.samples = samples
        self.labels = labels

    def __getitem__(self, idx):
        sample = self.samples[:, idx, :]
        # sample = self.samples[:, idx]
        label = self.labels[:, idx]

        return sample, label
    
    def __len__(self):
        return self.samples.shape[1]

In [13]:
def calculate_SPD_metrics(predicted_spike_times, new_labels_single):
    # Initialize counters
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    # Iterate through ground truth spike times
    
    within_window_true = np.where(new_labels_single[:,0]==1)
    
    num_of_tp = 0
    num_of_FN = 0
    num_of_FP = 0
    for item in within_window_true[0]:
        if 1 in predicted_spike_times[item-12: item+12, 0]:
            num_of_tp += 1
        else:
            num_of_FN += 1

    true_positives = num_of_tp
    false_negatives = num_of_FN

    within_window_pred = np.where(predicted_spike_times[:,0]==1)
    for item in within_window_pred[0]:
        if 1 not in new_labels_single[item-12: item+12, 0]:
            num_of_FP += 1
   

    # Calculate sensitivity, FDR, and accuracy
    sensitivity = true_positives / (true_positives + false_negatives)
    fdr = num_of_FP/(true_positives + num_of_FP)
    accuracy = true_positives / (true_positives+num_of_FP+false_negatives)

    return sensitivity, fdr, accuracy

In [14]:
def refractory_label(pred, label_single, bin_width, stride):
    temp_label = torch.zeros((pred.shape[0], pred.shape[1]), dtype=torch.float32)
    for col in range(temp_label.shape[0]):
        if col <  bin_width/stride:
            bin_start = 0
            bin_end = int(col * stride)
            if col == 0:
                bin_end = 1
            label_ = 1 if 1 in label_single[:, bin_start:bin_end] else 0
            temp_label[col, :] = label_
        else:
            bin_start = int(col * stride - stride)
            bin_end = int(col * stride)
            label_ = 1 if 1 in label_single[:, bin_start:bin_end] else 0

            temp_label[col, :] = label_
    temp_label = np.insert(np.array(temp_label), 0, [0 for _ in range(2)])
    temp_label = temp_label[:pred.shape[0]].reshape(1, -1)

    return torch.tensor(temp_label).reshape(-1, 1)

def refractory_pred(pred, refractory_interval=3):
    spike_updated = []
    checkpoint = 0
    pred_true_id = np.where(pred[:,0]==1)
    for item in pred_true_id[0]:
        if item > checkpoint:
            spike_updated.append(item)
            checkpoint = spike_updated[-1] + refractory_interval
    spike_updated = np.array(spike_updated).astype(int)
    pred_refract = np.zeros_like(pred)
    for i, id in enumerate(spike_updated):
        pred_refract[id, 0] = 1

    return torch.tensor(pred_refract)

In [15]:
def accuracy_fn(y_true, y_pred):
    y_pred = y_pred.to('cpu')
    y_true = y_true.to('cpu')
    # x_correct = torch.eq(y_true, y_pred).sum().item() # torch.eq() calculates where two tensors are equal
    correct_num = torch.eq(y_true, y_pred).sum().item()
    # correct = (x_correct+y_correct)/2
    acc = correct_num / y_pred.numel()
    return acc

## Training

In [16]:
def training(dataset, net, model_weight_name, stride, bin_width, num_steps, ind_train, ind_val):
    # dataset = (spikes, labels, label_single)
    samples = torch.tensor(dataset[0][:, ind_train[0]: ind_train[-1]+1])
    labels = torch.tensor(dataset[1][:, ind_train[0]: ind_train[-1]+1])
    labels_single = torch.tensor(dataset[2][:, ind_train[0]: ind_train[-1]+1])

    new_samples, new_labels = transform_to_3d(samples, labels, overlap=True, stride=stride, bin_width=bin_width, num_steps=num_steps, model_type="ANN")

    new_samples = new_samples[0]
    new_labels = new_labels[0]

    training_set = MyDataset(new_samples, new_labels)

    train_loader = DataLoader(
                            dataset=training_set,
                            batch_size=BATCH_SIZE,
                            drop_last=False,
                            shuffle=False,
                        )

    
    criterion = torch.nn.MSELoss()
    # criterion = nn.CrossEntropyLoss()
    optimiser = torch.optim.AdamW(net.parameters(), lr=0.008, 
                                  betas=(0.9, 0.999), weight_decay=0) #0.008
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimiser, T_max=EPOCHS+5)
    best_training_acc, best_val_acc = float("-inf"), float("-inf")
    net.to(DEVICE)
    
    for epoch in tqdm(range(EPOCHS)):
        net.train()
        for i, (sample, label) in enumerate(train_loader):
            if MODEL_TYPE == "SNN":
                net.reset_mem = True
            sample = sample.to(DEVICE)
            label = label.to(DEVICE)

            pred = net(sample)
            # print(pred.shape, label.shape)
            
            # pred_acc, ind_acc = torch.max(pred, dim=1)
            # loss_val = criterion(pred_acc.reshape(-1, 1), label)
            # current_acc = accuracy_fn(label, ind_acc.reshape(-1, 1))
            # loss_val = criterion(pred, label.squeeze().long())
            loss_val = criterion(pred, label)
            
            current_acc = accuracy_fn(label, pred)

            if current_acc > best_training_acc:
                best_training_acc = current_acc

            optimiser.zero_grad()
            loss_val.backward()
            optimiser.step()
            
            # draw_figures(original_data=sample, events=sample.detach().cpu().numpy(), label=label.detach().cpu().numpy(), predict_data=pred.detach().cpu().numpy(), )
        print("Current Training Accuracy: ", current_acc)
        lr_scheduler.step()

        current_val_acc = validation(dataset, net, stride, bin_width, num_steps, ind_val)
        if current_val_acc > best_val_acc:
            best_val_acc = current_val_acc
            torch.save(net.state_dict(), model_weight_name)
        # for name,param in net.named_parameters():
        #     print(name, param)
        print(f"{epoch} validation ACC: {current_val_acc}")

def validation(dataset, net, stride, bin_width, num_steps, ind_val):
    net.eval()
    samples = torch.tensor(dataset[0][:, ind_val[0]: ind_val[-1]+1])
    labels = torch.tensor(dataset[1][:, ind_val[0]: ind_val[-1]+1])
    labels_single = torch.tensor(dataset[2][:, ind_val[0]: ind_val[-1]+1])

    new_samples, new_labels = transform_to_3d(samples, labels, overlap=True, stride=stride, bin_width=bin_width, num_steps=num_steps, model_type="ANN")

    new_samples = new_samples[0]
    new_labels = new_labels[0]

    val_set = MyDataset(new_samples, new_labels)
    
    val_loader = DataLoader(
                        dataset=val_set,
                        batch_size=BATCH_SIZE,
                        drop_last=False,
                        shuffle=False,
                    )
    
    val_acc_final = 0
    pred_rec = []
    with torch.no_grad():
        for i, (sample, label) in enumerate(val_loader):
            if MODEL_TYPE == "SNN":
                net.reset_mem = True
            sample = sample.to(DEVICE)
            label = label.to(DEVICE)

            pred = net(sample)
            pred_rec.append(pred)
            val_acc_final += accuracy_fn(label, pred)
            # print(torch.eq(label, pred))
            # draw_figures(original_data, torch.sum(sample, dim=2).cpu().numpy(), label.cpu().numpy(), pred.cpu().numpy())
    
    return val_acc_final/(i+1)

## Testing

In [17]:
def test(dataset, net, model_weight_name, stride, bin_width, num_steps, ind_test):
    net.load_state_dict(torch.load(model_weight_name))
    net.eval()
    samples = torch.tensor(dataset[0][:, ind_test[0]: ind_test[-1]+1])
    labels = torch.tensor(dataset[1][:, ind_test[0]: ind_test[-1]+1])
    labels_single = torch.tensor(dataset[2][:, ind_test[0]: ind_test[-1]+1])

    new_samples, new_labels_single = transform_to_3d(samples, labels_single, overlap=True, stride=stride, bin_width=bin_width, num_steps=num_steps, model_type="ANN")
    new_samples, new_labels = transform_to_3d(samples, labels, overlap=True, stride=stride, bin_width=bin_width, num_steps=num_steps, model_type="ANN")

    new_samples = new_samples[0]
    new_labels = new_labels[0]
    new_labels_single = new_labels_single[0]
    # print(new_samples.shape, new_labels.shape)

    test_set = MyDataset(new_samples, new_labels)
    
    test_loader = DataLoader(
                        dataset=test_set,
                        batch_size=new_samples.shape[1],
                        drop_last=False,
                        shuffle=False,
                    )
    
    test_acc_final = 0
    pred_rec = []
    with torch.no_grad():
        for i, (sample, label) in enumerate(test_loader):
            if MODEL_TYPE == "SNN":
                net.reset_mem = True
            sample = sample.to(DEVICE)
            label = label.to(DEVICE)

            pred = net(sample)
            pred_rec.append(pred)
            # print(pred[pred==1])
            test_acc_final += accuracy_fn(label, pred)

    new_labels = new_labels.t()
    new_labels_single = new_labels_single.t()
    label_refract = refractory_label(pred, labels_single, bin_width, stride)
    pred_refract = refractory_pred(pred.cpu().numpy())
    # draw_figures(original_data=sample, events=sample.detach().cpu().numpy(), label=label.detach().cpu().numpy(), predict_data=pred_refract.detach().cpu().numpy(), single_label=label_refract.detach().cpu().numpy())
    sensitivity, fdr, accuracy = calculate_SPD_metrics(pred_refract.detach().cpu().numpy(), label_refract.detach().cpu().numpy())
    
    print(f"Sensitivity: {sensitivity:.4f}")
    print(f"FDR: {fdr:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

## Run

In [18]:
file_path = './Simulator_data/'
file_name = 'C_Easy2_noise02' # C_Difficult2_noise005, C_Difficult2_noise01, C_Difficult2_noise015, C_Difficult2_noise02, C_Easy2_noise005, C_Easy2_noise01, C_Easy2_noise015, C_Easy2_noise02
# C_Easy1_noise005, C_Easy1_noise01, C_Easy1_noise015, C_Easy1_noise02

In [19]:
model_weight_name =  "./SNN_weight/" + file_name + "_model_state_dict.pth"
original_data, spikes, labels, sampling_freq, label_single = load_data(file_path, file_name)

sampling_sample_1ms = (sampling_freq/1000).astype(int) #1ms
stride = int(sampling_sample_1ms)
bin_width = int(sampling_sample_1ms)
num_steps = 1

ind_train, ind_val, ind_test = split_dataset(spikes, train_ratio=0.5)

In [20]:
layer1 = 16
layer2 = 2

net = SNNModelTraining(input_dim=int(bin_width//num_steps), num_step=num_steps, layer1=layer1, layer2=layer2)
net.to(DEVICE)

SNNModelTraining(
  (fc1): Linear(in_features=24, out_features=16, bias=True)
  (fc2): Linear(in_features=16, out_features=2, bias=True)
  (fc3): Linear(in_features=2, out_features=1, bias=True)
  (lif1): Leaky()
  (lif2): Leaky()
  (lif3): Leaky()
  (dropout): Dropout(p=0.1, inplace=False)
  (norm_layer): LayerNorm((1, 24), eps=1e-05, elementwise_affine=True)
  (relu): ReLU()
)

In [21]:
training(dataset=(spikes, labels, label_single), 
         net=net, 
         model_weight_name=model_weight_name, 
         stride=stride, 
         bin_width=bin_width, 
         num_steps=num_steps, 
         ind_train=ind_train, 
         ind_val=ind_val)

  0%|          | 0/64 [00:00<?, ?it/s]

Current Training Accuracy:  0.875


  2%|▏         | 1/64 [00:00<00:41,  1.52it/s]

0 validation ACC: 0.8554550438596491
Current Training Accuracy:  0.868421052631579


  3%|▎         | 2/64 [00:01<00:37,  1.63it/s]

1 validation ACC: 0.8592961896929824
Current Training Accuracy:  0.8256578947368421


  5%|▍         | 3/64 [00:01<00:36,  1.68it/s]

2 validation ACC: 0.816149259868421
Current Training Accuracy:  0.17434210526315788


  6%|▋         | 4/64 [00:02<00:35,  1.68it/s]

3 validation ACC: 0.816149259868421
Current Training Accuracy:  0.881578947368421


  8%|▊         | 5/64 [00:03<00:36,  1.62it/s]

4 validation ACC: 0.8634765625
Current Training Accuracy:  0.868421052631579


  9%|▉         | 6/64 [00:03<00:35,  1.64it/s]

5 validation ACC: 0.8644462719298246
Current Training Accuracy:  0.6973684210526315


 11%|█         | 7/64 [00:04<00:34,  1.66it/s]

6 validation ACC: 0.8662280701754386
Current Training Accuracy:  0.881578947368421


 11%|█         | 7/64 [00:04<00:39,  1.44it/s]


RuntimeError: File ./SNN_weight/C_Easy2_noise02_model_state_dict.pth cannot be opened.

In [20]:
test(dataset=(spikes, labels, label_single), 
    net=net, 
    model_weight_name=model_weight_name, 
    stride=stride, 
    bin_width=bin_width, 
    num_steps=num_steps, 
    ind_test=ind_test)

  net.load_state_dict(torch.load(model_weight_name))


Sensitivity: 0.9632
FDR: 0.0531
Accuracy: 0.9138
