In [None]:
### Progress list ###
"""
DONE:
- Check and remove overlapping spwr events
- Create spwr dataset with 100ms around each spwr peak
    - Plot the events
    - Plot the entire signal but with spwr marked in red
- Create noise dataset. Ensure it does not contain any spwr events. Maybe plot to make sure in case Peter missed spwr events.
- Train test split
- Train the model
- Label parts of the code and give instructions at code segments where a lot of manual work is done.

WORK IN PROGRESS:            


TODO:
- Fix a way to not have Keras overwrite models
- Create a function for plotting just to clean up the code further

#UGLY CODE PARTS:
 - Everything with np.delete because it creates a copy of the array which takes up a lot of memory and unnecessary time
 - That I have to go through the noise range and manually select ranges that could contain non noise
 - That I after creating the noise data set loop through it again and manually mark ranges that are possible spwrs
   and then manually find noise data to replace it which is done now by finding a sequential range of data that is
   maybe_non_noise_indicies.shape[0]*201 long without any possible spwr inside of it. I do this now by adding a number
   to the last frame from which the noise dataset was created. But this is not a good strategy if a lot of data is needed.
"""

In [None]:
# Load libraries
import numpy as np
from tensorflow.keras import layers
from tensorflow import keras
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import h5py
import os
import pandas as pd
import scipy.io
import tensorflow as tf
import random

#Needed for Keras
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print("Num CPUs Available: ", len(tf.config.list_physical_devices('CPU')))

## Load data from MATLAB ##

In [None]:
### LOAD DATA ###

# Current path
cwd = os.getcwd()

work_dir = r"C:\Users\RIPPLE\source\repos\spwr-project"
data_dir = r"D:\Dropbox (Personal)\ETH_DATA\rTBY35\8_freely_behav_220517_120301"
# D:\Dropbox (Personal)\ETH_DATA\rTBY35\7_freely_behav_220511_152446
os.chdir(data_dir)

# Load data for determining the type of spwr event
g = scipy.io.loadmat('spwr_indicies.mat')
high_dHP_bool = (g['bzHigh_dHP_ind'].flatten()).astype(bool)
low_dHP_bool = (g['bzLow_dHP_ind'].flatten()).astype(bool)
high_iHP_bool = (g['bzHigh_iHP_ind'].flatten()).astype(bool)
low_iHP_bool = (g['bzLow_iHP_ind'].flatten()).astype(bool)

# Load lfp and ripple data
f = h5py.File('8_freely_behav_220517_120301_auto_features.mat','r')
lfp_channels = np.array(f.get('lfp/data')) # (16, 7403751)
#lfp_lowpass = np.array(f.get('lfp/lowpass'))
#lfp_lowpass_all = np.array(f.get('lfp/lowpass_all'))
#lfp_bandpass = np.array(f.get('lfp/bandpass'))
#lfp_bandpass_power = np.array(f.get('lfp/bandpass_power'))
ripple_timestamps = np.array(f.get('ripples/timestamps')) #(2, 1467)
ripple_durations = np.array(f.get('ripples/duration')) #(1, 1467)
ripple_centralFrames = np.array(f.get('rippleEpisodes/centralFrames')).astype(int).transpose() #(1, 1467)
lfp_main_channels_dHP = np.array(f.get('lfp_main_channel_dHP')).flatten().astype(int)-1 #-1 for python zero-indexing
lfp_main_channels_iHP = np.array(f.get('lfp_main_channel_iHP')).flatten().astype(int)-1 #-1 for python zero-indexing

os.chdir(work_dir)

## Collect spwr data for all ripples and only dHP ripples

In [None]:
# Timestamps for dHP spwr events, both high and low
dHP_spwr_bool = np.logical_or(high_dHP_bool, low_dHP_bool)

dHP_spwr_timestamps = ripple_timestamps[:, dHP_spwr_bool] #(2, 878)
dHP_spwr_durations = ripple_durations[:, dHP_spwr_bool] #(1, 878)
dHP_spwr_centralFrames = ripple_centralFrames[:, dHP_spwr_bool] #(1, 878)

# Timestamps for iHP spwr events
#iHP_spwr_bool = np.logical_or(high_iHP_bool, low_iHP_bool)
#iHP_spwr_timestamps = ripple_timestamps[:, iHP_spwr_bool] # (2, 589)

# Timestamps for separate spwr events
#high_dHP_spwr_timestamps = ripple_timestamps[:, high_dHP_bool] #(2, 589)
#low_dHP_spwr_timestamps = ripple_timestamps[:, low_dHP_bool] #(2, 289)
#high_iHP_spwr_timestamps = ripple_timestamps[:, high_iHP_bool] #(2, 384)
#low_iHP_spwr_timestamps = ripple_timestamps[:, low_iHP_bool] #(2, 205)

# All spwr timestamps sorted by start time
# Needed for calculating overlapping spwr events
sorted_ripple_timestamps_ind = np.argsort(ripple_timestamps[0, :])
sorted_ripple_timestamps = ripple_timestamps[:, sorted_ripple_timestamps_ind]
sorted_ripple_centralFrames = ripple_centralFrames[:, sorted_ripple_timestamps_ind] #(1, 1467)

#Sort the dHP spwr timestamps according to start time
sort_indicies = np.argsort(dHP_spwr_timestamps[0, :])
dHP_spwr_timestamps = dHP_spwr_timestamps[:, sort_indicies] #(2, 878)
dHP_spwr_durations = dHP_spwr_durations[:, sort_indicies] #(1, 878)
dHP_spwr_centralFrames = dHP_spwr_centralFrames[:, sort_indicies] #(1, 878)

num_dHP_spwr = np.sum(dHP_spwr_bool) #878
num_total_spwr = ripple_timestamps.shape[1] #1467

#Only use most informative channels (picked by Peter)
lfp_dHP_data = lfp_channels[lfp_main_channels_dHP, :] #3, 7403751
num_channels = len(lfp_main_channels_dHP)


max_ripple_duration = np.max(ripple_durations) #0.09850000000005821 seconds => 100ms around each spwr to not miss any data.

In [None]:
#Distribution of ripple durations
counts, bins = np.histogram(ripple_durations)
plt.stairs(counts, bins)
plt.show()
#print(np.sort(ripple_durations)[:, -100:])
# Only a small fraction are very long spwr events. But I cannot discard them just like that because 
# I want the algorithm to also learn this general behaviour of spwr events and therefore longer spwr events are needed 

### Only use non-overlapping dHP spwr ###

In [None]:
#This code extracts spwrs' centralFrame if the spwr does not overlap other spwrs, 
#i.e. within 100ms of each other from the center.
#This is done to have isolated spwrs in the training/test set and to avoid the same data in different samples.
#Hopefully this allows the model to better learn what a spwr is.

#The end result is that we are provided with the sorted dHP spwrs

timesteps = 100
spwr_size = 2*timesteps+1

