In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os

In [2]:
max_dataset_size = 2500000;
waveform_length = 72;
nb_of_elements = 150000;
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nb_of_datasets = 45;
snr_from = 20;
snr_to = 100;
max_shift = 0;
use_horizontal_flip = False;
use_vertical_flip = False;
filter_type = 'high';
sampling_rate = 24000;
passband = np.array([100], dtype=int);
order = 1;



In [3]:
root_folder = os.path.dirname(os.getcwd())
path_to_recordings = path.join(root_folder,'data/synthesized');
path_to_ground_truth_data = path.join(root_folder,'data/synthesized');
path_to_train_data = path.join(root_folder,'data/train_data_re_id.npy');
path_to_train_labels = path.join(root_folder,'data/train_labels_re_id.npy');
sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn
transform_list =[cnn.FilterSignalUsingButtersWorth(filter_type, sampling_rate, passband, order), cnn.OptimizedZScoreNormalizaton()];



In [4]:
# counts each class
classes_per_recording = np.zeros((nb_of_datasets), dtype='int')
for i in range(0 , nb_of_datasets):
    single_recording_ground_truth = path.join(path_to_ground_truth_data, 'gt_' + str(i + 1) + '.npy');
    gt_data = np.load(single_recording_ground_truth);
    classes_per_recording[i] = np.unique(gt_data[1,:]).size;
    print(np.unique(gt_data[1,:]))
    print(np.min(gt_data))
print(classes_per_recording)

# class counter to concatenate classes together from different recordings
class_counter = np.zeros((nb_of_datasets, 1), dtype='int');
class_counter[0] = 0;
for i in range(1, nb_of_datasets):
    class_counter[i] = class_counter[i-1] + classes_per_recording[i - 1] - 1; 
print(class_counter)



