# Example notebook for classification and spelling

This notebook will enable the user to:
1. Train up a classifier on neural data from BRAVO1 during mimed speech attempts, then make predictions on held out test data, as well as calculate feature salience.
2. Use the classifier to make realtime predictions on the sentence spelling data. 
3. Train an N-gram language model for use with our beam search algorithm.
4. Use the language model + beam search to improve the predictions.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse

In [None]:
# Load all the dependencies
import os
import torch
import argparse
import numpy as np
from data_wrangling import load_data , split_data
from machine_learning import run_standard_model, prediction_only
from bravo_ml.model_eval.torch_salience import get_saliences
from saving_functions import save_saliences, make_dataframe
import matplotlib.pyplot as plt
from os.path import join

#### Arguments. This defines the arguments that you can modify, should you want to. 
#### The necessary arguments to train the basic model are all provided below

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default=None, type=str, 
                    help='How the data is saved. Should be  \
                    Should be in format %s_restofname.npy for X, Y, and blx (Neural, labels, blocks)')

parser.add_argument('--date_dict', default=None, type=str,
                   help='A dictionary that maps the block numbers to the date they were recorded')

parser.add_argument('--end_date', default=None, type=str, 
                   help='last day of data to use, YYYY_MM_DD format')

parser.add_argument('--device', default='cpu', type=str, 
                   help='cpu for cpu, cuda to use GPU')

parser.add_argument('--model_saving_fp', default=None, help='where to store the models + the base name, \
                                            This should be in the format filepath/<modelname>_%d.pth \
                                            so that we can store models for each CV if you want')

parser.add_argument('--feats_to_use', default='hga_and_lfs', type=str, help='which features to use \
                        defaults to both. other options are hga, or lfs')

parser.add_argument('--data_dir', default = './data', help='where data is stored')



parser.add_argument('--balance_classes', action='store_true', help='equal number of classes per category')


#### The rest of th parameters shouldn't be changed here, but can be useful for running experiments.
parser.add_argument('--start_date', default=None, type=str, 
                   help='first day of data to use, YYYY_MM_DD format')
parser.add_argument('--data_split_scheme', default='random', type=str,
                   help='how to split up the data')
parser.add_argument('--load_pretrained_model', default=None, type=str, help='path of pretrained models to use for training. '
            'Useful for finetuning models')
parser.add_argument('--num_cvs', default=10, type=int, help='how many CVs to run')
parser.add_argument('--num_samples', default=None, type=int, help='max no. of samples to use for training');

Because the dataset is quite large, it may be useful to load only a small portion of it to enable faster training, and to enable training on RAMs with < 32GB of memory. You can adjust the percentage of the training data to use using this argument

In [None]:
parser.add_argument('--training_data_pct', default=0.5, type=float,  
                   help='What percentage of  the data to use. This enables training on smaller RAM CPUs');

If you reduce the size of the dataset, the predictions on the realtime data will be worse. 
To get better predictions, you can load a model pretrained on all the data up to the day before our realtime data was recorded using the argument below

In [None]:
parser.add_argument('--load_pretr_for_spell', action='store_true', help='load the pretrained model on full ds for spelling');

### Arguments for training the model on a small portion of our dataset
This should run on most computers with 16GB of RAM

In [None]:
# Load the settings we need 
exp_str = '--end_date 2021_07_20' # Use data collected prior to July 20th [ensures the data is only isolated trial data]
# Change this to wherever your data is stored. But if you follow readme instructions, it should be stored in these folders
# already. 
exp_str += ' --dataset %s_alpha_new_mimed.npy'  #
exp_str += ' --date_dict date_dictionary_mimed.pkl'
exp_str += ' --model_saving_fp ./model_checkpoints/demo_model_partial_data_%d.pth'
exp_str += ' --device cpu' # Change to CUDA to use a GPU for training if you installed pytorch appropriately. 
exp_str += ' --balance_classes'
exp_str += ' --data_dir ./data/'
exp_str += ' --load_pretr_for_spell'

