# Imports

In [None]:
import os
import rippl_AI
import importlib
importlib.reload(rippl_AI)
import aux_fcn

import matplotlib.pyplot as plt
import numpy as np
%matplotlib qt

# Basic detection example
In this section, a use example of the predict and detection functions are provided

### Download data

In [None]:
from figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                  
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data loading

In [None]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)
channels_map = aux_fcn.load_channels_map(path)

# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
print(channels_map)


In [None]:
# The predict function takes care of normalizing and subsampling your data
# If no architecture or model is specified, the best CNN1D will be used
prob,LFP_norm=rippl_AI.predict(LFP,sf,d_sf=2500) 

In [None]:
# An interactive GUI will be displayed, choose your deired threshold
det_ind=rippl_AI.get_intervals(prob,LFP_norm=LFP_norm) 
print(f"{det_ind.shape[0]} events where detected")

# Get performances after detection
Every model predict, get_intervals is used automatically and the performance metric is ploted

### Data download

In [None]:
from figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data loading

In [None]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
# Now the ground truth (GT) tagged events is loaded 
ripples=aux_fcn.load_ripples(path)/sf
# Reformat channels into correct values
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
# Read .dat
print('Channel map value: ',channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=True)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)

In [None]:
# Two loops going over every possible model
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
SWR_prob=[[None]*5]*5
for i,architecture in enumerate(architectures):
    print(i,architecture)
    for n in range(1,6):
        # Make sure the selected model expected number of channels is the same as the channels array passed to the predict fcn
        # In this case, we are manually setting the channel array to 3 
        if architecture=='CNN2D' and n>=3:
            channels=[0,3,7]
        else:
            channels=[0,1,2,3,4,5,6,7]
        SWR_prob[i][n-1],_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=n,channels=channels)

# SWR_prob contains the output of each model


In [None]:
th_arr=np.linspace(0.1,1,10)
fig,axs=plt.subplots(5,5,figsize=(10,10),sharex='all',sharey='all')
for i in range(5):
    for j in range(5):
        F1_arr=np.zeros(shape=(len(th_arr)))
        for k,th in enumerate(th_arr):
            det_ind=rippl_AI.get_intervals(SWR_prob[i][j],threshold=th)
            #print(ripples)
            _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
        axs[i,j].plot(th_arr,F1_arr)
    axs[i,0].set_title(architectures[i])

axs[0,0].set_xlabel('Threshold')
axs[0,0].set_ylabel('F1')

# Detecting with less than 8 channels
Detectors need several channels for optimal performance. We found out that 8 channels have enough information to assure good performance. But what happens if we don't have 8? We have seen that interpolating the missing channels also works. In this section, we will show how to use the interpolation function we have created for this purpose, inside the `aux_fcn` package.

### Data download

In [None]:
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [14959449] 
sess=['Dlx1']                                    
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print("Data already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data load
To ilustrate how 'interpolate_channels' can be used to extract the desired number of channels, we will be simulating two cases using the DLx1 session:
1. We are using a recording probe that extracts 4 channels, when we need 8.
2. Some channels are dead or have to much noise.

In [None]:
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')

sf, expName, ref_channels, dead_channels = aux_fcn.load_info(path)

channels_map = aux_fcn.load_channels_map(path)
channels, shanks, ref_channels = aux_fcn.reformat_channels(channels_map, ref_channels)
LFP = aux_fcn.load_raw_data(path, expName, channels, verbose=False)
print('Sampling frequency: ', sf)
print('Shape of the original data',LFP.shape)
LFP_linear=LFP[:,[0,2,4,6]]
print('Shape of the 4 channels simulated data: ',LFP_linear.shape)
LFP[:,[2,5]]=0
LFP_dead=LFP
print('Sample of the simulated dead LFP: ',LFP_dead[0])

After interpolation, the data is ready to use in prediction

In [None]:
# Define channels
channels_interpolation = [0,-1,1,-1,2,-1,-1,3]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_linear, channels_interpolation)
print('Shape of the interpolated LFP: ',LFP_interpolated.shape)

In [None]:
 # Define channels
channels_interpolation = [0,1,-1,3,4,-1,6,7]

# Make interpolation
LFP_interpolated = aux_fcn.interpolate_channels(LFP_dead, channels_interpolation)
print('Value of the 1st sample of the interpolated LFP: ',LFP_interpolated[0])

# Using an ensemble model
In this section, we will show how to use an ensemble model that combines the output of the best models of each architecture. This model has better performance and more stability than the individual models. In this case, only the best ensemble model will be provided.

First, the output of the 5 selected models needs to reshaped

In [None]:
# 5 outputs are generated
architectures=['XGBOOST','SVM','LSTM','CNN1D','CNN2D']
output=[]
for architecture in architectures:
    channels=[0,1,2,3,4,5,6,7]
    SWR_prob,_=rippl_AI.predict(LFP,sf,arch=architecture,model_number=1,channels=channels)
    output.append(SWR_prob)
ens_input=np.array(output).transpose()


Generating ensemble model output

In [None]:

prob_ens=rippl_AI.predict_ens(ens_input)

Plot performance

In [None]:
fig,ax=plt.subplots()
th_arr=np.linspace(0.1,1,10)
F1_arr=np.zeros(shape=(len(th_arr)))
for k,th in enumerate(th_arr):
    det_ind=rippl_AI.get_intervals(prob_ens,threshold=th)
    _,_,F1_arr[k],_,_,_=aux_fcn.get_performance(det_ind,ripples)
ax.plot(th_arr,F1_arr)
ax.set_title('Ensemble model')
ax.set_ylim(-0.05,0.8)
ax.set_xlabel('Threshold')
ax.set_ylabel('F1')