#Copy of all spwr centralFrames. Overlapping spwr frames are removed.
copy_centralFrames = np.copy(sorted_ripple_centralFrames)
copy_centralFrames = copy_centralFrames.flatten() #Flatten => (num_spwr, )
 
#Current index in array with removed overlapping spwrs
idx = 0

#Index of current and next spwr of all spwr events
curr_spwr = 0
next_spwr = 1

#spwr events to be removed because of overlapping iHP and dHP as iHP spwr is still captured by dHP channels
non_overlapping_spwr = []

num_dHP_removed = 0
num_iHP_removed = 0

#Initiate previously removed ripple end frame to start of time series, i.e. 0
prev_ripple_end = 0

#While there are at least 2 non-overlapping or potentially non-overlapping spwr events left. 
#If only one left I cannot compare it to anything
num_spwr_left = copy_centralFrames.shape[0] - 1 
while idx < num_spwr_left:
    
    curr_ripple_start = copy_centralFrames[idx] - timesteps
    curr_ripple_end = copy_centralFrames[idx] + timesteps
    next_ripple_start = copy_centralFrames[idx+1] - timesteps
    next_ripple_end = copy_centralFrames[idx+1] + timesteps
    
    #Remove curr ripple if curr ripple overlaps with previously removed ripple
    if(curr_ripple_start <= prev_ripple_end):
        copy_centralFrames = np.delete(copy_centralFrames, idx)
        
        if(dHP_spwr_bool[curr_spwr]):
            num_dHP_removed += 1
        
        else:
            num_iHP_removed += 1
            
        prev_ripple_end = curr_ripple_end
        num_spwr_left += -1
        curr_spwr += 1
        next_spwr += 1
        
    #Remove curr and next ripple if next ripple overlaps with current ripple 
    elif(curr_ripple_start <= next_ripple_start <= curr_ripple_end):
        
        #Discard both overlapping spwr events one after another
        copy_centralFrames = np.delete(copy_centralFrames, idx) #This discard curr
        copy_centralFrames = np.delete(copy_centralFrames, idx) #This discards next which is now at curr pos

        if(dHP_spwr_bool[curr_spwr]):
            num_dHP_removed += 1
        
        else:
            num_iHP_removed += 1
        
        if(dHP_spwr_bool[curr_spwr+1]):
            num_dHP_removed += 1
        else:
            num_iHP_removed += 1
        
        prev_ripple_end = next_ripple_end
        num_spwr_left += -2  
        curr_spwr += 2
        next_spwr += 2
    
    #No overlap => save current spwr as non-overlapping and go to next spwr
    else:
        prev_ripple_end = curr_ripple_end
        non_overlapping_spwr.append(curr_spwr)
        curr_spwr = next_spwr
        idx+= 1
        next_spwr += 1 

#If last ripple does not overlap with previous ripple we add it
curr_ripple_start = copy_centralFrames[idx] - timesteps
if(curr_ripple_start > prev_ripple_end):
    non_overlapping_spwr.append(curr_spwr)

print("Total number of spwr: ", num_total_spwr)
print("Total number of dHP spwr", num_dHP_spwr)
print("Total number of removed overlapping spwr events: ", len(non_overlapping_spwr)) 
print("Number of dHP spwr removed: ", num_dHP_removed)
print("Number of iHP spwr removed: ", num_iHP_removed)
print("Portion of total lost spwr events: ", len(non_overlapping_spwr)/num_total_spwr)
print("Portion of total lost dHP spwr events: ", num_dHP_removed/num_dHP_spwr)

#True if spwr event is not removed
#False is spwr event is removed
kept_spwr_bool = np.full((num_total_spwr), False) # Creates boolean array with all False values
for spwr_idx in non_overlapping_spwr:
    kept_spwr_bool[spwr_idx] = True
    
#Only non-removed dHP spwr events are kept
non_overlapping_dHP_spwr_bool = np.logical_and(kept_spwr_bool, dHP_spwr_bool)
print("Shape of non_overlapping_dHP_spwr_bool: ", non_overlapping_dHP_spwr_bool.shape)
print("Num non-overlapping dHP spwrs: ", np.sum(non_overlapping_dHP_spwr_bool))

num_non_overlapping_dHP_spwr = np.sum(non_overlapping_dHP_spwr_bool)

##Get lfp data only for non-overlapping dHP spwr events
#Sorted array with the centralFrames of non-overlapping dHP spwrs
dHP_non_overlapping_centralFrames = sorted_ripple_centralFrames[0, non_overlapping_dHP_spwr_bool]

#Create structure for storing the lfp data corresponding to each non-overlapping dHP spwr and load in the data
#Channel, spwr id, lfp data
non_overlapping_lfp_dHP_spwr = np.zeros((num_channels, num_non_overlapping_dHP_spwr, spwr_size)) 

for j, spwr_centralFrame in enumerate(dHP_non_overlapping_centralFrames):
    start = spwr_centralFrame - 100
    end = spwr_centralFrame + 101
    non_overlapping_lfp_dHP_spwr[:, j, :] = lfp_dHP_data[:, start:end]

print("Number of high dHP spwrs in non overlapping: ", np.sum(np.logical_and(high_dHP_bool, non_overlapping_dHP_spwr_bool)))
print("Number of high dHP spwrs in non overlapping: ", np.sum(np.logical_and(low_dHP_bool, non_overlapping_dHP_spwr_bool)))

non_overlapping_dHP_ripple_durations = ripple_durations[:, non_overlapping_dHP_spwr_bool]

#Distribution of ripple durations
counts, bins = np.histogram(non_overlapping_dHP_ripple_durations)
plt.stairs(counts, bins)
plt.show()

### PLOTTING ###
This section does not have to be run to get the results.

In [None]:
#Extract frames that are not labelled as spwr (100 frames left and right around each spwr centralFrame i.e 100ms)

start_frames = sorted_ripple_centralFrames[0, :] - 100
end_frames = sorted_ripple_centralFrames[0, :] + 101 #101 because python does not include last idx in range

# TODO, make this more efficient instead of np.delete which creates a new array every time
non_spwr_frames = np.arange(0, 7403751)
for i in range(len(sorted_ripple_centralFrames[0, :])):
    start_frame = start_frames[i]
    end_frame = end_frames[i]
    non_spwr_frames = np.delete(non_spwr_frames, np.argwhere((non_spwr_frames >= start_frame) & (non_spwr_frames <= end_frame)))

print("Number of timepoints: ", lfp_dHP_data.shape[1])
# Plot all of the data
plt.figure()
str_idx = 0
end_idx = lfp_dHP_data.shape[1]
plt.plot(np.arange(len(lfp_dHP_data[0, str_idx:end_idx])),lfp_dHP_data[0, str_idx:end_idx], alpha=0.8)
plt.plot(np.arange(len(lfp_dHP_data[1, str_idx:end_idx])),lfp_dHP_data[1, str_idx:end_idx], alpha=0.8)
plt.plot(np.arange(len(lfp_dHP_data[2, str_idx:end_idx])),lfp_dHP_data[2, str_idx:end_idx], alpha=0.8)