### Uncomment the cell below to change the arguments to those that replicate our pretrained model
This requires at least a 32 GB of RAM since we're using a larger dataset

In [None]:
# exp_str = '--end_date 2021_08_10' # Use data collected prior to August 10th (the session prior to realtime decoding we will use)
# exp_str += ' --dataset %s_alpha_new_mimed.npy'  #
# exp_str += ' --date_dict date_dictionary_mimed.pkl'
# exp_str += ' --model_saving_fp ./model_checkpoints/demo_model_%d.pth'
# exp_str += ' --data_dir ./data/'
# exp_str += ' --training_data_pct 1.0'
# exp_str += ' --device cpu'

#### Parse arguments

In [None]:
# Parse the arguments
args = vars(parser.parse_args(exp_str.split()))
# see the arguments
args

# Step 1: Load the neural data, labels, and block numbers

The data should be in 3 files: 

X_alpha_new_mimed: The neural samples for each trial in a tensor that is N_trials x Timesteps x Channels. Channels 0-127 are the LFS activity from every channel. The other 128 are the HGA activity from each channel. Each sample is the neural activity from 2s prior to the go cue, to 4s after. The data has already been downsampled to 33.33Hz, and the keras.utils.normalize function has been applied. Using the appropriate timewindow is taken care of by a data augmentation in the model training code. 

blx_alpha_new_mimed: The associated block number for each trial, these are used to map each sample to the day it was recorded. This requires a date_dictionary in the args, which maps from these block numbers to the dates that they were recorded.

Y_alpha_new_mimed: The associated label for each trial. Labels 0-25 correspond to A-Z. Label 26 is the motor command.

This data is already saved with our decimation factor of 6 and normalization applied

In [None]:
args['test_end_date'] = args['end_date']

# Step 1: Load in the dataset.
X, Y, blocks, dates = load_data(join(args['data_dir'], args['dataset']), args['start_date'], args['end_date'], 
                               join(args['data_dir'], args['date_dict']))

### Plotting a trial

In [None]:
go_back_amt = -1*X.shape[0]*args['training_data_pct']
go_back_amt = int(go_back_amt)

In [None]:
X, Y, blocks, dates = X[go_back_amt:], Y[go_back_amt:], blocks[go_back_amt:], dates[go_back_amt:]

In [None]:
from data_wrangling import balance_classes

In [None]:
if args['balance_classes']:
    X, Y, blocks, dates = balance_classes(X, Y, blocks, dates)

In [None]:
print(X.shape)

In [None]:
t= np.arange(X.shape[1])*6/200 - 2
trial = 100
elec = 0
plt.plot(t, X[trial, :, elec], label='raw data, elec 0')
plt.plot(t, X[trial, :, elec+128], label='hga data, elec 0')
plt.axvline(0, label='cue onset', c= 'r')
plt.xlabel('time (s)')
plt.ylabel('neural activity (a.u.)')
plt.legend()
plt.title("Neural activity from a single trial, with decimation and normalization applied")
plt.show()

In [None]:
# You can try out testing the model with just high gamma or just raw data to see the effect on accuracy changes (see fig 3). 
# Note the model will then not work for the sentence predictions later in the notebook.
if args['feats_to_use'] == 'lfs':
    X = X[:, :, :128]
elif args['feats_to_use'] == 'hga': 
    X = X[:, :, 128:]

# Step 2: Train the neural network model [just on the mimed dataset]

There are two differences between this network and the one used in our evaluation: 

Firstly, the one we used was pretrained on overt data, then fine tuned on mimed. See fig 4 for more details, but this only minorly improved performance and doubles training time. 
Secondly, we used model ensembling to average the predictions of 10 models. To save training time for this example, we only train on mimed data, and evaluate using one model. 

