In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os

In [2]:
max_dataset_size = 5000000;
waveform_length = 72;
nb_of_elements = 30000;
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nb_of_datasets = 5;
snr_from = 20;
snr_to = 100;
max_shift = waveform_length  // 4 # +- shift
use_horizontal_flip = True;
use_vertical_flip = False;
filter_type = 'high';
sampling_rate = 24000;
passband = np.array([100], dtype=int);
order = 1;


In [3]:
root_folder = os.path.dirname(os.getcwd())
path_to_recordings = path.join(root_folder,'data/synthesized');
path_to_ground_truth_data = path.join(root_folder,'data/synthesized');
path_to_train_data = path.join(root_folder,'data/train_data.npy');
path_to_train_labels = path.join(root_folder,'data/train_labels.npy');
path_to_noise_data = path.join(root_folder,'data/noise_data.npy');
path_to_mean_std = path.join(root_folder,'data/mean_std.npy');
sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn
transform_list =[cnn.FilterSignalUsingButtersWorth(filter_type, sampling_rate, passband, order), cnn.OptimizedZScoreNormalizaton()];

  

In [None]:
dataset_spikes = cnn.SpikeTrainDataset();
counter = 1;
while (dataset_spikes.__len__() < max_dataset_size):
    if (counter > nb_of_datasets):
        counter = 1;
    single_recording_path = path.join(path_to_recordings, 'data_' + str(counter) + '.npy');
    single_recording_ground_truth = path.join(path_to_ground_truth_data, 'gt_' + str(counter) + '.npy');
    temp_dataset = cnn.GenerateDataset(single_recording_path, single_recording_ground_truth, waveform_length, 600, snr_from, snr_to, max_shift, use_horizontal_flip, use_vertical_flip, transform_list);
    dataset_spikes = torch.utils.data.ConcatDataset((dataset_spikes, temp_dataset));
    counter = counter + 1;

temp_dataset_len:  0
shift_from:  -18
shift_to:  19
shift_step:  1
shift_indexes:  tensor([-18, -17, -16, -15, -14, -13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,
         -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
         10,  11,  12,  13,  14,  15,  16,  17,  18], dtype=torch.int32)
snr_ratio:  59.20295199155965
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7fe1c3506630>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7fe1d33bf3c8>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7fe1d33bf438>]


  temp = temp.new_tensor(data);


dataset len:  444444
temp_dataset_len:  0
shift_from:  -18
shift_to:  19
shift_step:  1
shift_indexes:  tensor([-18, -17, -16, -15, -14, -13, -12, -11, -10,  -9,  -8,  -7,  -6,  -5,
         -4,  -3,  -2,  -1,   0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
         10,  11,  12,  13,  14,  15,  16,  17,  18], dtype=torch.int32)
snr_ratio:  63.12879912824271
flip_data_horz:  0
[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
[<custom_resnet.CustomResnet.Awgn object at 0x7fe1c3526550>, <custom_resnet.CustomResnet.FilterSignalUsingButtersWorth object at 0x7fe1d33bf3c8>, <custom_resnet.CustomResnet.OptimizedZScoreNormalizaton object at 0x7fe1d33bf438>]


In [None]:
dataset_noise = cnn.SpikeTrainDataset();
counter = 1;
while (dataset_noise.__len__() < max_dataset_size):
    if (counter > nb_of_datasets):
        counter = 1;
    single_recording_path = path.join(path_to_recordings, 'data_' + str(counter) + '.npy');
    single_recording_ground_truth = path.join(path_to_ground_truth_data, 'gt_' + str(counter) + '.npy');
    noise_indices = cnn.GetNoiseIndices(single_recording_path, single_recording_ground_truth, waveform_length, nb_of_elements, snr_from, snr_to, max_shift, use_horizontal_flip, use_vertical_flip, transform_list);
    noise_class = torch.zeros(1, noise_indices.nelement(), dtype= torch.int);
    noise_data = torch.cat((noise_indices, noise_class), 0);
    np.save(path_to_noise_data, noise_data.numpy());
    temp_dataset = cnn.GenerateDataset(single_recording_path, path_to_noise_data, waveform_length, 10);
    dataset_noise = torch.utils.data.ConcatDataset((dataset_noise, temp_dataset));
    counter = counter + 1;

In [None]:
dataset_size =  dataset_spikes.__len__() if dataset_spikes.__len__() < dataset_noise.__len__() else  dataset_noise.__len__()

In [None]:
# convers to np array
np_data_spikes = np.zeros((dataset_size, 1, waveform_length));
for i, (data, target) in enumerate(dataset_spikes):
  np_data_spikes[i, :] = data.numpy();
  if i % 10000 == 0:
    print(i);
  if(i + 1 == dataset_size):
    break;

In [None]:
sys.setrecursionlimit(30000)
#argmax_spikes = np.max(abs(np_data_spikes), axis=2);
#valid_spikes = np.where(argmax_spikes >= 4.5)[0].ravel();
#np_data_spikes = np_data_spikes[valid_spikes, :, :]

# removes spikes that have mutiple spikes in waveform
spike_argmax = np.argmax(abs(np_data_spikes), 2)
over_treshold = waveform_length // 2 + waveform_length // 4
under_treshold = waveform_length // 2 - waveform_length // 4
is_valid_train_data = ((spike_argmax >= under_treshold) & (spike_argmax <= over_treshold)).ravel()
np_data_spikes = np_data_spikes[is_valid_train_data, :, :]
dataset_size = np_data_spikes.shape[0]
np_classes_spikes = np.ones(dataset_size)

np_data_noise = np.zeros((dataset_size, 1, waveform_length));
np_classes_noise = np.zeros(dataset_size)
counter = 0;
for i, (data, target) in enumerate(dataset_noise):
    np_data_noise[counter, :] = data.numpy()
    counter = counter + 1;
    if i % 10000 == 0:
        print(i);
    if(i + 1 == dataset_size):
        break;
    

    


In [None]:
np_data_spikes.shape

In [None]:
print(np_data_spikes[0, :, 0:37])
    

In [None]:
print(np_data_noise.shape)
print(np_data_spikes.shape)


In [None]:
dataset = np.concatenate((np_data_spikes, np_data_noise), 0);

In [None]:
labels = np.concatenate((np_classes_spikes, np_classes_noise), 0); 

In [None]:
np.save(path_to_train_data, dataset)
np.save(path_to_train_labels, labels)


In [None]:
mean = (np.mean(np_data_spikes))
std = (np.std(np_data_spikes))
np.save(path_to_mean_std, [mean, std])

print(mean)
print(std)

In [None]:
import matplotlib.pyplot as plt
rnd = np.random.randint(0, np_data_noise.shape[0]);
plt.plot(np_data_noise[rnd, 0, :])


In [None]:
argmax_spikes = np.max(abs(np_data_spikes), axis=2)

In [None]:
np.where(argmax_spikes < 4.5)[0].shape

In [None]:
argmax_spikes