# Import relevant code and packages

In [None]:
# Native Python packages
import json
import logging
import os
import pickle
import time

# Third party libraries
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch

# Custom code
from data import data as data_class
from rnn_speech_detection import PyTorchEventDetector
from torch_model import speech_detector as lstm_model
from utils import FLAGS as FLAGS_class, load_parameters_and_directories, update_parameters_and_flags, \
    universal_parser, update_dependent_parameters, write_model_description_json, logger_config

%load_ext autoreload
%autoreload 2
%matplotlib inline

mpl.rcParams.update({'font.size': 14})

# Define the path of this repo
Set this path to correspond to where you've placed the repo. Directly inside this path should be the `data`, `results`, model configuation, and this notebook (`train_speech_detection_model.ipynb`).

In [None]:
repo_path = '/path/to/silent_spelling/speech_detection/'

# Define parameters relevant to training

In [None]:
# Parse command line arguments.
parser = universal_parser()
pargs = parser.parse_args(args=[])

# The paradigms corresponding to each of the utterance sets, in order
pargs.paradigms = '["mimed","mimed","mimed","attempted_motor"]'

# The utterance sets included in the training data
pargs.utterance_sets = '["alphabet1_2_with_right_v2","alphabet1_2_with_right","alphabet1_2","navigation1"]'

# Where results (and trained models) will be saved
pargs.out_folder = 'results'

# The order of the labels
pargs.feature_labels = '["silence","speech","preparation","motor"]'#,"garbage1","garbage2","garbage4"]'
pargs.total_feature_labels = pargs.feature_labels

# The name of the prediction fold. This is a key in the block splits file and
# defining this tells the model which training split to use.
pargs.prediction_fold = 'demo'

# Name of the project if tracking in wandb
pargs.project_name = 'demo'  

# Number of blocks to pull from each utterance set for the validation set
pargs.num_val_blocks = 1

# Set the training flag to True and indicate
# the prespecified fold splits are being used.
pargs.do_training = True
pargs.prediction_folds = True

# The prefix of the model configuration. This allows
# you to define multiple configurations and specify 
# which to use flexibly.
pargs.config_prefix = 'demo_'

# The subject ID
pargs.subject = 'bravo1'

# Filename of the block splits
pargs.block_split_filename = 'demo_block_splits.pkl'

# Name of the folder containing the prepared data
pargs.prepared_data_folder_name = 'data'

# Model run: if you wish to save more than 1 run of the same model,
# increase this number
pargs.run = 0

# Model version
pargs.model_version = 'torch'

# The value of the false positive weight, as defined 
# in the Supplement of the paper
pargs.false_positive_weight = 1.1

# Set other flags relevant to training

In [None]:
# Set flags
FLAGS = FLAGS_class()
FLAGS.do_inference = pargs.do_inference
FLAGS.do_training = pargs.do_training
FLAGS.get_saliences = pargs.get_saliences
FLAGS.prediction_folds = pargs.prediction_folds
FLAGS.create_predictions = pargs.create_predictions
FLAGS.use_presaved_inference_blocks = pargs.use_presaved_inference_blocks
FLAGS.process_and_save_only = pargs.process_and_save_only
FLAGS.kfold_cross_validation = False
FLAGS.train_master_model = False
FLAGS.store_weights = True
FLAGS.save_logits = False
FLAGS.do_hyperopt = False
FLAGS.use_hyperopted_params = pargs.use_hyperopted_params
FLAGS.no_early_stopping_learning_rate_adjustment = pargs.no_early_stopping_learning_rate_adjustment

# Initialize the logger
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)
logging.basicConfig(**logger_config)
logger = logging.getLogger('speech_detection')

# Load the parameters object based on the model configuration JSON.
# Update it with the parsed arguments. Load directory information
params, dirs = load_parameters_and_directories(
    model_type=f'supervised_lstm_model',
    config_path=repo_path,
    pargs=pargs,
    config_prefix='demo_'
)

# Update the FLAGS and params based on the FLAGS.
params, FLAGS = update_parameters_and_flags(params, FLAGS, pargs=pargs)

In [None]:
# Define the specific directories.
# dirs.data = repo_path + 'data'
dirs.prepared_data_dir = repo_path + 'bravo1/data'  # where the prepared data is
dirs.general_output_dir = repo_path  # where to get general files
dirs.output = repo_path + 'bravo1/results'  # Where to save models

if not os.path.exists(dirs.output):
    os.makedirs(dirs.output)

# Load prepared data and split into training and validation sets

In [None]:
with open(repo_path + 'block_breakdowns.json', 'r') as f:
    full_block_breakdowns = json.load(f)
    block_breakdown = full_block_breakdowns[params.subject_id]

params.num_folds = 1
params.rt_class = None
prefixes = [
    'window{:03d}_fold{}'.format(params.int_window, cur_fold)
    for cur_fold in range(params.num_folds)
]
cur_model_run = pargs.run
dirs.fold_output_models = [
    os.path.join(dirs.output,
            '{}_model_run{}'.format(cur_prefix, cur_model_run))
    for cur_prefix in prefixes
]

# Get blocks and load data.
data = data_class(repo_path + 'log.log')
data._get_blocks(FLAGS=FLAGS, dirs=dirs, params=params, args=pargs)

ifold = 0
fold_start = time.time()

