In [7]:
import numpy as np
from torch.utils.data import Dataset, Subset
from collections import defaultdict
import pickle
from snn_delays.config import DATASET_PATH
import os

class SubDataset(Dataset):
    def __init__(self, base_dataset, samples_per_class, target_classes, save_path='indexes'):

        save_indices_path= os.path.join(DATASET_PATH, save_path)
        self.base_dataset = base_dataset
        self.samples_per_class = samples_per_class
        self.target_classes = target_classes
        self.num_classes = len(target_classes)
        
        class_counts = defaultdict(int)
        indices = []
        
        for i, (_, label) in enumerate(base_dataset):
            class_idx = np.argmax(label)
            if class_idx in target_classes and class_counts[class_idx] < samples_per_class:
                indices.append(i)
                class_counts[class_idx] += 1
                
        self.filtered_dataset = Subset(base_dataset, indices)
        self.indices = indices
        
        if save_indices_path:
            with open(save_indices_path, 'wb') as f:
                pickle.dump(self.indices, f)
    
    def __len__(self):
        return len(self.filtered_dataset)
    
    def __getitem__(self, idx):
        if idx >= len(self.filtered_dataset):
            raise IndexError("Index out of range for SubDataset")
        
        img, label = self.filtered_dataset[idx]
        return img, label

In [26]:
from snn_delays.snn import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils import train, get_device
from snn_delays.utils.test_behavior import tb_save_max_last_acc
import torch

device = get_device()
torch.manual_seed(10)

dataset = 'shd'
total_time = 50
batch_size = 32

# DATASET
DL = DatasetLoader(dataset=dataset,
                   caching='memory',
                   num_workers=0,
                   batch_size=batch_size,
                   total_time=total_time,
                   crop_to=1e6)

_, __, dataset_dict = DL.get_dataloaders()

target_classes = [x for x in range(20)]
test_dataset = DL._dataset.test_dataset
train_dataset = DL._dataset.train_dataset

sub_train_dataset = SubDataset(train_dataset, 10, target_classes, f'{dataset}_10_train')
sub_test_dataset = SubDataset(test_dataset, 10, target_classes, f'{dataset}_10_test')

Running on: cuda:0
[CropTime(min=0, max=1000000.0), ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=50, n_event_bins=None, overlap=0, include_incomplete=False)]


In [27]:
train_dataset = MemoryCachedDataset(sub_train_dataset)
test_dataset = MemoryCachedDataset(sub_test_dataset)

train_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=0)

test_loader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=0)

for x, y in train_loader:
    print(x.shape)

torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([8, 50, 1, 700])


In [28]:
ckpt_dir = 'default' 
dataset_dict['num_training_samples'] = 200

snn = SNN(dataset_dict=dataset_dict, structure=(64, 2), connection_type='f',
    delay=(48,16), delay_type='h', tau_m = 'normal',
    win=total_time, loss_fn='mem_sum', batch_size=batch_size, device=device,
    debug=False)

snn.set_network()

snn.to(device)

num_epochs = 100
train(snn, train_loader, test_loader, 1e-3, num_epochs, dropout=0.0, 
    test_behavior=tb_save_max_last_acc, ckpt_dir=ckpt_dir, scheduler=(100, 0.95), test_every=5)


[INFO] Delays: tensor([ 0, 16, 32])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([ 0, 16, 32])

[INFO] Delays o: tensor([0])
1000.0
Delta t: 20.0 ms
mean of normal: -0.541324854612918
training shd50_l2_48d16.t7 for 100 epochs...
Epoch [1/100], learning_rates 0.001000, 0.100000




Step [2/6], Loss: 3.00106
l1_score: 0
Step [4/6], Loss: 3.10128
l1_score: 0
Step [6/6], Loss: 3.08916
l1_score: 0
Time elasped: 2.447533369064331
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [2/6], Loss: 2.93791
l1_score: 0
Step [4/6], Loss: 2.90847
l1_score: 0
Step [6/6], Loss: 2.89381
l1_score: 0
Time elasped: 2.561569929122925
Epoch [3/100], learning_rates 0.001000, 0.100000
Step [2/6], Loss: 2.79530
l1_score: 0
Step [4/6], Loss: 2.88531
l1_score: 0
Step [6/6], Loss: 2.81564
l1_score: 0
Time elasped: 2.387892007827759
Epoch [4/100], learning_rates 0.001000, 0.100000
Step [2/6], Loss: 2.65524
l1_score: 0
Step [4/6], Loss: 2.57609
l1_score: 0
Step [6/6], Loss: 2.76876
l1_score: 0
Time elasped: 2.4930062294006348
Epoch [5/100], learning_rates 0.001000, 0.100000
Step [2/6], Loss: 2.57873
l1_score: 0
Step [4/6], Loss: 2.69775
l1_score: 0
Step [6/6], Loss: 2.61444
l1_score: 0
Time elasped: 2.6294474601745605
Test Loss: 2.561345168522426
Avg spk_count per neuron for all 50 time-st