# Plot the non-overlapping dHP spwrs
for i, central_frame in enumerate(dHP_non_overlapping_centralFrames):
    x_axis = range(dHP_non_overlapping_centralFrames[i]-100, dHP_non_overlapping_centralFrames[i]+101)
    plt.plot(x_axis,non_overlapping_lfp_dHP_spwr[0, i, :], 'r')
    plt.plot(x_axis,non_overlapping_lfp_dHP_spwr[1, i, :], 'r')
    plt.plot(x_axis,non_overlapping_lfp_dHP_spwr[2, i, :], 'r')
    
#"""
# Plot all labelled spwr
for i, central_frame in enumerate(sorted_ripple_centralFrames[0, :]):
    start_idx = sorted_ripple_centralFrames[0, i]-100
    end_idx = sorted_ripple_centralFrames[0, i]+101
    x_axis = range(start_idx, end_idx)
    plt.plot(x_axis,lfp_dHP_data[0, start_idx:end_idx], 'b', alpha=0.2)
    plt.plot(x_axis,lfp_dHP_data[1, start_idx:end_idx], 'b', alpha=0.2)
    plt.plot(x_axis,lfp_dHP_data[2, start_idx:end_idx], 'b', alpha=0.2)
#"""

#Segment of the data without no spwr labelled inside of it
"""
noise_start = 5800000
noise_end = 5900000
noisy_data = lfp_dHP_data[:, noise_start:noise_end]
#plt.plot(range(noise_start,noise_end), lfp_dHP_data[0, noise_start:noise_end], 'k', alpha=0.8)
#plt.plot(range(noise_start,noise_end), lfp_dHP_data[1, noise_start:noise_end], 'k', alpha=0.8)
#plt.plot(range(noise_start,noise_end), lfp_dHP_data[2, noise_start:noise_end], 'k', alpha=0.8)
"""

"""  
print("Plotting frames without labelled spwr")
plt.plot(non_spwr_frames,lfp_dHP_data[0, non_spwr_frames], 'k', alpha=0.8)
plt.plot(non_spwr_frames,lfp_dHP_data[1, non_spwr_frames], 'k', alpha=0.8)
plt.plot(non_spwr_frames,lfp_dHP_data[2, non_spwr_frames], 'k', alpha=0.8)
"""

plt.show()    

In [None]:
#Data information
print("Number of total frames: ", len(lfp_dHP_data[0, :]))
print("Number of non-labelled spwr frames: ", len(non_spwr_frames))
print("Num total centralFrames spwr frames: ", num_total_spwr*201)
print("Num total spwr frames with 100ms around centralFrame: ", len(lfp_dHP_data[0, :]) - num_total_spwr*201)
print("Num total dHP spwr frames with 100ms around centralFrame: ", sum(non_overlapping_dHP_spwr_bool)*201)
print("Number of non-overlapping dHP spwr: ", num_non_overlapping_dHP_spwr)


In [None]:
# For plotting a single ripple by providing the ripple_id which is a number 0-non_overlapping_lfp_dHP_spwr.shape[1]
ripple_id = 150
plt.figure()
plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[0, ripple_id, :])),non_overlapping_lfp_dHP_spwr[0, ripple_id, :])
plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[1, ripple_id, :])),non_overlapping_lfp_dHP_spwr[1, ripple_id, :])
plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[2, ripple_id, :])),non_overlapping_lfp_dHP_spwr[2, ripple_id, :])
plt.show()

# Should these considered spwrs?

In [None]:
#35 bad centralFrame?
#78, 206 and 294, 310 are not ripples right?
maybe_not_ripples_id = [1,3,6,7,8,9,12,13,15,17,34,35,41,44,60,61,67,71,77,78,79,80,94,99,104,178,206,294,300,310,313]

#for ripple_id in range(0, non_overlapping_lfp_dHP_spwr.shape[1]):
for ripple_id in maybe_not_ripples_id:
    print(ripple_id)
    plt.figure()
    plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[0, ripple_id, :])),non_overlapping_lfp_dHP_spwr[0, ripple_id, :])
    plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[1, ripple_id, :])),non_overlapping_lfp_dHP_spwr[1, ripple_id, :])
    plt.plot(np.arange(len(non_overlapping_lfp_dHP_spwr[2, ripple_id, :])),non_overlapping_lfp_dHP_spwr[2, ripple_id, :])
    plt.show()
    plt.close()

In [None]:
#For plotting more than just the 200 frames around a ripple centralFrame.
#Given id, get centralFrame and then give the number of frames to the left and right of the centralframe to be plotted.

ripple_id = 150
centralFrame_of_ripple = dHP_non_overlapping_centralFrames[ripple_id]

#The ripple before the current ripple
#print(sorted_ripple_centralFrames[:, 250:300])
#print(centralFrame_of_ripple)
#centralFrame_of_ripple = 1177374

frame_range = range(centralFrame_of_ripple-200, centralFrame_of_ripple+101)
plt.figure()
plt.plot(np.arange(len(lfp_dHP_data[0, frame_range])),lfp_dHP_data[0, frame_range])
plt.plot(np.arange(len(lfp_dHP_data[1, frame_range])),lfp_dHP_data[1, frame_range])
plt.plot(np.arange(len(lfp_dHP_data[2, frame_range])),lfp_dHP_data[2, frame_range])
plt.show()

### IS THIS A SPWR? ###

In [None]:
#Unlabelled potential ripple that I found myself.
print(sorted_ripple_centralFrames[(5384000 < sorted_ripple_centralFrames) & (sorted_ripple_centralFrames < 5384300)])
sorted_ripple_centralFrames[(5100000 < sorted_ripple_centralFrames) & (sorted_ripple_centralFrames < 5384300)]


plt.figure()
str_idx = 5384000
end_idx = 5384300
#str_idx = 5384500
#end_idx = 5384650
plt.plot(np.arange(len(lfp_dHP_data[0, str_idx:end_idx])),lfp_dHP_data[0, str_idx:end_idx])
plt.plot(np.arange(len(lfp_dHP_data[1, str_idx:end_idx])),lfp_dHP_data[1, str_idx:end_idx])
plt.plot(np.arange(len(lfp_dHP_data[2, str_idx:end_idx])),lfp_dHP_data[2, str_idx:end_idx])
plt.show()  

In [None]:
#Low drop spwr, good to know not all spwr events drop to around -1000
str_idx = 6293668-100
end_idx = 6293668+101
plt.plot(np.arange(len(lfp_dHP_data[0, str_idx:end_idx])),lfp_dHP_data[0, str_idx:end_idx])
plt.plot(np.arange(len(lfp_dHP_data[1, str_idx:end_idx])),lfp_dHP_data[1, str_idx:end_idx])
plt.plot(np.arange(len(lfp_dHP_data[2, str_idx:end_idx])),lfp_dHP_data[2, str_idx:end_idx])
plt.show() 

