In [None]:
import os
import rippl_AI
import aux_fcn
import matplotlib.pyplot as plt
import numpy as np

### Data download
4 uLED sessions will be downloaded: Amigo2 and Som2 will be used for training ; Dlx1 and Thy7 for validation


In [None]:
from figshare.figshare import Figshare
fshare = Figshare()

article_ids = [16847521,16856137,14959449,14960085] 
sess=['Amigo2','Som2','Dlx1','Thy7']                                  
for id,s in zip(article_ids,sess):
    datapath = os.path.join('Downloaded_data', f'{s}')
    if os.path.isdir(datapath):
        print(f"{s} session already exists. Moving on.")
    else:
        print("Downloading data... Please wait, this might take up some time")        # Can take up to 10 minutes
        fshare.retrieve_files_from_article(id,directory=datapath)
        print("Data downloaded!")

### Data load
The training sessions' LFP will be appended together in a list. The same will happen with the ripples detection times.
That is the required input for the training parser

In [None]:
# The training sessions will be appended together. Do the same with your training data
train_LFPs=[]
train_GTs=[]
# Amigo2
path=os.path.join('Downloaded_data','Amigo2','figshare_16847521')
LFP,GT=aux_fcn.load_lab_data(path)
train_LFPs.append(LFP)
train_GTs.append(GT)

# Som2
path=os.path.join('Downloaded_data','Som2','figshare_16856137')
LFP,GT=aux_fcn.load_lab_data(path)
train_LFPs.append(LFP)
train_GTs.append(GT)
## Append all your validation sessions
val_LFPs=[]
val_GTs=[]
# Dlx1 Validation
path=os.path.join('Downloaded_data','Dlx1','figshare_14959449')
LFP,GT=aux_fcn.load_lab_data(path)
val_LFPs.append(LFP)
val_GTs.append(GT)
# Thy07 Validation
path=os.path.join('Downloaded_data','Thy7','figshare_14960085')
LFP,GT=aux_fcn.load_lab_data(path)
val_LFPs.append(LFP)
val_GTs.append(GT)



The training sessions are concatenated, the validation sessions are kept as different sessions

In [None]:
retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT=rippl_AI.prepare_training_data(train_LFPs,train_GTs,val_LFPs,val_GTs,sf=30000)

# Retraining examples for the different models

### XGBOOST
XGBOOST does not require further parameters

In [None]:
rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='XGBOOST',
                       save_path=os.path.join('retrained_models','XGBOOST_retrained1'))

### SVM
Has only one parameter: 
'Undersampler proportion': It controls the number of windows with negatives (no ripples) that will be used to train the model. Following the formula: Undersampler proportion= (Positive windows)/(Negative windows). 1 means the same number of poitive and negative windows. Low values can lead to overfitting.

In [None]:
params={'Unsersampler proportion': 0.1}

rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='SVM',parameters=params,
                       save_path=os.path.join('retrained_models','SVM_retrained1'))

### LSTM 
LSTM has two training parameters:
'Epochs': is the number of times that the training data is fed to the model
'Training batch': is the number of windows that are processed before updating the weights during training. Higher values prevent big weight oscillations.

In [None]:
params={'Epochs': 2,
        'Training batch': 32}
rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='LSTM',parameters=params,
                       save_path=os.path.join('retrained_models','LSTM_retrained1'))

### CNN2D
CNN2D share training parameters with th LSTM architecture

In [None]:
params={'Epochs': 1,
        'Training batch': 64}
rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='CNN2D',parameters=params,
                       save_path=os.path.join('retrained_models','CNN2D_retrained1'))

### CNN1D
CNN1D share training parameters with LSTM and CNN2D

In [None]:
params={'Epochs': 2,
        'Training batch': 32}
rippl_AI.retrain_model(retrain_LFP_norm,retrain_GT,val_LFP_norm,val_GT,arch='CNN1D',parameters=params,save_path=os.path.join('retrained_models','CNN1D_retrained1'))