# SVM parameter exploration
This notebook is a template for finding the SVM model best suited for your needs

In [None]:
import os
import importlib
import matplotlib.pyplot as plt
import numpy as np
import sklearn as sk
from sklearn import svm,calibration
from imblearn.under_sampling import RandomUnderSampler
import sys
import inspect
parent_dir=os.path.dirname(os.getcwd())
sys.path.insert(0,parent_dir )
import rippl_AI
import aux_fcn
importlib.reload(aux_fcn)
importlib.reload(rippl_AI)

### Data download
4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation


In [None]:
import os
from figshare.figshare.figshare import Figshare
fshare = Figshare()

article_ids = [16847521,16856137,14959449,14960085] 
sess=['Amigo2','Som2','Dlx1','Thy7']                                  
for id,s in zip(article_ids,sess):
    datapath = os.path.join(parent_dir,'Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print(f"{s} session already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data load
The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.
That is the required input for the training parser

In [None]:
# The training sessions will be appended together. Replace this cell with your own data loading
train_LFPs=[]
train_GTs=[]
# Amigo2
path=os.path.join(parent_dir,'Downloaded_data','Amigo2','figshare_16847521')
LFP,GT=aux_fcn.load_lab_data(path)
train_LFPs.append(LFP)
train_GTs.append(GT)
# Som2
path=os.path.join(parent_dir,'Downloaded_data','Som2','figshare_16856137')
LFP,GT=aux_fcn.load_lab_data(path)
train_LFPs.append(LFP)
train_GTs.append(GT)

## Append all your validation sessions
val_LFPs=[]
val_GTs=[]
# Dlx1 Validation
path=os.path.join(parent_dir,'Downloaded_data','Dlx1','figshare_14959449')
LFP,GT=aux_fcn.load_lab_data(path)
val_LFPs.append(LFP)
val_GTs.append(GT)
# Thy07 Validation
path=os.path.join(parent_dir,'Downloaded_data','Thy7','figshare_14960085')
LFP,GT=aux_fcn.load_lab_data(path)
val_LFPs.append(LFP)
val_GTs.append(GT)

x_training,GT_training,x_val_list,GT_val_list=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)

## SVM training parameters

#### Parameters:
* Channels:  Number of channels that will be used to train the model, extracted from the data shape defined in the previous cell
* Timesteps: Number of samples that the will be used to generate a single prediction
* Undersampler proportion: roportion of True/False samples. Using all the samples demands heavy resources usage, but the most data used, the better the model generalizes

In [None]:
conf=  { 'timesteps':[1,2,4],              # Possible values: 1,2,4,8,16,32... 
         'undersampler proportion':[1]}    # Possible values: 1, 0.5, 0.1 ... (0,1]

### Training

In [None]:
# Desired sampling frequency of the models
sf=1250
th_arr=np.linspace(0.1,0.9,9)
model_name_arr=[]           # To plot in the next cell
model_arr=[]                # Actual model array, used in the next validation section
n_channels=x_training.shape[1]
timesteps_arr=conf['timesteps']
undersampler_arr=conf['undersampler proportion']
l_ts=len(timesteps_arr)
l_us=len(undersampler_arr)
n_iters=l_ts*l_us
# GT is in the shape (n_events x 2), a y output signal with the same length as x is required
perf_train_arr=np.zeros(shape=(n_iters,len(th_arr),3)) # Performance array, (n_models x n_th x 3 ) [P R F1]
perf_test_arr=np.zeros_like(perf_train_arr)
timesteps_arr_ploting=[]            # Array that will be used in the validation, to be able to call the function predict
print(f'{n_channels} will be used to train the SVM models')

print(f'{n_iters} models will be trained')

x_test_or,GT_test,x_train_or,GT_train=aux_fcn.split_data(x_training,GT_training,split=0.7,sf=sf)

y_test_or= np.zeros(shape=(len(x_test_or)))
for ev in GT_test:
    y_test_or[int(sf*ev[0]):int(sf*ev[1])]=1
y_train_or= np.zeros(shape=(len(x_train_or)))
for ev in GT_train:
    y_train_or[int(sf*ev[0]):int(sf*ev[1])]=1