In [None]:
# Initialize lists to save all the meaningful metrics we want
pred_list = []
label_list = []
cvs_list = []
blocks_list = []
prev_sal_start = 0 
# all_saliences = np.zeros_like(X)

#### This will train up the model. 
It takes about 13-14 epochs for our model to reach around 55% accuracy on the validation set. Keep in mind we're not using all the 9k samples we had, so it will be a little worse. Each epoch should take around a minute. Our final accuracy on the heldout test set should be around 50-60%. 

In [None]:
# Go through the cvs necessary. 
for cv in range(args['num_cvs']):
                
    # Use the 2nd CV fold if we're not using all the data. Gives slightly better performance :D 
    if args['training_data_pct'] != 1.0: 
        cv = 2
    
    # Get training and test blocks
    X_tr, Y_tr, blocks_tr, X_te, Y_te, blocks_te = split_data((X,Y, blocks), None,
                                                              args['data_split_scheme'], 
                                                              args['num_cvs'], cv, 
                                                              args['num_samples'])
    # Split the training set into validation sets.
    X_tr, Y_tr, blocks_tr, X_v, Y_v, blocks_v = split_data((X_tr, Y_tr, blocks_tr),
                                                          None,  
                                                           args['data_split_scheme'], 
                                                            args['num_cvs'], cv,
                                                          None)
    
    print('Train samples', X_tr.shape[0], 'Validation samples', X_v.shape[0], 'Test samples', X_te.shape[0])
        
    # Load the pretrained model if necessary
    if not args['load_pretrained_model'] is None: 
        pretr_model = args['load_pretrained_model']
       
    else: 
        pretr_model = None
    
    
    if pretr_model is None: 
        lr = 1e-3
    else: 
        lr = 1e-4
    # Run the training loop. Get back the labels and predictions, as well as teh trained model 
    labels, preds, te_loader, model = run_standard_model(X_tr, Y_tr, blocks_tr, 
                                                         X_te, Y_te, blocks_te,
                                                         X_v, Y_v, blocks_v, pretr_model, 
                                                         lr, cv,
                                                         args)
    
    
#     # If you run this for more CVs, this code lets you store all the 
#     # predictions and labels across each CV. 
#     pred_list.extend(preds)
#     label_list.extend(labels)
#     cvs_list.extend([cv]*len(labels))
#     blocks_list.extend(blocks_te)

    
    break # we only care about getting one model for now, but to run experiments etc. you can run 10. 

In [None]:
print(f'Model training complete. The top-1 accuracy was {100*np.mean([np.argmax(p) for p in preds]==labels):.3f}%' , 'on the held out test data')

# Step 3: Load data and metadata from a sentence spelling block. 

Here we will get the chunks of the data that occur every 2.5 seconds relative to the first go cue, and apply the model. Then we will save the probability vectors from each prediction for use during the beam search. 

We will also save the most likely character at each timepoint to display the greedy prediction. Because the spaces are added in during the beam search, the greedy decode will not contain spaces.

We will also plot a matrix that shows the prediction vector for every timepoint. You can see that the character often still assigns some probability to the correct character when it is wrong, which means a language model can improve the decoding

In [None]:
# imports relevant to this setion
from scipy.signal import decimate
from machine_learning import normalize
import torch.nn.functional as F
import pickle
import fastwer

In [None]:

if args['load_pretr_for_spell']:
    from machine_learning import get_model_shell
    model = get_model_shell(X_te, Y_te, blocks_te, 0, args)
    model.load_state_dict(torch.load('./model_checkpoints/demo_model_0.pth'))



In [None]:
block = 2726 # The realtime block we're testing on. It was collected on 8/12/21, so none of our training data was from this day.

In [None]:
# open the provided data
with open(args['data_dir'] + '/realtime_spelling/block_%d_timing.pkl' %block,'rb') as f: 
    sentence_to_start = pickle.load(f)

# A single timecourse with all of the neural data in it    
neural = np.load(args['data_dir'] + '/realtime_spelling/block_%d_neural.npy' %block)

