In [26]:
# !pip install snntorch
# !pip install tonic

In [94]:
import tonic
from tonic import DiskCachedDataset
import tonic.transforms as transforms
import torch
from torch.utils.data import DataLoader
import numpy as np
from sklearn import linear_model
import time
import lsm_weight_definitions as lsm_wts
import lsm_models

In [95]:
loss_fn = torch.nn.CrossEntropyLoss()

In [116]:
def train_one_epoch(model, epoch_index, training_loader, optimizer, loss_fn):
    running_loss = 0.
    last_loss = 0.
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print(device)
    model.to(device)
    print(next(model.parameters()).is_cuda)
    for i, data in enumerate(training_loader):
        inputs, labels = data
        print(labels.shape)
        print(type(inputs))
        print(inputs.shape)
        optimizer.zero_grad()
        inputs = torch.reshape(inputs, (inputs.shape[1], inputs.shape[0], -1)).to('cuda:0')
        print(inputs.shape)
        print(inputs.device)
        outputs = model(inputs)
        # outputs.to(device)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 25 == 24:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            # tb_x = epoch_index * len(training_loader) + i + 1
            # tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.0

    return last_loss

In [30]:
if __name__ == "__main__":

    #Load dataset (Using NMNIST here)
    sensor_size = tonic.datasets.NMNIST.sensor_size
    frame_transform = transforms.Compose([transforms.Denoise(filter_time=3000),
                                          transforms.ToFrame(sensor_size=sensor_size,time_window=1000)])

    trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
    testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)

    cached_trainset = DiskCachedDataset(trainset, cache_path='./cache/nmnist/train')
    cached_testset = DiskCachedDataset(testset, cache_path='./cache/nmnist/test')

    batch_size = 309
    trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
    testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

    #Set device
    #device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
    device = torch.device("cpu") if torch.cuda.is_available() else torch.device("cpu")
    print(device)

    data, targets = next(iter(trainloader))
    flat_data = torch.reshape(data, (data.shape[0], data.shape[1], -1))
    print(flat_data.shape)

    in_sz = flat_data.shape[-1]

    #Set neuron parameters
    tauV = 15.0
    tauI = 15.0
    th = 18
    curr_prefac = np.float32(1/tauI)
    alpha = np.float32(np.exp(-1/tauI))
    beta = np.float32(1 - 1/tauV)

    Win, Wlsm = lsm_wts.initWeights1(25, 1.5, 0.13, in_sz)
    N = Wlsm.shape[0]
    lsm_net = lsm_models.LSM(N, in_sz, np.float32(curr_prefac*Win), np.float32(curr_prefac*Wlsm), alpha=alpha, beta=beta, th=th).to(device)
    lsm_net.eval()
    #Run with no_grad for LSM
    with torch.no_grad():
        start_time = time.time()
        for i, (data, targets) in enumerate(iter(trainloader)):
            if i%25 == 24:
                print("train batches completed: ", i)
            flat_data = torch.reshape(data, (data.shape[0], data.shape[1], -1)).to(device)
            spk_rec = lsm_net(flat_data)
            lsm_out = torch.mean(spk_rec, dim=0)
            if i==0:
                in_train = torch.mean(flat_data, dim=0).cpu().numpy()
                lsm_out_train = lsm_out.cpu().numpy()
                lsm_label_train = np.int32(targets.numpy())
            else:
                in_train = np.concatenate((in_train, torch.mean(flat_data, dim=0).cpu().numpy()), axis=0)
                lsm_out_train = np.concatenate((lsm_out_train, lsm_out.cpu().numpy()), axis=0)
                lsm_label_train = np.concatenate((lsm_label_train, np.int32(targets.numpy())), axis=0)
        end_time = time.time()

        print("running time of training epoch: ", end_time - start_time, "seconds")

        for i, (data, targets) in enumerate(iter(testloader)):
            if i%25 == 24:
                print("test batches completed: ", i)
            flat_data = torch.reshape(data, (data.shape[0], data.shape[1], -1)).to(device)
            lsm_net.eval()
            spk_rec = lsm_net(flat_data)
            lsm_out = torch.mean(spk_rec, dim=0)
            if i==0:
                in_test = torch.mean(flat_data, dim=0).cpu().numpy()
                lsm_out_test = lsm_out.cpu().numpy()
                lsm_label_test = np.int32(targets.numpy())
            else:
                in_test = np.concatenate((in_test, torch.mean(flat_data, dim=0).cpu().numpy()), axis=0)
                lsm_out_test = np.concatenate((lsm_out_test, lsm_out.cpu().numpy()), axis=0)
                lsm_label_test = np.concatenate((lsm_label_test, np.int32(targets.numpy())), axis=0)

    print(lsm_out_train.shape)
    print(lsm_out_test.shape)

    print(in_train.shape)
    print(in_test.shape)

    print("mean in spiking (train) : ", np.mean(in_train))
    print("mean in spiking (test) : ", np.mean(in_test))

    print("mean LSM spiking (train) : ", np.mean(lsm_out_train))
    print("mean LSM spiking (test) : ", np.mean(lsm_out_test))

    print("training linear model:")
    clf = linear_model.SGDClassifier(max_iter=10000, tol=1e-6)
    clf.fit(lsm_out_train, lsm_label_train)

    score = clf.score(lsm_out_test, lsm_label_test)
    print("test score = " + str(score))