# Noisy data
### Instructions
What I have done here is to select an interval based on the entire dataset where there are no labelled ripples (See plot of all data above). After that I looked at 200 frames at a time and manually wrote down the ranges of frames that were potentially not to be considered noise. I did this for a range of 100000 because I needed 439 * 201 = 88239 noise frames initially to have an equally large noise dataset as spwr dataset. However, due to a bug in my code this was later reduced to 319 * 201 but I could still use the same range. If one has a larger spwr dataset one would have to go through a larger range of non-labelled data and write down the ranges with potentially non-noise.

In [None]:
#Don't have it run this

noise_start = 5800000
noise_end = 5900000

#Have done the first 500 (0-499 in python)
for i in range(1, 2):
    str_idx = noise_start + i*200
    end_idx = noise_start + (i+1)*200
    print("range(" + str(str_idx) + "," + str(end_idx) + ")")
    plt.figure()
    noisy_data_ch1 = lfp_dHP_data[0, str_idx:end_idx]
    noisy_data_ch2 = lfp_dHP_data[1, str_idx:end_idx]
    noisy_data_ch3 = lfp_dHP_data[2, str_idx:end_idx]
    plt.plot(noisy_data_ch1, 'b')
    plt.plot(noisy_data_ch2, color='C1')
    plt.plot(noisy_data_ch3, 'g')
    plt.show()
    plt.close()

In [None]:
#I manually stepped through the data from frame 5800000 to 5900000 and picked out events of interest
#These were parts of the signal that looked like:
    #spwrs that hadn't been labelled
    #Sharp drops in all channels that could be an indication of the sharp drop in spwr
    #Inverted spwrs with the lowest channel on top and the highest channel at the bottom
#After talking to Peter and going through the data again we determined that only some spwr_like events
#should be excluded from the dataset as labelling them as noise could be wrong even if the 
#automatic detection algorithm from buzaki did not classify that part of the signal as a spwr.
spwr_like_string = """#spwr-like
range(5864600,5864800)
#spwr-like
range(5867400,5867600)
#two spwr?
range(5869600,5869800)
#spwr-like
range(5870000,5870200)
#spwr_like
range(5871400,5871700)
#spwr_like
range(5896300,5896500)
#spwr_like
range(5888700,5888900)
#spwr-like
range(5885900,5886100)
#spwr-like
range(5882600,5882800)
#spwr_like
range(5879800,5880000)
#spwr_like
range(5878600,5878800)
#spwr_like
range(5877800,5878000)
#spwr_like
range(5876400,5876600)
#spwr_like
range(5874700,5874900)
#spwr_like
range(5874400,5874600)
#spwr_like
range(5873000,5873200)
#spwr-like
range(5872200,5872400)
#spwr-like
range(5802200,5802400)
#spwr-like
range(5804800,5805000)
#spwr-like
range(5805800,5806200)
#spwr-like
range(5807100,5807500)
#spwr-like
range(5808600,5809000)
#spwr-like
range(5828400,5828600)
#spwr-like
range(5839000,5839200)
#spwr-like and invert
range(5840800,5841800)
#spwr-like
range(5844350,5844550)
#spwr-like
range(5847800,5848000)
#spwr_like
range(5851000,5851600)
#Inverted spwr?
range(5851800,5852000)
#spwr-like
range(5852400,5852600)
#Two spwrs?
range(5853400,5853800)"""

#Convert string of ranges to list of ranges
# Extract spwr like ranges
spwr_like = []
a_splitted = spwr_like_string.split('\n')
for i, part in enumerate(a_splitted):
    if(i % 2):
        spwr_like_frames = 0
        test = 'spwr_like_frames = '+ part
        exec(test)
        spwr_like.append(spwr_like_frames)
        
#Sort the list according to the first frame in each range
spwr_like = sorted(spwr_like, key=lambda r: r.start)

sharp_drops = [range(5814200, 5814400), range(5814900, 5815100), range(5818200, 5818400), range(5850400, 5850700)\
, range(5844200, 5844400), range(5837200, 5837400), range(5836600, 5836800), range(5831600, 5831800)\
,range(5810750, 5811100), range(5801700, 5801900), range(5872600, 5872800), range(5860400, 5860600)\
,range(5867600,5867800)]

inverted_spwr = [range(5881400,5881600), range(5810700, 5810900), range(5823200, 5823400), range(5880200,5880400) \
, range(5880800,5881200), range(5882200,5882400), range(5885000,5885200)]

# Checking sharp drop, inverted and spwr-like to see if they can be noise
For each of the identified potentially non-noise ranges of each time I then plotted the data again to go through them with Peter and determine if it should be considered noise or spwr. The conclusion was that the inverted spwrs and sharp drops could be considered noise but some of the spwr_like signals might be unlabelled spwrs. The ranges of those spwr_like signals were therefore saved and then excluded when creating the noise dataset by only gathering data outside of those ranges.

In [None]:
for i, frames in enumerate(inverted_spwr):
    print(i)
    plt.figure()
    plt.plot(np.arange(len(lfp_dHP_data[0, frames])),lfp_dHP_data[0, frames], 'b')
    plt.plot(np.arange(len(lfp_dHP_data[1, frames])),lfp_dHP_data[1, frames], 'C1')
    plt.plot(np.arange(len(lfp_dHP_data[2, frames])),lfp_dHP_data[2, frames], 'g')
    plt.show()
    plt.close()

"""
#4
frames = range(5881000,5881200)
plt.figure()
plt.plot(np.arange(len(lfp_dHP_data[0, frames])),lfp_dHP_data[0, frames], 'b')
plt.plot(np.arange(len(lfp_dHP_data[1, frames])),lfp_dHP_data[1, frames], 'C1')
plt.plot(np.arange(len(lfp_dHP_data[2, frames])),lfp_dHP_data[2, frames], 'g')
plt.show()
plt.close()
"""
#Probably fine
inverted_spwr_noisy = [0,1,2,3,4,5,6]
#Check with Peter
inverted_spwr_no_noisy = []

In [None]:
for i, frames in enumerate(sharp_drops):
    print(i)
    plt.figure()
    plt.plot(np.arange(len(lfp_dHP_data[0, frames])),lfp_dHP_data[0, frames], 'b')
    plt.plot(np.arange(len(lfp_dHP_data[1, frames])),lfp_dHP_data[1, frames], 'C1')
    plt.plot(np.arange(len(lfp_dHP_data[2, frames])),lfp_dHP_data[2, frames], 'g')
    plt.show()
    plt.close()

#Probably fine
sharp_drops_noisy = [0,1,2,3,4,5,6,7,8,9,10,11,12]
#Check with Peter
sharp_drops_no_noisy = []

