In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os


In [2]:
root_folder = os.path.dirname(os.getcwd())
path_to_ground_truth_data = path.join(root_folder,'data/synthesized/gt_85.npy');
path_to_recording= path.join(root_folder,'data/synthesized/data_85.npy');

sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

waveform_length = 72;
step_size = int(np.round(waveform_length * 0.1))
batch_size = 15000;


In [3]:
# loads spike detection model
num_classes_detect = 2;
path_to_model_detect = path.join(root_folder, 'models/detect/resnet18_num_classes_2_epoch_24.pt')
model_detect = cnn.resnet18(num_classes=num_classes_detect);
checkpoint = torch.load(path_to_model_detect)
model_detect.load_state_dict(checkpoint['model_state_dict'])
model_detect.to(device);


In [4]:
# loads re id model
num_classes_re_id = 499;
path_to_model_re_id = path.join(root_folder, 'models/re_id/resnet18_num_classes_499_epoch_74.pt')
model_re_id = cnn.ft_net(class_num = num_classes_re_id);
checkpoint = torch.load(path_to_model_re_id)
model_re_id.load_state_dict(checkpoint['model_state_dict'])
model_re_id.to(device);


In [5]:
# transforms data for inference
transform = transforms.Compose([cnn.FilterSignalUsingButtersWorth('high', 24000, np.array([100], dtype=int), 1), cnn.OptimizedZScoreNormalizaton()])
recording = cnn.Recording(path_to_recording, transform = transform);
recording = cnn.AddPaddingToRecording(recording, step_size, waveform_length)
# gets waveform data for each window
waveform_indices = cnn.GetWaveformIndices(recording, step_size, waveform_length)
transform = transforms.Compose([cnn.ExtractWaveforms(waveform_indices, waveform_length)])
waveforms = transform(recording.data);
normalized_waveforms = waveforms;


[ 0.98707844 -0.98707844] [ 1.         -0.97415687]


In [6]:
### spike detection pipeline
# loads data for inference
dataset_to_infer_detect = cnn.InferenceDataset(normalized_waveforms.to(device));
data_loader_detect = torch.utils.data.DataLoader(dataset_to_infer_detect, batch_size=batch_size, shuffle=False, num_workers=0);

#inference for detection 
result_detect = cnn.Inference(model_detect, data_loader_detect, num_classes_detect)
treshold = 0.9;
predictions = cnn.PredictionByTreshold(result_detect, treshold)
max_amplitude_index = cnn.GetNonOverlappingSpikesMaxAmplitude(recording, predictions, waveform_length, waveforms, waveform_indices)


torch.cuda.FloatTensor
done:  0  batch
done:  1  batch
done:  2  batch
done:  3  batch
done:  4  batch
done:  5  batch
done:  6  batch
done:  7  batch
done:  8  batch
done:  9  batch
done:  10  batch
done:  11  batch
done:  12  batch
done:  13  batch
done:  14  batch
done:  15  batch
done:  16  batch
done:  17  batch
done:  18  batch
done:  19  batch
done:  20  batch
done:  21  batch
done:  22  batch
done:  23  batch
done:  24  batch
done:  25  batch
done:  26  batch
done:  27  batch
done:  28  batch
done:  29  batch
done:  30  batch
done:  31  batch
done:  32  batch
done:  33  batch
done:  34  batch
done:  35  batch
done:  36  batch
done:  37  batch
done:  38  batch
done:  39  batch
done:  40  batch
done:  41  batch
done:  42  batch
done:  43  batch
done:  44  batch
done:  45  batch
done:  46  batch
done:  47  batch
done:  48  batch
done:  49  batch
done:  50  batch
done:  51  batch
done:  52  batch
done:  53  batch
done:  54  batch
done:  55  batch
done:  56  batch
done:  57  batch
d

In [7]:
del dataset_to_infer_detect, data_loader_detect
torch.cuda.empty_cache()



In [8]:

max_amplitude_index = torch.unsqueeze(torch.tensor(max_amplitude_index), 0);


In [9]:
max_amplitude_index.shape

torch.Size([1, 15154])

In [10]:
# spike classficiation pipeline
transform = transforms.Compose([cnn.ExtractWaveforms(max_amplitude_index, waveform_length)])
waveforms_re_id = transform(recording.data);
#vertical flip
waveforms_re_id_flipped = waveforms_re_id * -1;
# data loading for inference
dataset_to_re_id = cnn.InferenceDataset(waveforms_re_id.to(device));
data_loader_re_id = torch.utils.data.DataLoader(dataset_to_re_id, batch_size=batch_size, shuffle=False, num_workers=0);
dataset_to_re_id_flipped = cnn.InferenceDataset(waveforms_re_id_flipped.to(device));
data_loader_re_id_flipped = torch.utils.data.DataLoader(dataset_to_re_id_flipped, batch_size=batch_size, shuffle=False, num_workers=0);


torch.cuda.FloatTensor
torch.cuda.FloatTensor


In [11]:
# exracts features
feature_map_dims = 128
features = torch.FloatTensor(0 , 128).zero_();
model_re_id.model.fc.classifier = nn.Sequential()
result_re_id = cnn.Inference(model_re_id, data_loader_re_id, feature_map_dims)
result_re_id_flipped = cnn.Inference(model_re_id, data_loader_re_id_flipped, feature_map_dims)
features = result_re_id + result_re_id_flipped ;
temp = torch.norm(features, p=2, dim=1, keepdim=True)
unit_features = features.div(temp.expand_as(features))