cuda
torch.Size([310, 256, 2312])
train batches completed:  24
train batches completed:  49
train batches completed:  74
train batches completed:  99
train batches completed:  124
train batches completed:  149
train batches completed:  174
train batches completed:  199
train batches completed:  224
running time of training epoch:  160.4107174873352 seconds
test batches completed:  24
(60000, 1000)
(10000, 1000)
(60000, 2312)
(10000, 2312)
mean in spiking (train) :  0.004515322
mean in spiking (test) :  0.0045424583
mean LSM spiking (train) :  0.1648524
mean LSM spiking (test) :  0.16819212
training linear model:
test score = 0.9605


In [113]:
#Load dataset (Using NMNIST here)
sensor_size = tonic.datasets.NMNIST.sensor_size
frame_transform = transforms.Compose([transforms.Denoise(filter_time=3000),
                                        transforms.ToFrame(sensor_size=sensor_size,time_window=1000)])

trainset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=True)
testset = tonic.datasets.NMNIST(save_to='./data', transform=frame_transform, train=False)

cached_trainset = DiskCachedDataset(trainset, cache_path='./cache/nmnist/train')
cached_testset = DiskCachedDataset(testset, cache_path='./cache/nmnist/test')

batch_size = 256
trainloader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)
testloader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False))

#Set device
#device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

data, targets = next(iter(trainloader))
flat_data = torch.reshape(data, (data.shape[1], data.shape[0], -1))
print(flat_data.shape)

in_sz = flat_data.shape[-1]

#Set neuron parameters
tauV = 15.0
tauI = 15.0
th = 18
curr_prefac = np.float32(1/tauI)
alpha = np.float32(np.exp(-1/tauI))
beta = np.float32(1 - 1/tauV)

Win, Wlsm = lsm_wts.initWeights1(25, 1.5, 0.13, in_sz)
N = Wlsm.shape[0]
print(N)
lsm_net = lsm_models.LSM(N, in_sz, np.float32(curr_prefac*Win), np.float32(curr_prefac*Wlsm), alpha=alpha, beta=beta, th=th).to(device)

cuda
torch.Size([256, 311, 2312])
1000


In [114]:
optimizer = torch.optim.SGD(lsm_net.parameters(), lr=0.003, momentum=0.85)

In [117]:
import datetime
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
epoch_number = 0

EPOCHS = 10
lsm_net.to('cpu')
best_vloss = 1_000_000.
loss_arr = []
epoch_ind = []
for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    lsm_net.train(True)
    #train_one_epoch(model, epoch_index, tb_writer, training_loader, optimizer, loss_fn):
    # trainloader.to(device)
    avg_loss = train_one_epoch(lsm_net,epoch_number, trainloader, optimizer, loss_fn)
    loss_arr.append(avg_loss)
    epoch_ind.append(epoch_number+1)
    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    lsm_net.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(testloader):
            vinputs, vlabels = vdata
            voutputs = lsm_net(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    # writer.add_scalars('Training vs. Validation Loss',
                    # { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    # epoch_number + 1)
    # writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(lsm_net.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
cuda
True
torch.Size([256])
<class 'torch.Tensor'>
torch.Size([311, 256, 2, 34, 34])
torch.Size([256, 311, 2312])
cuda:0


RuntimeError: Expected target size [256, 1000], got [256]

In [None]:
import matplotlib.pyplot as plt
plt.plot(epoch_ind, loss_arr)