In [None]:
for i, frames in enumerate(spwr_like):
    print(i)
    print(frames)
    plt.figure()
    plt.plot(np.arange(len(lfp_dHP_data[0, frames])),lfp_dHP_data[0, frames], 'b')
    plt.plot(np.arange(len(lfp_dHP_data[1, frames])),lfp_dHP_data[1, frames], 'C1')
    plt.plot(np.arange(len(lfp_dHP_data[2, frames])),lfp_dHP_data[2, frames], 'g')
    plt.show()
    plt.close()

#Probably fine
spwr_like_probably_noise = [4,5,6,7,8,9,10,11,12,15,16,18,22,27,28,30]
#Check with Peter
#25 too large interval
#28 too large interval
spwr_like_not_noise = [0,1,2,3,13,14,17,19,20,21,23,24,25,26,29]

In [None]:
#25
frames = range(5840800,5841600)
plt.figure()
plt.plot(np.arange(len(lfp_dHP_data[0, frames])),lfp_dHP_data[0, frames], 'b')
plt.plot(np.arange(len(lfp_dHP_data[1, frames])),lfp_dHP_data[1, frames], 'C1')
plt.plot(np.arange(len(lfp_dHP_data[2, frames])),lfp_dHP_data[2, frames], 'g')
plt.show()
plt.close()

## Create noise dataset

In [None]:
#Noise dataset is created using data from frame 5800000 to 5900000
noise_start = 5800000
noise_end = 5900000

#num_channel, num_dHP in dataset, signal length
noisy_lfp_dHP = np.zeros((num_channels, num_non_overlapping_dHP_spwr, spwr_size)) 

#Curr noise_sample
i = 0

start_frame = noise_start
end_frame = start_frame + spwr_size

#Keep track of which spwr_like data we are currently at
curr_spwr_like_idx = 0

curr_spwr_like_start = spwr_like[spwr_like_not_noise[curr_spwr_like_idx]][0]
curr_spwr_like_end = spwr_like[spwr_like_not_noise[curr_spwr_like_idx]][-1]

#While we still need more data samples for the noise dataset
while i < num_non_overlapping_dHP_spwr:
    
    #If noise_data sample has passed the current spwr_like data start
    if(end_frame >= curr_spwr_like_start):
        
        #Set current noise data to the frames after the spwr_like data
        start_frame = curr_spwr_like_end + 1
        end_frame = start_frame + (spwr_size)
        
        #and set spwr_like data to the next spwr_like data.
        curr_spwr_like_idx += 1
        
        #if we still have spwr_like data left update the start and end of it, else set the pointer to the end of the noise
        #this is done so that when there is no spwr_like data left the above if statement will only trigger when
        #we run out of data, i.e. go past the noise_end limit.
        if(curr_spwr_like_idx < len(spwr_like_not_noise)):
            curr_spwr_like_start = spwr_like[spwr_like_not_noise[curr_spwr_like_idx]][0]
            curr_spwr_like_end = spwr_like[spwr_like_not_noise[curr_spwr_like_idx]][-1]
        
        else:
            curr_spwr_like_start = noise_end 
    
    #If we are not at risk of sampling spwr_like data we simply add a datapoint to the noise dataset.
    else:
        start_frame = end_frame
        end_frame += 201
        noisy_lfp_dHP[0, i, :] = lfp_dHP_data[0, start_frame:end_frame]
        noisy_lfp_dHP[1, i, :] = lfp_dHP_data[1, start_frame:end_frame]
        noisy_lfp_dHP[2, i, :] = lfp_dHP_data[2, start_frame:end_frame]
        i += 1

    #This only triggers if there is not enough data in the set range
    #To solve this one must extend the range, go through the data manually and pick out events that could be unlabelled spwrs
    if(end_frame > noise_end):
        print("NEED MORE DATA")
        break

#Should be less than noise_end
print("Current end frame: ", end_frame)

# Fine tune the noise dataset
After I have created a noise dataset I went through all of the noise datapoints and found potential spwrs that I missed during the first investigation of the data from 5800000 to 5900000. I saved their indicies in the maybe_non_noise_indicies aray and replaced them with noise collected after the last collected noise datapoint. But this was also done manually as I found x amount of sequential 201 datapoints. This was done by adding a number to the frames after the last recorded noise dataframe and then looking through the data until enough sequential non-noise data was found.

In [None]:
#Plot each of the noise sampled datapoints
for i in range(noisy_lfp_dHP.shape[1]):
    print(i)
    plt.figure()
    plt.plot(np.arange(noisy_lfp_dHP.shape[2]),noisy_lfp_dHP[0, i, :], 'b')
    plt.plot(np.arange(noisy_lfp_dHP.shape[2]),noisy_lfp_dHP[1, i, :], 'C1')
    plt.plot(np.arange(noisy_lfp_dHP.shape[2]),noisy_lfp_dHP[2, i, :], 'g')
    plt.show()
    plt.close()


In [None]:
#ALL OF THESE ARE INTERESTING TO EVALUATE THE ALGORITHM ON. DOES IT THINK THESE ARE NOISE OR SPWR?
#263 ripple?
#Possibly data that should not be considered noise that is currently in the noise dataset
#!!!NOTE!!! These indicies will increase if based on the number of noise samples required
#TODO Maybe fix this so that it is more automated.
maybe_non_noise_indicies = np.array([33, 184, 194, 195, 210, 227, 250, 263, 321, 332])

#Only remove the indicies that are in the current dataset
maybe_non_noise_indicies = maybe_non_noise_indicies[maybe_non_noise_indicies < noisy_lfp_dHP.shape[1]-1] #-1 to exclude ind if it is the last
print("Current indicies in noise data that should not be considered noise: ", maybe_non_noise_indicies)

In [None]:
#This code is to extract additional noise data to replace possibly non-noise in the noise dataset

print("Current start frame: ", start_frame)
print("Current end_frame: ", end_frame)

#The +4600 is something I found so that the next noisy_lfp_dHP.shape[1] number 
#of 201 frame bundles do not contain a spwr_like signal
#TODO: x and y are bad names but I need something other than start and end to not overwrite them.
x_frame = start_frame+4600
y_frame = end_frame+4600

frame_range = range(x_frame,y_frame)

#Check that the next number: noisy_lfp_dHP.shape[1] data samples are noise
for i in range(0, maybe_non_noise_indicies.shape[0]):
    plt.figure()
    plt.plot(np.arange(len(lfp_dHP_data[0, frame_range])),lfp_dHP_data[0, frame_range], 'b')
    plt.plot(np.arange(len(lfp_dHP_data[1, frame_range])),lfp_dHP_data[1, frame_range], 'C1')
    plt.plot(np.arange(len(lfp_dHP_data[2, frame_range])),lfp_dHP_data[2, frame_range], 'g')
    plt.show()
    plt.close()
    x_frame = y_frame
    y_frame = x_frame + 201
    frame_range = range(x_frame,y_frame)

