In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os


In [2]:
root_folder = os.path.dirname(os.getcwd())
path_to_recording= path.join(root_folder,'data/recording_datasets/datasets_2.npy');
path_to_model = path.join(root_folder, 'models/resnet18_num_classes_2_epoch_0.pt')
path_to_results = path.join(root_folder, 'data/results/datasets_1_results.npy')

sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn

In [3]:
waveform_length = 72;
step_size = int(np.round(waveform_length * 0.1))
batch_size = 60000;
num_classes = 2;
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = cnn.resnet18(num_classes=num_classes);
checkpoint = torch.load(path_to_model)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device);


In [4]:
transform = transforms.Compose([cnn.MovingWeightedMeanAndStdNormalization(1000)])
recording = cnn.Recording(path_to_recording, transform = transform);
recording = cnn.AddPaddingToRecording(recording, step_size, waveform_length)



72000000
tensor([-0.4852, -0.0765,  0.0115, -0.3168, -0.8177, -1.1595, -1.3421, -1.4938,
        -1.5787, -1.5511, -1.4830, -1.4844, -1.5745, -1.6838, -1.6920, -1.5752,
        -1.4379, -1.3655, -1.3350, -1.2545, -1.0723, -0.7421, -0.2437,  0.3325,
         0.7915,  0.9613,  0.8495,  0.6891,  0.6677,  0.7274,  0.7781,  0.8906,
         1.1738,  1.5295,  1.7142,  1.7459,  1.9466,  2.3850,  2.7226,  2.7006,
         2.3137,  1.6934,  1.0278,  0.5063,  0.1528, -0.0873, -0.2789, -0.4114,
        -0.4238, -0.3308, -0.1393,  0.1240,  0.2714,  0.0547, -0.3991, -0.7987,
        -0.9988, -0.9398, -0.5659,  0.1558,  0.8844,  1.1747,  1.0718,  0.9475,
         1.0456,  1.2560,  1.3518,  1.2474,  1.0342,  0.7762,  0.4605,  0.1095,
        -0.1335, -0.2250, -0.2098, -0.1273, -0.0247,  0.0788,  0.1373,  0.0759,
        -0.0732, -0.0366,  0.2924,  0.6132,  0.6220,  0.3356, -0.0087, -0.2095,
        -0.2549, -0.2624, -0.3526, -0.5380, -0.6964, -0.7022, -0.5567, -0.4517,
        -0.5400, -0.7467, -0.91

In [5]:
waveform_indices = cnn.GetWaveformIndices(recording, step_size, waveform_length)

In [6]:
transform = transforms.Compose([cnn.ExtractWaveforms(waveform_indices, waveform_length)])
waveforms = transform(recording.data);



In [7]:
normalized_waveforms = waveforms;
dataset_to_infer = cnn.InferenceDataset(normalized_waveforms.to(device));

torch.cuda.FloatTensor


In [8]:
%pdb off
data_loader = torch.utils.data.DataLoader(dataset_to_infer, batch_size=batch_size, shuffle=False, num_workers=0);
result = cnn.Inference(model, data_loader, num_classes)

Automatic pdb calling has been turned OFF
done:  0  batch
done:  1  batch
done:  2  batch
done:  3  batch
done:  4  batch
done:  5  batch
done:  6  batch
done:  7  batch
done:  8  batch
done:  9  batch
done:  10  batch
done:  11  batch
done:  12  batch
done:  13  batch
done:  14  batch
done:  15  batch
done:  16  batch
done:  17  batch
done:  18  batch
done:  19  batch
done:  20  batch
done:  21  batch
done:  22  batch
done:  23  batch
done:  24  batch
done:  25  batch
done:  26  batch
done:  27  batch
done:  28  batch
done:  29  batch
done:  30  batch
done:  31  batch
done:  32  batch
done:  33  batch
done:  34  batch
done:  35  batch
done:  36  batch
done:  37  batch
done:  38  batch
done:  39  batch
done:  40  batch
done:  41  batch
done:  42  batch
done:  43  batch
done:  44  batch
done:  45  batch
done:  46  batch
done:  47  batch
done:  48  batch
done:  49  batch
done:  50  batch
done:  51  batch
done:  52  batch
done:  53  batch
done:  54  batch
done:  55  batch
done:  56  batch

In [9]:
soft_max = nn.Softmax(1)
probability = soft_max(result)
argmax = torch.argmax(probability, 1);


In [10]:
def predictionByTreshold(result, treshold):
  soft_max = nn.Softmax(1)
  probability = soft_max(result)
  prediction = torch.argmax(probability, 1);
  for i  in range(prediction.nelement()):
    if prediction[i] > 0:
      score = probability[i, prediction[i]];
      if(score < treshold):
        prediction[i] = 0;
  return prediction;

In [11]:
treshold = 0.85;
predictions = predictionByTreshold(result, treshold)

In [12]:
print(np.where(predictions == 1)[0].shape)
print(np.where(predictions == 0)[0].shape)

(724283,)
(9561421,)


In [13]:
print(np.where(argmax == 1)[0].shape)
print(np.where(argmax == 0)[0].shape)

(1083899,)
(9201805,)


In [14]:
shift_check = 1
classes = [1]
extracted_sequences = []
temp = [];
for i in range(0, predictions.nelement()):
  append_sequence = True;
  current_class = None
  for j in range(0, len(classes)):
    class_id = classes[j];
    if(predictions[i] == class_id):
      temp.append(i);
      if(predictions.nelement() - shift_check  > i):
        for k in range(1, shift_check + 1):
          if(predictions[i+k] == class_id):
            append_sequence = False;
  if(append_sequence and len(temp) != 0):
    extracted_sequences.append(temp)
    temp = [];
          

      


In [15]:
len(extracted_sequences)

96969

In [16]:
#for seq in extracted_sequences:
#  print(len(seq))

In [17]:
predicted_index_from_to = np.zeros((len(extracted_sequences), 2));
for i, spike in enumerate(extracted_sequences):
  predicted_index_from_to[i, 0] = waveform_indices[0, spike[0]] - waveform_length // 2;
  predicted_index_from_to[i, 1] = waveform_indices[0, spike[-1]] + waveform_length // 2;


In [18]:
predicted_index_from_to = predicted_index_from_to.astype(int)

In [19]:
import matplotlib.pyplot as plt
max_index = np.zeros((len(predicted_index_from_to)))
for i in range(0, len(predicted_index_from_to)):
  wf = abs(recording[0, predicted_index_from_to[i, 0]:predicted_index_from_to[i, 1] ])
  max_index[i] =  predicted_index_from_to[i, 0] + np.argmax(wf);

In [30]:
path_to_ground_truth_data = path.join(root_folder,'data/recording_datasets/ground_truth_data_multiunit_2.npy');
ground_truth = np.load(path_to_ground_truth_data);
neuron_indexes = ground_truth[1, :];
spike_positions = ground_truth[0, :];
spike_positions = spike_positions.astype(int)

spike_positions.shape
total = spike_positions.size
print(total)


58481


In [31]:
max_index_unique = np.unique(max_index)


In [32]:
max_index_unique.astype(int)

array([    1516,     1763,     1893, ..., 71997977, 71999418, 71999801])

In [33]:
print(max_index_unique.shape)
print(max_index.shape)
np.save(path_to_results, max_index_unique)

(90288,)
(96969,)


In [25]:
is_found = np.zeros(max_index_unique.size)
for i in range(0, max_index_unique.size):
  found = np.where(max_index_unique[i] == spike_positions)
  if(len(found[0]) == 0):
    is_found[i] = 0
  else:
    is_found[i] = 1;

In [26]:
tp = np.where(is_found == 1)[0].size
print(tp)
fp = np.where(is_found == 0)[0].size
print(fp)
fn = total - tp - fp
print(fn)

56610
33678
-31807


In [27]:
precision = tp / (tp + fp)
print(precision)
recall = tp / (tp + fn)
print(recall)
f1 = 2 * ((precision * recall) / (precision + recall))
print(f1)

0.626993620414673
2.2823851953392733
0.9837432987809644


In [35]:
tp

56610

In [29]:
tp / total

0.968006703031754

In [None]:
# tp is spike
# fp is noise but predicted as spike
# tn is noise
# fn is spike but predicted as noise