In [None]:
# Define the character set. 
characters = 'abcdefghijklmnopqrstuvwxyz$'
# put model in eval set
model = model.to(args['device']).eval()

# Step 4: Perform "greedy" spelling without full language modeling

In [None]:
def simulate_realtime_prediction(neural_slice, model): 
    """
    Input: Array of neural data at 200 Hz 
    Outputs: The models prediction given that neural data

    """
    neural_slice = np.expand_dims(neural_slice, axis=0)
    # We need to decimate the data since its at 200Hz, and the model operates at 33.33
    neural_slice = decimate(neural_slice, 6, axis=1)

    # Normalize the data, normalize raw and LFS separately. 
    # Note the order changes, since there was a small discrepancy in the order of the
    # feature streams for the realtime data vs the order used for model training. This
    # fixes it. 
    new_data= np.zeros_like(neural_slice)
    new_data[:, :, 128:] = normalize(neural_slice[:, :, :128])
    new_data[:, :, :128] = normalize(neural_slice[:, :, 128:])

    # Predict using the model we just trained :D 
    with torch.no_grad():
        pred_vec = model(torch.from_numpy(new_data).float().to(args['device']))
        pred_vec = F.softmax(pred_vec, dim=-1)
    pred_vec = pred_vec.cpu().squeeze().numpy()

    return pred_vec

In [None]:
pred_arrs = []
cers = []
cue_time_diff = 2.5 # Time between cues in seconds
sr = 200 # The original sampling rate of the data. 

for sent, start in sentence_to_start.items(): 
    sent += '$' # representing the motor command
    greedy = ''
    preds = []
    for k, char in enumerate(sent.replace(' ', '')): 
        
        
        # Get the neural slice out.
        char_start_time = k*cue_time_diff + start
        char_end_time = char_start_time + cue_time_diff
        char_start_ind = int(char_start_time*sr)
        char_stop_ind = int(char_end_time*sr)
        neural_slice = neural[char_start_ind:char_stop_ind]

        # Make the prediction the model would make
        pred_vec = simulate_realtime_prediction(neural_slice, model)

        # If the motor command didnt have greater than .8 probability, 
        # then store the prediction over the NATO codeword characters for the beam search.
        # Otherwise, stop the process.
        if pred_vec[-1] > .8:
            print('motor command decoded, system stopped')
            break
        else:
            preds.append(pred_vec[:-1])
            
            # Store the most likely character
            greedy += characters[np.argmax(pred_vec)]
            
    # Print out the ground truth
    sent = sent[:-1] # Remove the Motor command now
    print('Ground truth:', sent)
    
    # print the greedy prediction
    print('Greedy prediction:', greedy)
    
    # Compare how many characters it got right with the sentence without spaces. 
    wer, cer = fastwer.score([sent.replace(' ', '')],[greedy]), fastwer.score([sent.replace(' ', '')],[greedy],  char_level=True)
    print('Greedy CER %.3f:' %cer)
   
    
    # Append predictions to the list of predicted arrays. 
    pred_arrs.append(np.array(preds))
    plt.imshow((np.array(preds).T), cmap='Blues')
    plt.yticks(np.arange(26), characters[:-1])
    plt.xticks(np.arange(len(preds)), list(sent.replace(' ', '')))
    plt.xlabel('Intended Character')
    plt.ylabel('Decoder Probability of Each Char. Codeword')
    plt.show()
    print('---')

# Step 5: Plug our vectors into our full beam search and language model. This will result in receiving the predictions with spaces inserted and with langauge modeling applied. 

This should greatly improve the predictions. 

In [None]:
from language_models.train_ngram import get_ngram_lm

