In [18]:
import os
import glob
import zipfile
import h5py
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import lava.lib.dl.slayer as slayer
import lava.lib.dl.slayer.io as sio

import IPython.display as display
from matplotlib import animation
Height = 720
Width = 1280
df_f = 2
roix =[int(0.67*Width)//df_f,int(0.70*Width)//df_f]
roiy = [int(0*Height)//df_f,int(1*Height)//df_f]

width = roix[1]-roix[0]
height = roiy[1] - roiy[0]

heightoffset = roiy[0]
widthoffset = roix[0]

In [51]:
# path setting
Dataset_path = '../data/'

In [52]:
# Dataset for SNN
class FCDataset(Dataset):
    '''
        pathlist: list of path to the classes
        sampling_time: total duration of the event file
        sample_bins: number of bins to sample the event file
        x: width of the sensor
        y: height of the sensor
    '''
    def __init__(
        self, pathlist=[],
        sampling_time=0.5e-6, sample_bins=100,x=128,y=128):
            super(FCDataset, self).__init__()
            self.classnum = len(pathlist)
            self.pathlist = pathlist
            self.sampling_time = sampling_time
            self.sample_bins = sample_bins
            self.data = []
            self.label = []
            self.x = x
            self.y = y
            for idx, path in enumerate(pathlist):
                eventflielist = glob.glob(f'{path}/xypt_*.npy')
                for eventfile in eventflielist:
                    event = sio.read_np_spikes(eventfile,fmt='xypt',time_unit=0.0000005)
                    self.data.append(event)
                    self.label.append(idx)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
            event = self.data[idx]
            #print(event)
            spike = event.fill_tensor(
                torch.zeros(2, self.y, self.x, self.sample_bins),
                sampling_time=self.sampling_time,
                )
            label = self.label[idx]
            return spike.reshape(-1, self.sample_bins), label
        
pathlist = [f'{Dataset_path}numpysp2/3um_15uL_-25bias_-25off_25fo_lux+2/',
            f'{Dataset_path}numpysp2/8um_2000D_10uL_50to10min/',
            f'{Dataset_path}numpysp2/15um_10uL_50to10min/']
            

In [40]:
#basic network
class Network(torch.nn.Module):
    def __init__(self,x,y,out):
        super(Network, self).__init__()

        neuron_params = {
                'threshold'     : 1.25,
                'current_decay' : 0.25,
                'voltage_decay' : 0.03,
                'tau_grad'      : 0.03,
                'scale_grad'    : 3,
                'requires_grad' : True,     
            }
        neuron_params_drop = {**neuron_params, 'dropout' : slayer.neuron.Dropout(p=0.05),}
        
        self.blocks = torch.nn.ModuleList([
                slayer.block.cuba.Dense(neuron_params_drop, x*y*2, 512, weight_norm=True, delay=True),
                slayer.block.cuba.Dense(neuron_params_drop, 512, 512, weight_norm=True, delay=True),
                slayer.block.cuba.Dense(neuron_params, 512, out, weight_norm=True),
            ])
    
    def forward(self, spike):
        for block in self.blocks:
            spike = block(spike)
        return spike
    
    def grad_flow(self, path):
        # helps monitor the gradient flow
        grad = [b.synapse.grad_norm for b in self.blocks if hasattr(b, 'synapse')]

        plt.figure()
        plt.semilogy(grad)
        plt.savefig(path + 'gradFlow.png')
        plt.close()

        return grad

    def export_hdf5(self, filename):
        # network export to hdf5 format
        h = h5py.File(filename, 'w')
        layer = h.create_group('layer')
        for i, b in enumerate(self.blocks):
            b.export_hdf5(layer.create_group(f'{i}'))

In [38]:
#transform txyp events to xypt events
#first time only
for datapath in pathlist:
    npylist = glob.glob(f'{datapath}*.npy')
    eventlist =[]
    for idx, name in enumerate(npylist):
        #print(name)
        data = np.load(name)
        temp = []
        mintime = min(data['t'])
        for i in range(len(data)):
            t = data[i]['t']-mintime
            x = data[i]['x']-widthoffset
            y = data[i]['y']-heightoffset
            p = data[i]['p']
            temp.append([x,y,p,t])
        data = np.array(temp)
        dirname, filename = os.path.split(name)
        new_filename = f'xypt_{filename}'
        new_filepath = os.path.join(dirname, new_filename)
        np.save(new_filepath, data)

In [47]:
x=20
y=360
trained_folder = 'Trained'
os.makedirs(trained_folder, exist_ok=True)
device = torch.device('cuda') 
net = Network(x=x,y=y,out=3).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
full_dataset = FCDataset(pathlist=pathlist,sampling_time=0.0001, sample_bins=100,x=x,y=y)

from torch.utils.data import DataLoader, random_split

# 定义数据集大小的比例
train_size = int(0.8 * len(full_dataset))  # 假设训练集占80%
test_size = len(full_dataset) - train_size  # 剩余的为测试集
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])
spike_tensor, label = test_dataset[np.random.randint(len(test_dataset))]
spike_tensor = spike_tensor.reshape(2, 360, 20, -1)
testevent = sio.tensor_to_event(spike_tensor.cpu().data.numpy())
anim = testevent.anim(plt.figure(figsize=(5, 10)), frame_rate=240)
anim.save(f'gifs/input.gif', animation.PillowWriter(fps=24), dpi=300)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
error = slayer.loss.SpikeRate(true_rate=0.2, false_rate=0.03, reduction='sum').to(device)
stats = slayer.utils.LearningStats()
assistant = slayer.utils.Assistant(net, error, optimizer, stats, classifier=slayer.classifier.Rate.predict)



In [48]:
epochs = 100

for epoch in range(epochs):
    for i, (input, label) in enumerate(train_loader): # training loop
        output = assistant.train(input, label)
    print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')
    
    for i, (input, label) in enumerate(test_loader): # training loop
        output = assistant.test(input, label)
    print(f'\r[Epoch {epoch:2d}/{epochs}] {stats}', end='')
     
    if epoch%20 == 19: # cleanup display
        print('\r', ' '*len(f'\r[Epoch {epoch:2d}/{epochs}] {stats}'))
        stats_str = str(stats).replace("| ", "\n")
        print(f'[Epoch {epoch:2d}/{epochs}]\n{stats_str}')
    
    if stats.testing.best_accuracy:
        torch.save(net.state_dict(), trained_folder + '/network.pt')
    stats.update()
    stats.save(trained_folder + '/')
    net.grad_flow(trained_folder + '/')

                                                                                                                                                                                            
[Epoch 19/100]
Train loss =     0.35395 (min =     0.34658)     accuracy = 0.99520 (max = 0.99431)  
Test  loss =     0.63297 (min =     0.58552)     accuracy = 0.94452 (max = 0.94879) 
                                                                                                                                                                                            
[Epoch 39/100]
Train loss =     0.30917 (min =     0.31615)     accuracy = 0.99626 (max = 0.99742)  
Test  loss =     0.66170 (min =     0.57850)     accuracy = 0.94346 (max = 0.95235) 
                                                                                                                                                                                            
[Epoch 59/100]
Train loss =     0.30690 (min =     0.29111)  