In [None]:
#Save possibly spwr data in the noise dataset for later analysis
#and replace it with noise from the next timesteps after the last one.

x_frame = start_frame+4600
y_frame = end_frame+4600

maybe_spwr_data = np.zeros((3, len(maybe_non_noise_indicies), spwr_size))
for i, idx in enumerate(maybe_non_noise_indicies):
    #Save possibly spwr data for testing
    maybe_spwr_data[0, i, :] = noisy_lfp_dHP[0, idx, :]
    maybe_spwr_data[1, i, :] = noisy_lfp_dHP[1, idx, :]
    maybe_spwr_data[2, i, :] = noisy_lfp_dHP[2, idx, :]
    #Replace possibly spwr data with noisy data
    noisy_lfp_dHP[0, idx, :] = lfp_dHP_data[0, x_frame:y_frame]
    noisy_lfp_dHP[1, idx, :] = lfp_dHP_data[1, x_frame:y_frame]
    noisy_lfp_dHP[2, idx, :] = lfp_dHP_data[2, x_frame:y_frame]
    
    x_frame = y_frame
    y_frame += spwr_size

In [None]:
maybe_non_noise_indicies

# Prepare the data for the model

In [None]:
### Normalize, Standarize and split the data into a train and a test set###

#Normalize the data together
data = np.zeros((num_channels, num_non_overlapping_dHP_spwr*2, spwr_size))
data[0] = np.concatenate((non_overlapping_lfp_dHP_spwr[0], noisy_lfp_dHP[0]))
data[1] = np.concatenate((non_overlapping_lfp_dHP_spwr[1], noisy_lfp_dHP[1]))
data[2] = np.concatenate((non_overlapping_lfp_dHP_spwr[2], noisy_lfp_dHP[2]))

data_norm = np.zeros(data.shape)

#Z-normalization
channel_means = np.mean(data, axis=(1,2))
channel_stds = np.std(data, axis=(1,2))
data_norm[0] = (data[0] - channel_means[0]) / channel_stds[0]
data_norm[1] = (data[1] - channel_means[1]) / channel_stds[1]
data_norm[2] = (data[2] - channel_means[2]) / channel_stds[2]
print("Channel means: ", channel_means)
print("Channel stds: ", channel_stds)

#Min-max normalization
channel_mins = np.amin(data, axis=(1,2))
channel_maxs = np.amax(data, axis=(1,2))
print("Channel mins: ", channel_mins)
print("Channel maxs: ", channel_maxs)

data_norm[0] = (data[0] - channel_mins[0]) / (channel_maxs[0]-channel_mins[0])
data_norm[1] = (data[1] - channel_mins[1]) / (channel_maxs[1]-channel_mins[1])
data_norm[2] = (data[2] - channel_mins[2]) / (channel_maxs[2]-channel_mins[2])

#Separate back to spwr and noise to ensure equal number of spwr and noise in train/test
spwr_data_norm = data_norm[:, :num_non_overlapping_dHP_spwr, :]
noise_data_norm = data_norm[:, num_non_overlapping_dHP_spwr:, :]

#Create labels
spwr_labels = np.ones((num_non_overlapping_dHP_spwr), dtype=int)
noise_labels = np.zeros((num_non_overlapping_dHP_spwr),dtype=int)
labels = np.concatenate((spwr_labels, noise_labels))

#Shuffle the spwr and noise data randomly separately
ind = np.random.permutation(spwr_data_norm.shape[1])
train_test_split = 0.8
half_train_size = int(len(ind)*train_test_split)
half_test_size = num_non_overlapping_dHP_spwr-half_train_size
train_size = half_train_size*2
test_size = half_test_size*2

print("Num training examples: ", train_size)
print("Num test examples: ", test_size)

spwr_data_norm_train = spwr_data_norm[:, ind[:half_train_size], :]
noise_data_norm_train = noise_data_norm[:, ind[:half_train_size], :]
spwr_data_norm_test = spwr_data_norm[:, ind[half_train_size:], :]
noise_data_norm_test = noise_data_norm[:, ind[half_train_size:], :]

x_train = np.zeros((num_channels, train_size, spwr_size))
x_train[:, :half_train_size, :] = spwr_data_norm_train
x_train[:, half_train_size:, :] = noise_data_norm_train
y_train = np.concatenate((spwr_labels[ind[:half_train_size]], noise_labels[ind[:half_train_size]]))

x_test = np.zeros((num_channels, test_size, spwr_size))
x_test[:, :half_test_size, :] = spwr_data_norm_test
x_test[:, half_test_size:, :] = noise_data_norm_test
y_test = np.concatenate((spwr_labels[ind[half_train_size:]], noise_labels[ind[half_train_size:]]))

train_perm_ind = np.random.permutation(train_size)
test_perm_ind = np.random.permutation(test_size)

#Shuffle the data again but for spwr and noise together in train and test
x_train = x_train[:, train_perm_ind, :]
y_train = y_train[train_perm_ind]
x_test = x_test[:, test_perm_ind, :]
y_test = y_test[test_perm_ind]

#For Keras model input as it expect (datapoint_ind, data, feature number)
x_train = np.transpose(x_train, axes=(1,2,0))
x_test = np.transpose(x_test, axes=(1,2,0))

num_classes = len(np.unique(y_train))
print("Number of classes: ", num_classes)
print("x train shape: ", x_train.shape) 
print("y train shape: ", y_train.shape) 
print("x test shape: ", x_test.shape)
print("y test shape: ", y_test.shape) 

In [None]:
#Make model

###Change these to not have an old model get overwritten###
main_name = "test"
model_name = main_name + "_model.h5"
fig_name = main_name + "_model_epoch_accuracy.pdf"

def make_model(input_shape):
    input_layer = keras.layers.Input(input_shape)

    conv1 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(input_layer)
    conv1 = keras.layers.BatchNormalization()(conv1)
    conv1 = keras.layers.ReLU()(conv1)

    conv2 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv1)
    conv2 = keras.layers.BatchNormalization()(conv2)
    conv2 = keras.layers.ReLU()(conv2)

    conv3 = keras.layers.Conv1D(filters=64, kernel_size=3, padding="same")(conv2)
    conv3 = keras.layers.BatchNormalization()(conv3)
    conv3 = keras.layers.ReLU()(conv3)

    gap = keras.layers.GlobalAveragePooling1D()(conv3)

    output_layer = keras.layers.Dense(num_classes, activation="softmax")(gap)

    return keras.models.Model(inputs=input_layer, outputs=output_layer)


model = make_model(input_shape=x_train.shape[1:])
keras.utils.plot_model(model, show_shapes=True)

In [None]:
#Train model
epochs = 500
batch_size = 32