In [None]:
import argparse
# Input arguments
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--vocab_filepath', default=None, type=str, help='where the new vocab is located. list of all the words')
parser.add_argument('--ngram_filepath', default=None, type=str, help='where to save the ngram model. end in .pkl') 
parser.add_argument('--ngram_corpus', default=None, type=str, help='path to the Ngram corpora for training')
parser.add_argument('--greedy', action='store_true', help='whether to use the weird greedy beam search or not')
parser.add_argument('--paradigm', default='mimed', type=str, help='use mimed or overt blocks')


# Language model + beam search parameters. 
parser.add_argument('--alpha', type=float, default=0.642)
parser.add_argument('--beta', type=float, default=10.524)
parser.add_argument('--final_lm_alpha', type=float, default=1.5268)
parser.add_argument('--beam_width', type=int, default=256)

# This will override the previously specified parameters, and use the parameters that were used in realtime during that block. 
parser.add_argument('--useRTparam', action='store_true', help='use the same params we used in RT')


# Respect certain parameters in rt. Not used here, but can be useful for running experiments to change alpha and beta to 0 (e.g. not use LM)
parser.add_argument('--hardset_beams', action='store_true', help='lock beam sizes, otherwise they will update \
                    across the demo days')
parser.add_argument('--hardset_lms', action='store_true', help='lock lm parameters, otherwise they will update \
                    across teh demo days')
parser.add_argument('--greedy_flag', action='store_true', help='greedy decoding');

In [None]:
exp_str = '--vocab_filepath ./language_models/word_vocab_1k.lex' # The vocabulary you want. To adjust vocabulary size, just use a
                                                                # vocabulary with more words. 
exp_str += ' --ngram_filepath ./language_models/1k_ngram.pkl' # Where to save the LM
exp_str += ' --ngram_corpus ./language_models/corpora/' # Contains text files that will be processed to train LM. 

exp_str += ' --useRTparam' # Use the beam search parameters that we used for this block in realtime. 
args_lm = vars(parser.parse_args(exp_str.split()));

### Train the Ngram language model (LM) that we're going to use during the beam search
This will take between 3-5 minutes on a macbook pro

In [None]:
# Step 1: Get the vocab/n_gram sorted
ngram_lm_fp = get_ngram_lm(args_lm['vocab_filepath'], 
                           args_lm['ngram_filepath'], 
                           args_lm['ngram_corpus'])


from language_models.autocomplete import * 
if ngram_lm_fp is None: 
    ngram_lm_fp = args_lm['ngram_filepath']
   
# Load the language models
lm = load_autocomplete(ngram_lm_fp) 

In [None]:
# Run the search. This will output the most likely sentence at each timepoint, which was provided to the participant as feedback. 
# It will also show the final point 

In [None]:
# Get the sentences and predictions paired
sent_and_preds =[]
for sent, pred_arr in zip(sentence_to_start.keys(), pred_arrs):
    sent_and_preds.append((sent, pred_arr))


# Load the language model parameters
lm_config = {
    'alpha':args_lm['alpha'],
    'beta':args_lm['beta'],
    'final_lm_alpha':args_lm['final_lm_alpha'],
    'beam_width':args_lm['beam_width']
} 

# Run the simulation. 
from beam_search_simulation.run_simulation import run_prefix_search

gts, preds = run_prefix_search(sent_and_preds, 
                               lm_config, 
                               lm, 
                             [block]*len(sent_and_preds),
                               args_lm['vocab_filepath'],  
                               useRTparams=args_lm['useRTparam'], 
                              hardset_lms=args_lm['hardset_lms'], 
                              hardset_beams=args_lm['hardset_beams'],
                              greedy_flag= args_lm['greedy_flag'])

In [None]:
# Displays the ground-truth sentences
gts

In [None]:
# Displays the predicted sentences
preds

In [None]:
# Compute the primary performance metrics (word and character error rates). The CER should be low. 
import fastwer
wer = fastwer.score(preds, gts)
cer = fastwer.score(preds, gts, char_level=True)
print('block wer %.2f' %(wer), '%')
print('block cer %.2f' %(cer), '%')