done:  0  batch
done:  1  batch
done:  0  batch
done:  1  batch


In [12]:
del dataset_to_re_id, data_loader_re_id, dataset_to_re_id_flipped, data_loader_re_id_flipped
torch.cuda.empty_cache()



In [13]:
treshold = 0.6;
feature_list = torch.unsqueeze(unit_features[0, :], 0).to(device)
ids_list = torch.tensor([1], dtype=torch.int)
for i in range(1,unit_features.size()[0]):
    query = torch.unsqueeze(unit_features[i, :], 0).to(device);
    query_raveled = query.view(-1,1);
    score = torch.mm(feature_list, query_raveled);
    max_val, max_ind = torch.max(score, 0);
    # spike is already in list
    if(max_val.item() > treshold):
        ids_list = torch.cat((ids_list, torch.tensor([ids_list[max_ind.item()]], dtype = torch.int)), 0);
    # adds a new spike id
    else:
        new_id = torch.max(ids_list, 0)[0] + 1;
        ids_list = torch.cat((ids_list, torch.tensor([new_id.item()], dtype = torch.int)), 0);

    feature_list = torch.cat((feature_list, query), 0);


In [14]:
ids_list[0:100]

tensor([ 1,  2,  3,  4,  5,  6,  7,  6,  8,  9, 10,  6,  3, 10,  2,  3,  1,  4,
         3,  4,  3, 11, 12,  4,  8, 13, 14,  3,  8, 10,  3, 15, 16, 17, 10,  3,
         3, 10,  6,  4, 18, 13,  3, 19, 15,  9, 13, 18, 20, 15,  3, 11, 13,  6,
         8,  3, 21,  3, 10,  9, 16, 15,  8, 19, 10, 20,  3,  3, 15, 11,  1, 15,
        10, 22,  3,  3, 23, 11,  3,  3, 18, 17, 21,  8, 18,  3, 16, 13, 16,  3,
        10, 15,  3, 13, 24, 15,  3,  4,  3, 22], dtype=torch.int32)

In [15]:
temp =np.load(path_to_ground_truth_data)


In [16]:
unique_classes, nb_of_occourences = np.unique(ids_list, return_counts=True);
print(unique_classes)
print(nb_of_occourences)

[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33]
[ 687    5 3477  980  107 1087    2  929  634 1039  487    2  465   10
  994  172  210  888   41  554  731  645    4    6    1  220    7    1
  195  568    2    3    1]


In [17]:
unique_classes_gt, nb_of_occourences_gt = np.unique(temp[1, :], return_counts=True);
print(unique_classes)
print(nb_of_occourences)

[ 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33]
[ 687    5 3477  980  107 1087    2  929  634 1039  487    2  465   10
  994  172  210  888   41  554  731  645    4    6    1  220    7    1
  195  568    2    3    1]


In [18]:
found_spike_index = [];
found = [];
for i in range(max_amplitude_index.size()[1]):
    spike = max_amplitude_index[0, i];
    ind =np.where(spike.numpy() == temp[0, :].astype(int))
    if(len(ind[0]) == 0):
        found.append(0);
        found_spike_index.append(0)

    else:
        found.append(1)
        found_spike_index.append(int(temp[1, ind[0]]) + 1)


In [19]:
found = np.asarray(found)
found_spike_index = np.asarray(found_spike_index)

In [20]:
tp = np.where(found == 1)[0].size
fp = ntotalp.where(found == 0)[0].size
print(tp / (tp+fp))
print(fp)

0.9196911706480138
1217


In [21]:
unique_classes, nb_of_occourences = np.unique(found_spike_index, return_counts=True);


In [22]:
unique_classes.size

20

In [23]:
unique_classes, nb_of_occourences = np.unique(found_spike_index, return_counts=True);

spike_classes = np.empty((unique_classes.size,),dtype=object)
for i in range(spike_classes.size):
    spike_classes[i] = [];
for i in range(found_spike_index.size):
    index = found_spike_index[i];
    index_re_id = ids_list[i]
    if (index == 0):
        spike_classes[index].append(index_re_id.item());
    else:
        spike_classes[index].append(index_re_id.item())
        
        

In [24]:
tp = 0;
total = 0;
for i in range(1, spike_classes.size):
    unique_classes, nb_of_occourences = np.unique(spike_classes[i], return_counts=True);
    tp = tp + np.max(nb_of_occourences);
    total = total + np.sum(nb_of_occourences);
    print("max occ: ", np.max(nb_of_occourences), "total: ", np.sum(nb_of_occourences))
    

max occ:  2489 total:  2497
max occ:  209 total:  262
max occ:  483 total:  483
max occ:  168 total:  168
max occ:  194 total:  194
max occ:  683 total:  687
max occ:  630 total:  630
max occ:  957 total:  957
max occ:  629 total:  672
max occ:  213 total:  219
max occ:  992 total:  993
max occ:  842 total:  844
max occ:  872 total:  879
max occ:  554 total:  594
max occ:  969 total:  970
max occ:  731 total:  838
max occ:  550 total:  553
max occ:  463 total:  476
max occ:  1019 total:  1021


In [25]:
tp / total

0.9791920786395925

In [26]:
unique_classes, nb_of_occourences = np.unique(spike_classes[0], return_counts=True);
print(nb_of_occourences)

[  3 957   8 130   3   4  20   4   1   1   4  46   1   4  10   7   1  13]