callbacks = [
    keras.callbacks.ModelCheckpoint(
        model_name, save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=20, min_lr=0.0001
    ),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=50, verbose=1),
]
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)
history = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=callbacks,
    validation_split=0.2,
    verbose=1,
)

In [None]:
#Test model
model = keras.models.load_model(model_name)

test_loss, test_acc = model.evaluate(x_test, y_test)

print("Test accuracy", test_acc)
print("Test loss", test_loss)

In [None]:
# Plot training and test loss
metric = "sparse_categorical_accuracy"
plt.figure()
plt.plot(history.history[metric])
plt.plot(history.history["val_" + metric])
plt.title("model " + metric)
plt.ylabel(metric, fontsize="large")
plt.xlabel("epoch", fontsize="large")
plt.legend(["train", "val"], loc="best")
plt.savefig(fig_name)
plt.show()
plt.close()

In [None]:
def plot_lfp_data(lfp_data):
    """
    This function plots the lfp_data, 200 frames at a time
    
    Parameters:
        - lfp_data has shape (num_channels, num_samples, spwr_size)
    Returns:
    """
    for i in range(lfp_data.shape[1]):
        print(i)
        plt.figure()
        plt.plot(np.arange(lfp_data.shape[2]), lfp_data[0, i, :], 'b')
        plt.plot(np.arange(lfp_data.shape[2]), lfp_data[1, i, :], 'C1')
        plt.plot(np.arange(lfp_data.shape[2]), lfp_data[2, i, :], 'g')
        plt.show()
        plt.close()

In [None]:
#Plot the lfp data of the frames that were removed from the noise dataset because of possibly being spwrs.  
plot_lfp_data(maybe_spwr_data)

In [None]:
#All of the events that I removed from the noise dataset because I thought they looked too much like ripples
#are classified as ripples by the model.

def preprocess_lfp_data_min_max_norm(lfp_data, mins, maxs):
    test_data = np.zeros(lfp_data.shape)
    test_data[0] = (lfp_data[0] - mins[0]) / (maxs[0] - mins[0])
    test_data[1] = (lfp_data[1] - mins[1]) / (maxs[1] - mins[1])
    test_data[2] = (lfp_data[2] - mins[2]) / (maxs[2] - mins[2])
    return np.transpose(test_data, axes=(1,2,0))

def preprocess_lfp_data_z_norm(lfp_data, means, stds):
    test_data = np.zeros(lfp_data.shape)
    test_data[0] = (lfp_data[0] - means[0]) / stds[0]
    test_data[1] = (lfp_data[1] - means[1]) / stds[1]
    test_data[2] = (lfp_data[2] - means[2]) / stds[2]
    return np.transpose(test_data, axes=(1,2,0))

#y_hat = model.predict(preprocess_lfp_data_z_norm(maybe_spwr_data, channel_mins, channel_maxs))
y_hat = model.predict(preprocess_lfp_data_min_max_norm(maybe_spwr_data, channel_mins, channel_maxs))
print(y_hat)
print(np.argmax(y_hat, axis=1))

In [None]:
#100% accuracy on the test data
y_hat = model.predict(x_test)
print(np.argmax(y_hat, axis=1))
print(y_test)
print(np.argmax(y_hat, axis=1) - y_test)
print("Number of errors on test data: ", np.sum(np.abs(np.argmax(y_hat, axis=1) - y_test)))

In [None]:
#Test dataset to evaluate the model performance on new frames it has never seen before

"""
Good examples:
num_samples = 10
start_frame = 3020000
#4,5 and 7
"""

num_samples = 10
start_frame = 3020000
end_frame = start_frame + spwr_size
frame_range = range(start_frame, end_frame)

more_test_data = np.zeros((num_channels, num_samples, spwr_size)) 
for sample in range(num_samples):
    more_test_data[:, sample, :] = lfp_dHP_data[:, frame_range]
    start_frame = end_frame
    end_frame += spwr_size
    frame_range = range(start_frame, end_frame)

#y_hat = model.predict(preprocess_lfp_data_z_norm(more_test_data, channel_means, channel_stds))
y_hat = model.predict(preprocess_lfp_data_min_max_norm(more_test_data, channel_mins, channel_maxs))
#print(y_hat)
print(np.argmax(y_hat, axis=1))
        
plot_lfp_data(more_test_data)

In [None]:
#For online classification investigation

#This is a ripple in the dataset
ripple_id = 150
centralFrame_of_ripple = dHP_non_overlapping_centralFrames[ripple_id]

offset = 150

curr_start_f = centralFrame_of_ripple - offset - 100
ripple_end_f = centralFrame_of_ripple + 101 + offset

time_lfp_data = np.zeros((num_channels, ripple_end_f-curr_start_f, spwr_size))
ind_f = 0
for curr_f in range(curr_start_f, ripple_end_f, 1):
    f_range = range(curr_f, curr_f+spwr_size)
    time_lfp_data[:, ind_f, :] = lfp_dHP_data[:, f_range]
    ind_f += 1

y_hat = model.predict(preprocess_lfp_data_min_max_norm(time_lfp_data, channel_mins, channel_maxs))
preds = np.argmax(y_hat, axis=1)

#Print predictions
#print(np.argmax(y_hat, axis=1)) #54-233 model predicts 1

start_frames = np.arange(curr_start_f, ripple_end_f, 1)

fig,ax = plt.subplots()
                         
def animate(i):
    ax.clear()
    ax.set_xlim(start_frames[i], start_frames[i]+spwr_size)
    plot_title = "Prediction is " + str(preds[i])
    if(preds[i]): 
        ax.set_title(plot_title, fontsize = 20, color = 'green')
        ax.spines['bottom'].set_color('green')
        ax.spines['top'].set_color('green')
        ax.spines['left'].set_color('green')
        ax.spines['right'].set_color('green')
    else:
        ax.set_title(plot_title, fontsize = 20, color = 'red')
        ax.spines['bottom'].set_color('red')
        ax.spines['top'].set_color('red')
        ax.spines['left'].set_color('red')
        ax.spines['right'].set_color('red')
        
    line1, = ax.plot(np.arange(start_frames[i], start_frames[i]+spwr_size), time_lfp_data[0, i, :], color = 'blue', lw=1)
    line2, = ax.plot(np.arange(start_frames[i], start_frames[i]+spwr_size),time_lfp_data[1, i, :], color = 'orange', lw=1)
    line3, = ax.plot(np.arange(start_frames[i], start_frames[i]+spwr_size),time_lfp_data[2, i, :], color = 'green', lw=1)
    return line1, line2, line3,

time_lfp_data.shape[1] - spwr_size
ani = FuncAnimation(fig, animate, interval=40, blit=True, repeat=True, frames=ripple_end_f-curr_start_f-200)  
ani.save("anim.gif", dpi=300, writer=PillowWriter(fps=25))

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10, 5))