# Define more specific paths for where to save and load the models.
prefix = prefixes[ifold]
dirs.fold_output_model = dirs.fold_output_models[ifold]

# We're only going to load the validation data, we'll load
# training data on the fly for each epoch
data.load_prepared_data(FLAGS=FLAGS, params=params, fold=ifold,
                        block_breakdown=full_block_breakdowns,
                        data_dir=dirs.prepared_data_dir)

try:
    os.mkdir(dirs.fold_output_model)
except FileExistsError:
    logger.info('Model folder already exists.')

# Save all model configurations to the output model folder.
write_model_description_json(params,
                             data.blocks_for_training,
                             data.blocks_for_validation,
                             dirs.fold_output_model)

# Train a speech detection model with the demo data

If you wish to use a GPU for training, set `device = 'cuda'`.

In [None]:
device = 'cpu'

The elapsed time for each training loop will print and be saved to the logger.

Estimated training times for the example dataset:
* Using a Tesla V100 GPU and Linux server using Ubuntu 20.04, each epoch (including data loading and minibatch updates) takes ~ 1 minute.
* Using only CPUs on a Linux server using Ubuntu 20.04, each epoch (including data loading and minibatch updates) takes ~ 3.5 minutes.
* Using only CPUs on a MacBook Pro (macOS 10.15.7), each epoch (including data loading and minibatch updates) takes ~ 2.5 minutes.

In [None]:
# Initialize the speech detection model.
# If you have a wandb account, this will enable tracking training/validation 
# losses and accuracies under the project name specified at the top of the notebook
speech_detection_model = lstm_model(params, FLAGS, 'log.log', use_wandb=False, device=device)  

# Train a new model.
logger.info(
    'Model and outputs saved to : {}'.format(dirs.output))
logger.info('Data fold preparation time: {0:.3f} s'.format(
    time.time() - fold_start))
logger.info('Fold {} training blocks: {}'.format(
    ifold, data.blocks_for_training))
logger.info('Fold {} validation blocks: {}'.format(
    ifold, data.blocks_for_validation))
logger.info('neural val data shape')
logger.info([b.shape for b in data.neural_val])
logger.info('events val data shape')
logger.info([b.shape for b in data.events_val])
logger.info('*************************************')

train_start = time.time()

# Train the model.
speech_detection_model.train(
    data.trial_map,
    data.neural_val,
    data.events_val,
    output_dir=dirs.output,
    output_model_dir=dirs.fold_output_model,
    model_tracking=True,
    shuffle_batch=True,
    data_dir=dirs.prepared_data_dir,
    training_blocks=data.blocks_for_training,
    validation_blocks=data.blocks_for_validation
)

logger.info('Done training!')
logger.info('Total training time: {0:.3f} s'.format(
    time.time() - train_start))

# Load a pre-trained speech detection model

In [None]:
## Load the pretrained model that used more data
model_path = repo_path + 'pretrained_example_model'

## Load the model you just trained
# model_path = repo_path + 'bravo1/results/window100_fold0_model_run0'

# Initialize the model and load the function to predict silent speech probabilities from ECoG data
model = PyTorchEventDetector(restore_path=model_path, testing=True)
model.build(device='cpu');

In [None]:
# Load an example block of data
with open(repo_path + 'bravo1/data/2726-conv.pkl', 'rb') as f:
    data = pickle.load(f)
    ecog = data['ecog']
    events = data['events']
    sr = data['sr']

## Use the model to predict speech probabilities from a sentence-spelling block

In each trial of the sentence-spelling block, the participant uses silent-speech attempts to spell out the intended message.
At the end of the trial, the participant attempts to make a hand movement to finalize the sentence before proceeding to the next trial.
Therefore, the detected probabilities should show `speech` activity throughout the majority of each trial and `motor` activity at the end of each trial.

In [None]:
# Convert the model to a tensor and use the pre
block_tensor = torch.tensor(ecog, dtype=torch.float32).unsqueeze(0).to('cpu')
probs = []
with torch.no_grad():
    for m in model.models:
        probs.append(m(block_tensor))
stack_probs = torch.stack(probs, dim=0).cpu().numpy()
probs = np.mean(stack_probs, axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(20, 6))

target_pres = events.loc[(events['phase_num'] == 1) & (events['state_num'] == 2)]['elapsed_time'].values
for i, cue in enumerate(target_pres):
    if i == 0:
        ax.axvline(x=cue, linestyle='--', alpha=0.6, color='g', label='trial start')
    else:
        ax.axvline(x=cue, linestyle='--', alpha=0.6, color='g')
    
trial_end = events.loc[(events['phase_num'] == 3) & (events['state_num'] == 2)]['elapsed_time'].values
for i, cue in enumerate(trial_end):
    if i == 0:
        ax.axvline(x=cue, linestyle='--', alpha=0.6, color='r', label='trial end')
    else:
        ax.axvline(x=cue, linestyle='--', alpha=0.6, color='r')

x = np.arange(probs.shape[0]) / sr
for cur_event, event_type in enumerate(model.event_mapping):
    ax.plot(x, cur_event + probs[:, cur_event])

ax.axes.set(xlabel='Time (s)', yticks=0.5 + np.arange(0, 4), yticklabels=model.event_mapping)
ax.legend(loc='upper left', bbox_to_anchor=(1, 1.01));