SSC

In [32]:
from snn_delays.snn import SNN
from snn_delays.utils.dataset_loader import DatasetLoader
from snn_delays.utils.train_utils import train, get_device
from snn_delays.utils.test_behavior import tb_save_max_last_acc
import torch

device = get_device()
torch.manual_seed(10)

dataset = 'ssc'
total_time = 50
batch_size = 32

# DATASET
DL = DatasetLoader(dataset=dataset,
                   caching='memory',
                   num_workers=0,
                   batch_size=batch_size,
                   total_time=total_time,
                   crop_to=1e6)

_, __, dataset_dict = DL.get_dataloaders()

target_classes = [x for x in range(35)]
test_dataset = DL._dataset.test_dataset
train_dataset = DL._dataset.train_dataset

sub_train_dataset = SubDataset(train_dataset, 10, target_classes, f'{dataset}_10_train')
sub_test_dataset = SubDataset(test_dataset, 10, target_classes, f'{dataset}_10_test')

Running on: cuda:0
[CropTime(min=0, max=1000000.0), ToFrame(sensor_size=(700, 1, 1), time_window=None, event_count=None, n_time_bins=50, n_event_bins=None, overlap=0, include_incomplete=False)]


In [33]:
train_dataset = MemoryCachedDataset(sub_train_dataset)
test_dataset = MemoryCachedDataset(sub_test_dataset)

train_loader = DataLoader(train_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=0)

test_loader = DataLoader(test_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            pin_memory=True,
                            num_workers=0)

for x, y in train_loader:
    print(x.shape)

torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([32, 50, 1, 700])
torch.Size([30, 50, 1, 700])


In [34]:
ckpt_dir = 'default' 
dataset_dict['num_training_samples'] = 350

snn = SNN(dataset_dict=dataset_dict, structure=(64, 2), connection_type='f',
    delay=(48,16), delay_type='h', tau_m = 'normal',
    win=total_time, loss_fn='mem_sum', batch_size=batch_size, device=device,
    debug=False)

snn.set_network()

snn.to(device)

num_epochs = 100
train(snn, train_loader, test_loader, 1e-3, num_epochs, dropout=0.0, 
    test_behavior=tb_save_max_last_acc, ckpt_dir=ckpt_dir, scheduler=(100, 0.95), test_every=5)


[INFO] Delays: tensor([ 0, 16, 32])

[INFO] Delays i: tensor([0])

[INFO] Delays h: tensor([ 0, 16, 32])

[INFO] Delays o: tensor([0])
1000.0
Delta t: 20.0 ms
mean of normal: -0.541324854612918
training ssc50_l2_48d16.t7 for 100 epochs...
Epoch [1/100], learning_rates 0.001000, 0.100000




Step [3/10], Loss: 3.57200
l1_score: 0
Step [6/10], Loss: 3.59528
l1_score: 0
Step [9/10], Loss: 3.55838
l1_score: 0
Time elasped: 2.525960922241211
Epoch [2/100], learning_rates 0.001000, 0.100000
Step [3/10], Loss: 3.51128
l1_score: 0
Step [6/10], Loss: 3.54426
l1_score: 0
Step [9/10], Loss: 3.54864
l1_score: 0
Time elasped: 2.2742600440979004
Epoch [3/100], learning_rates 0.001000, 0.100000
Step [3/10], Loss: 3.50547
l1_score: 0
Step [6/10], Loss: 3.59255
l1_score: 0
Step [9/10], Loss: 3.49924
l1_score: 0
Time elasped: 2.3037807941436768
Epoch [4/100], learning_rates 0.001000, 0.100000
Step [3/10], Loss: 3.43836
l1_score: 0
Step [6/10], Loss: 3.52880
l1_score: 0
Step [9/10], Loss: 3.53652
l1_score: 0
Time elasped: 2.3182871341705322
Epoch [5/100], learning_rates 0.001000, 0.100000
Step [3/10], Loss: 3.35322
l1_score: 0
Step [6/10], Loss: 3.46812
l1_score: 0
Step [9/10], Loss: 3.45658
l1_score: 0
Time elasped: 2.482037305831909
Test Loss: 3.507571805607189
Avg spk_count per neuron fo