ax1.plot(np.arange(spwr_size),time_lfp_data[0, 53, :], 'b')
ax1.plot(np.arange(spwr_size),time_lfp_data[1, 53, :], 'C1')
ax1.plot(np.arange(spwr_size),time_lfp_data[2, 53, :], 'g')
ax1.set_title("TIME 53: PREDICTION IS 0")
ax2.plot(np.arange(spwr_size),time_lfp_data[0, 54, :], 'b')
ax2.plot(np.arange(spwr_size),time_lfp_data[1, 54, :], 'C1')
ax2.plot(np.arange(spwr_size),time_lfp_data[2, 54, :], 'g')
ax2.set_title("TIME 54: PREDICTION IS 1")

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(10, 5))

ax1.plot(np.arange(spwr_size),time_lfp_data[0, 233, :], 'b')
ax1.plot(np.arange(spwr_size),time_lfp_data[1, 233, :], 'C1')
ax1.plot(np.arange(spwr_size),time_lfp_data[2, 233, :], 'g')
ax1.set_title("TIME 233: PREDICTION IS 1")
ax2.plot(np.arange(spwr_size),time_lfp_data[0, 234, :], 'b')
ax2.plot(np.arange(spwr_size),time_lfp_data[1, 234, :], 'C1')
ax2.plot(np.arange(spwr_size),time_lfp_data[2, 234, :], 'g')
ax2.set_title("TIME 234: PREDICTION IS 0")
plt.show()

In [None]:
#Noisy dataset samples 

num_samples = 100

more_test_data = np.zeros((num_channels, num_samples, spwr_size)) 
for channel in range(num_channels):
    for sample in range(num_samples):
        more_test_data[channel, sample, :] = noisy_lfp_dHP[channel, sample, :]

#y_hat = model.predict(preprocess_lfp_data_z_norm(more_test_data, channel_means, channel_stds))
y_hat = model.predict(preprocess_lfp_data_min_max_norm(more_test_data, channel_mins, channel_maxs))
#print(y_hat)
print(np.argmax(y_hat, axis=1))
        
plot_lfp_data(more_test_data)

In [None]:
#Train dataset

ind = 100
x_train
y_hat = model.predict(x_train)
#y_hat = model.predict(get_test_data(preprocess_lfp_data[:, ind:ind+10, :], means, stds))
print(y_train)
print(np.abs(y_train - np.argmax(y_hat, axis=1)))
plot_lfp_data(non_overlapping_lfp_dHP_spwr[:, ind:ind+10, :])


In [None]:
#THIS CODE GAVE ME 100 Accuracy on the test dataset even though it does not work as I intended because it doesn't remove 
#a spwr if it overlaps with the previous spwr that has been removed. I.e. if the consecutive spwrs overlaps the third is kept 
#as long as it does not overlap with the 4th in line spwr. This code does also not include the last spwr even if
#it does not overlap with any spwr.

#I want to check if 100ms around each ripple centre is free of other ripples, if not it is removed from spwr dataset.
#This is done to have isolated spwrs in the training/test set and to avoid the same data in different samples.
#Hopefully this allows the model to better learn what a spwr is.
timesteps = 100

#Copy of all spwr centralFrames. Overlapping spwr frames are removed.
copy_centralFrames = np.copy(sorted_ripple_centralFrames)
copy_centralFrames = copy_centralFrames.flatten() #Flatten => (num_spwr, )
 
#Current index in array with removed overlapping spwrs
idx = 0

#Index of current and next spwr of all spwr events
curr_spwr = 0
next_spwr = 1

#spwr events to be removed because of overlapping iHP and dHP as iHP spwr is still captured by dHP channels
non_overlapping_spwr = []

num_dHP_removed = 0
num_iHP_removed = 0

#While there are at least 2 spwr events left. If only one left I cannot compare it to anything
#TODO do I leave the last spwr out of the dataset even if it does not overlap?
num_spwr_left = copy_centralFrames.shape[0] - 1 
while idx < num_spwr_left:

    curr_ripple_start = copy_centralFrames[idx] - timesteps
    curr_ripple_end = copy_centralFrames[idx] + timesteps
    next_ripple_start = copy_centralFrames[idx+1] - timesteps
    
    #If next overlaps with current ripple 
    if(curr_ripple_start < next_ripple_start < curr_ripple_end):
        #Discard both overlapping spwr events one after another
        copy_centralFrames = np.delete(copy_centralFrames, idx) #This discard curr
        copy_centralFrames = np.delete(copy_centralFrames, idx) #This discards next which is now at curr pos

        if(dHP_spwr_bool[curr_spwr]):
            num_dHP_removed += 1
        
        else:
            num_iHP_removed += 1
        
        if(dHP_spwr_bool[curr_spwr+1]):
            num_dHP_removed += 1
        else:
            num_iHP_removed += 1
              
        num_spwr_left += -2  
        curr_spwr += 2
        next_spwr += 2
    
    #No overlap => save current spwr as non-overlapping and go to next spwr
    else:
        non_overlapping_spwr.append(curr_spwr)
        curr_spwr = next_spwr
        idx+= 1
        next_spwr += 1 
    
#Check to see if the last ripple is not overlapping
print("num spwr left: ", num_spwr_left)
print("idx: ", idx)

print("Total number of spwr: ", num_total_spwr)
print("Total number of dHP spwr", num_dHP_spwr)
print("Total number of removed overlapping spwr events: ", len(non_overlapping_spwr)) 
print("Number of dHP spwr removed: ", num_dHP_removed)
print("Number of iHP spwr removed: ", num_iHP_removed)
print("Portion of total lost spwr events: ", len(non_overlapping_spwr)/num_total_spwr)
print("Portion of total lost dHP spwr events: ", num_dHP_removed/num_dHP_spwr)

#True if spwr event is not removed
#False is spwr event is removed
kept_spwr_bool = np.full((num_total_spwr), False) # Creates boolean array with all False values
for i in non_overlapping_spwr:
    kept_spwr_bool[i] = True
    
#Only non-removed dHP spwr events are kept
non_overlapping_dHP_spwr_bool = np.logical_and(kept_spwr_bool, dHP_spwr_bool)
print("Num non-overlapping dHP spwrs: ", np.sum(non_overlapping_dHP_spwr_bool))

In [None]:
"""
print(dHP_spwr_bool[:12])
print(kept_spwr_bool[:12])
print(non_overlapping_dHP_spwr_bool[:12])
print(sorted_ripple_centralFrames[0, :12])
"""

#Check to see that only pairs of spwr events have been removed. 
#If the difference between two non-overlapping spwr events -1 modulus 2 is not 0 it means the code 
#above is wrong. Only pairs of spwr events can be removed

for i in range(len(non_overlapping_spwr)-1):
    if(np.abs(non_overlapping_spwr[i]-non_overlapping_spwr[i+1])>1):
        if(non_overlapping_spwr[i+1]-non_overlapping_spwr[i] -1 )% 2:
            print("BAD")
            break