In [1]:
from pathlib import Path
from os.path import expanduser
from os import path
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os


In [2]:
root_folder = os.path.dirname(os.getcwd())
path_to_recording= path.join(root_folder,'data/recording_datasets/datasets_single9.npy');
path_to_model = path.join(root_folder, 'models/resnet18_num_classes_2_epoch_24.pt')
path_to_results = path.join(root_folder, 'data/results/single_recording_gd9.npy')

sys.path.append(root_folder)
from custom_resnet import CustomResnet as cnn

In [3]:
waveform_length = 72;
step_size = int(np.round(waveform_length * 0.1))
batch_size = 60000;
num_classes = 2;
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = cnn.resnet18(num_classes=num_classes);
checkpoint = torch.load(path_to_model)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device);


In [None]:
transform = transforms.Compose([cnn.FilterSignalUsingButtersWorth('high', 24000, np.array([100], dtype=int), 1), cnn.MovingMeanAndStdNormalization(1000)])
recording = cnn.Recording(path_to_recording, transform = transform);
recording = cnn.AddPaddingToRecording(recording, step_size, waveform_length)


[ 0.98707844 -0.98707844] [ 1.         -0.97415687]
started movingmeanandstd


In [None]:
np.mean(recording.data.numpy().ravel())

In [None]:
np.std(recording.data.numpy().ravel())

In [None]:
waveform_indices = cnn.GetWaveformIndices(recording, step_size, waveform_length)

In [None]:
transform = transforms.Compose([cnn.ExtractWaveforms(waveform_indices, waveform_length)])
waveforms = transform(recording.data);



In [None]:
normalized_waveforms = waveforms;
dataset_to_infer = cnn.InferenceDataset(normalized_waveforms.to(device));

In [None]:
%pdb off
data_loader = torch.utils.data.DataLoader(dataset_to_infer, batch_size=batch_size, shuffle=False, num_workers=0);
result = cnn.Inference(model, data_loader, num_classes)

In [None]:
soft_max = nn.Softmax(1)
probability = soft_max(result)
argmax = torch.argmax(probability, 1);


In [None]:
def predictionByTreshold(result, treshold):
  soft_max = nn.Softmax(1)
  probability = soft_max(result)
  prediction = torch.argmax(probability, 1);
  for i  in range(prediction.nelement()):
    if prediction[i] > 0:
      score = probability[i, prediction[i]];
      if(score < treshold):
        prediction[i] = 0;
  return prediction;

In [None]:
treshold = 0.9;
predictions = predictionByTreshold(result, treshold)

In [None]:
waveforms[pred_ind, :, :,].shape

In [None]:
pred_ind = np.where(predictions == 1)[0];
waveform_start = waveform_indices[0, pred_ind] - waveform_length // 2;
waveform_argmax = np.argmax(abs(waveforms[pred_ind, :, :,]), axis=2);
predicted_index = np.unique(waveform_start + waveform_argmax.view(-1).to(torch.int32));

In [None]:

difference = np.diff(predicted_index);
arg_min = np.argmin(difference);
val_min = np.min(difference)
while (val_min < waveform_length):
    isuppr = np.argmin(abs(recording.data[0,predicted_index[arg_min:arg_min+2]]));
    predicted_index = np.concatenate( ( predicted_index[:arg_min+isuppr], predicted_index[arg_min+isuppr+1:] ) )
    difference = np.diff(predicted_index);
    arg_min = np.argmin(difference);
    val_min = np.min(difference)


In [None]:
print(np.where(predictions == 1)[0].shape)
print(np.where(predictions == 0)[0].shape)

In [None]:
predicted_index

In [None]:
print(np.where(argmax == 1)[0].shape)
print(np.where(argmax == 0)[0].shape)

In [None]:
#for seq in extracted_sequences:
#  print(len(seq))

In [None]:
path_to_ground_truth_data = path.join(root_folder,'data/recording_datasets/single_recording_gd9.npy');
ground_truth = np.load(path_to_ground_truth_data);
neuron_indexes = ground_truth[1, :];
spike_positions = ground_truth[0, :];
spike_positions = spike_positions.astype(int)

spike_positions.shape
total = spike_positions.size
print(total)


In [None]:
max_index_unique = predicted_index


In [None]:
spike_positions

In [None]:
max_index_unique

In [None]:
recording1.data[0, 12994 -36: 12994+36]

In [None]:
rnd = np.random.randint(0, max_index_unique.size)
plt.plot(recording.data[0, max_index_unique[rnd] -36: max_index_unique[rnd]+36].numpy().ravel())

In [None]:
print(max_index_unique.shape)
np.save(path_to_results, max_index_unique)

In [None]:
is_found = np.zeros(max_index_unique.size)
for i in range(0, max_index_unique.size):
  found = np.where(max_index_unique[i] == spike_positions)
  if(len(found[0]) == 0):
    is_found[i] = 0
  else:
    is_found[i] = 1;

In [None]:
tp = np.where(is_found == 1)[0].size
print(tp)
fp = np.where(is_found == 0)[0].size
print(fp)


In [None]:
np.std(recording.data[0, :].numpy().ravel())

In [None]:
tp

In [None]:
tp / total

In [None]:
transform1.append(FilterSignalUsingButtersWorth('high', 24000, np.array([100], dtype=int), 1))


In [None]:
# tp is spike
# fp is noise but predicted as spike
# tn is noise
# fn is spike but predicted as noise