[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12.]
0.0
[0. 1. 2. 3. 4.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.]
0.0
[0. 1. 2.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20.]
0.0
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]
0.0
[0. 1. 2. 3.]
0.0
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
0.0
[0. 1. 2. 3. 4. 5. 6. 7. 8.]
0.0
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]
0.0
[0. 1. 2. 

In [None]:
# couts sample count for each class
nb_of_classes = int(np.sum(classes_per_recording - 1) + 1);


In [None]:
print(classes_per_recording)
print(class_counter)

[17 20 13  5 16 12 18  3 20 21 21 21 11  4 10  9 10  8 19 15  5  7 18  7
 17 14  6 15  4  6 12 11  5  9 13 19 11 14  3  8 16 10 17 14  9]
[[  0]
 [ 16]
 [ 35]
 [ 47]
 [ 51]
 [ 66]
 [ 77]
 [ 94]
 [ 96]
 [115]
 [135]
 [155]
 [175]
 [185]
 [188]
 [197]
 [205]
 [214]
 [221]
 [239]
 [253]
 [257]
 [263]
 [280]
 [286]
 [302]
 [315]
 [320]
 [334]
 [337]
 [342]
 [353]
 [363]
 [367]
 [375]
 [387]
 [405]
 [415]
 [428]
 [430]
 [437]
 [452]
 [461]
 [477]
 [490]]


In [None]:
16+19+12

47

In [None]:
max_dataset_size = max_dataset_size + nb_of_classes - max_dataset_size % nb_of_classes;
samples_per_class = max_dataset_size // nb_of_classes;
print(samples_per_class)
print(max_dataset_size / nb_of_classes)
print(max_dataset_size)


5011
5011.0
2500489


In [None]:
dataset = np.empty((1 , 1 ,waveform_length))
labels = np.empty((1, 0), dtype = "int")
# adds single unit activity data
for i in range(0, nb_of_datasets):
    added_sample_counter = np.zeros((classes_per_recording[i] - 1));
    data_iter_counter = 0;
    total_nb_of_samples = int((classes_per_recording[i] - 1) * samples_per_class);
    temp_spikes = np.zeros((total_nb_of_samples, 1, waveform_length));
    temp_labels = np.zeros((1, total_nb_of_samples), dtype='int');
    print(dataset.shape)
    # generates augmentations until there are same sized samples for each class
    while (data_iter_counter < total_nb_of_samples):
        single_recording_path = path.join(path_to_recordings, 'data_' + str(i + 1) + '.npy');
        single_recording_ground_truth = path.join(path_to_ground_truth_data, 'gt_' + str(i + 1) + '.npy');
        dataset_spikes = cnn.GenerateDataset(single_recording_path, single_recording_ground_truth, waveform_length, 600, snr_from, snr_to, max_shift, use_horizontal_flip, use_vertical_flip, transform_list);
        # adds to temp variable to add to real dataset afterwards
        for j, (data, target) in enumerate(dataset_spikes):
            neuron_index = target.item();
            if(neuron_index > 0 and added_sample_counter[neuron_index - 1] < samples_per_class):
                spike_waveform =  (data.numpy());
                temp_spikes[data_iter_counter, :] = spike_waveform;
                class_index = class_counter[i] + neuron_index;
                # changes class index
                temp_labels[0, data_iter_counter] = class_index;
                added_sample_counter[neuron_index - 1] = added_sample_counter[neuron_index - 1] + 1;
                data_iter_counter = data_iter_counter + 1;

    print(added_sample_counter)
    dataset = np.concatenate((dataset, temp_spikes), axis=0)            
    labels = np.concatenate((labels, temp_labels), axis=1)            

    
# adds multi unit activity data
data_iter_counter = 0;
total_nb_of_samples = int(samples_per_class);
temp_spikes = np.zeros((total_nb_of_samples, 1, waveform_length));
temp_labels = np.zeros((1, total_nb_of_samples), dtype='int');
while (data_iter_counter < total_nb_of_samples):
    for i in range(0, nb_of_datasets):
        single_recording_path = path.join(path_to_recordings, 'data_' + str(i + 1) + '.npy');
        single_recording_ground_truth = path.join(path_to_ground_truth_data, 'gt_' + str(i + 1) + '.npy');
        dataset_spikes = cnn.GenerateDataset(single_recording_path, single_recording_ground_truth, waveform_length, 600, snr_from, snr_to, max_shift, use_horizontal_flip, use_vertical_flip, transform_list);
        # adds to temp variable to add to real dataset afterwards
        for j, (data, target) in enumerate(dataset_spikes):
            neuron_index = target.item();
            if(neuron_index == 0 and data_iter_counter < total_nb_of_samples):
                spike_waveform =  (data.numpy());
                temp_spikes[data_iter_counter, :] = spike_waveform;
                class_index = neuron_index;
                # changes class index
                temp_labels[0, data_iter_counter] = class_index;
                data_iter_counter = data_iter_counter + 1;
dataset = np.concatenate((dataset, temp_spikes), axis=0)            
labels = np.concatenate((labels, temp_labels), axis=1)            


(1, 1, 72)
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  39.42016185337235
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e3553a048>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)


  temp = temp.new_tensor(data);


dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  39.98072087099692
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e3553a160>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  58.617973319546145
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e3553a160>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  120

dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  37.37168887935474
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e3553afd0>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  86.17663169552628
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e3553a128>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  1200

dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  97.0942939959379
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e306920f0>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  12008
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  23.76144001237021
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692160>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(16, dtype=torch.int32)
dataset len:  12008

[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692198>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  87.68040094829745
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692240>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  61.19593940365579
flip_data_horz:  0
[ 0.98707844 -0.98

dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  54.72823749881662
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692198>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  31.997593318245226
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692198>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  127

dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  29.315784375678366
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692198>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  12778
temp_dataset_len:  0
shift_from:  0
shift_to:  1
shift_step:  1
shift_indexes:  tensor([0], dtype=torch.int32)
snr_ratio:  55.662361616400275
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7f9e30692198>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7f9e453c0128>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7f9e453c0080>]
tensor(19, dtype=torch.int32)
dataset len:  12

In [None]:
np.min(labels)

In [None]:
class_sample_count = np.unique(labels, return_counts=True)[1]
class_sample_count

In [None]:
np.save(path_to_train_data, dataset)
np.save(path_to_train_labels, labels)


In [None]:
mean = (np.mean(np_data_spikes))
std = (np.std(np_data_spikes))
np.save(path_to_mean_std, [mean, std])

print(mean)
print(std)