for i_ts,timesteps in enumerate(timesteps_arr):

    x_train=x_train_or[:len(x_train_or)-len(x_train_or)%timesteps].reshape(-1,timesteps*n_channels)
    y_train_aux=y_train_or[:len(y_train_or)-len(y_train_or)%timesteps].reshape(-1,timesteps)
    y_train=aux_fcn.rec_signal(y_train_aux) 
    
    x_test=x_test_or[:len(x_test_or)-len(x_test_or)%timesteps].reshape(-1,timesteps*n_channels)
    y_test_aux=y_test_or[:len(y_test_or)-len(y_test_or)%timesteps].reshape(-1,timesteps)
    y_test=aux_fcn.rec_signal(y_test_aux)


    for i_us,undersampler_prop in enumerate(undersampler_arr):
        rus = RandomUnderSampler(sampling_strategy=undersampler_prop)
        x_train_us, y_train_us = rus.fit_resample(x_train, y_train)
        iter=i_ts*l_us+i_us
        print(f"\nIteration {iter+1} out of {n_iters}")
        print(f'Time steps: {timesteps}, Undersampler proportion: {undersampler_prop}')
        clf = sk.calibration.CalibratedClassifierCV(svm.LinearSVC()) 

        # Training 
        clf.fit(x_train_us, y_train_us)
        model_arr.append(clf)
        # Prediction. One value per window
        test_signal = clf.predict_proba(x_test)[:,1]
        train_signal=clf.predict_proba(x_train)[:,1]
        # Not compatible with the functions that extract beginning and end times
        y_train_predict=np.empty(shape=(x_train.shape[0]*timesteps,1,1))
        for i,window in enumerate(train_signal):
            y_train_predict[i*timesteps:(i+1)*timesteps]=window
        
        y_test_predict=np.empty(shape=(x_test.shape[0]*timesteps,1,1))
        for i,window in enumerate(test_signal):
            y_test_predict[i*timesteps:(i+1)*timesteps]=window
        
        for i,th in enumerate(th_arr):
            # Test
            ytest_pred_ind=aux_fcn.get_predictions_index(y_test_predict,th)/sf
            perf_test_arr[iter,i]=aux_fcn.get_performance(ytest_pred_ind,GT_test,0)[0:3]
            # Train
            ytrain_pred_ind=aux_fcn.get_predictions_index(y_train_predict,th)/sf
            perf_train_arr[iter,i]=aux_fcn.get_performance(ytrain_pred_ind,GT_train,0)[0:3]

        # Saving the model
        model_name=f"SVM_Ch{n_channels}_Ts{timesteps:03d}_Us{undersampler_prop:1.2f}"
        
        aux_fcn.fcn_save_pickle(os.path.join(parent_dir,'explore_models',model_name),clf)
        model_name_arr.append(model_name)
        timesteps_arr_ploting.append(timesteps)

### Plot training results

In [None]:
# Plot training results
fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')

for i in range(n_iters):
    axs[i,0].plot(perf_train_arr[i,:,0],perf_train_arr[i,:,1],'k.-')
    axs[i,0].plot(perf_test_arr[i,:,0],perf_test_arr[i,:,1],'b.-')
    axs[i,1].plot(th_arr,perf_train_arr[i,:,2],'k.-')
    axs[i,1].plot(th_arr,perf_test_arr[i,:,2],'b.-')
    axs[i,0].set_title(model_name_arr[i])
    axs[i,0].set_ylabel('Precision')
    axs[i,1].set_ylabel('F1')
axs[-1,0].set_xlabel('Recall')
axs[-1,1].set_xlabel('Threshold')
axs[0,0].legend(['Training','Test'])
plt.show()

### Validation

In [None]:
# For loop iterating over the models
importlib.reload(rippl_AI)
importlib.reload(aux_fcn)
fig,axs=plt.subplots(n_iters,2,figsize=(10,2*n_iters),sharey='col',sharex='col')
for n_m,model in enumerate(model_arr):
    F1_arr=np.zeros(shape=(len(x_val_list),len(th_arr))) #(n_val_sess x n_th) Array where the F1 val of each sesion will be stored
    for n_sess,LFP in enumerate(x_val_list):
        val_pred=rippl_AI.predict(LFP,sf=1250,arch='SVM',new_model=model,n_channels=n_channels,n_timesteps=timesteps_arr_ploting[n_m])[0]
        for i,th in enumerate(th_arr):
            val_pred_ind=aux_fcn.get_predictions_index(val_pred,th)/sf
            F1_arr[n_sess,i]=aux_fcn.get_performance(val_pred_ind,GT_val_list[n_sess],verbose=False)[2]
    
    axs[n_m,0].plot(th_arr,perf_train_arr[n_m,:,2],'k.-')
    axs[n_m,0].plot(th_arr,perf_test_arr[n_m,:,2],'b.-')
    for F1 in F1_arr:
        axs[n_m,1].plot(th_arr,F1)
    axs[n_m,1].plot(th_arr,np.mean(F1_arr,axis=0),'k.-')
    axs[n_m,0].set_title(model_name_arr[n_m])
    axs[n_m,0].set_ylabel('Precision')
    axs[n_m,1].set_ylabel('F1')
axs[-1,0].set_xlabel('Recall')
axs[-1,1].set_xlabel('Threshold')
plt.show()
    