In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!nvidia-smi
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(2)

Tue Aug 22 23:33:48 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN V      Off  | 00000000:3B:00.0 Off |                  N/A |
| 28%   31C    P8    24W / 250W |   3076MiB / 12066MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN V      Off  | 00000000:5E:00.0 Off |                  N/A |
| 28%   32C    P8    24W / 250W |   3143MiB / 12066MiB |      0%      Default |
|       

In [3]:

import numpy as np
import pandas as pd
import argparse
from os.path import join
import torchaudio
import torch
from torchaudio.models import decoder
from torchaudio.models.decoder import download_pretrained_files
print('torch version', torch.__version__)
print('torch audio version', torchaudio.__version__, '>= 0.12.0 needed')
curdir = './'
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, TensorDataset
import copy
import wandb

# Change to your data dir. 
data_dir = './data/'
device = 'cuda' # Set to cpu if you dont have a gpu avail. 

# Set up the experiment, can change the hyperparameters as you see fit or edit for your own models
parser = argparse.ArgumentParser()
parser.add_argument('--decimation', 
                   default=6, 
                   type=int, 
                   help='How much to downsample neural data')
parser.add_argument('--hidden_dim',
                    type=int,
                   default=512,
                   help="how many hid.units in model")
parser.add_argument('--lr', 
                    type=float,
                   default=1e-3,
                   help='learning rate')
parser.add_argument('--ks', 
                    type=int,
                   default=2,
                   help='ks of input conv')
parser.add_argument('--num_layers',
                   type=int, 
                   default=3,
                   help='number of layers')
parser.add_argument('--dropout', 
                   type=float, 
                   default=0.6, 
                   help='dropout amount')
parser.add_argument('--feat_stream', 
                   type=str, 
                   default='both',
                   help='which stream. both, hga, or raw')
parser.add_argument('--bs',
                   type=int, 
                   default=64, 
                   help='batch size')
parser.add_argument('--smooth',
                   type=int,
                   default=0, 
                   help='how much smoothing to apply.')
parser.add_argument('--no_normalize', 
                   action='store_false',
                   help='Normalize the neural data or not')
parser.add_argument('--LM_WEIGHT', 
                   help='how much the LM is weighted during beam search', 
                   type=float, 
                   default=3.23)
parser.add_argument('--WORD_SCORE', 
                   help='word insertion score for beam',
                    type=float,
                    default=-.26
                   )
parser.add_argument('--beam_width', 
                   help='beam size to use',
                   type=int,
                   default=100)
parser.add_argument('--checkpoint_dir',
                   help='where 2 save model',
                   type=str,
                   default=None)
parser.add_argument('--feedforward', 
                   help='no bidirectional',
                   action='store_true')
parser.add_argument('--pretrained',
                   help='path to a pretrained model to load',
                   type=str,
                   default=None)
parser.add_argument('--train_amt',
                   help='amt of train data to use',
                   type=float, 
                    default=1.0)
parser.add_argument('--samples_to_trim',
                   help='num samps back to go (to shorten window)',
                   default=0, 
                   type=int)
parser.add_argument('--ndense',
                   help='Use a different number of classes for a transfer model (useful for 50 phrase transfer.)',
                   default=40,
                   type=int)
parser.add_argument('--transfer_audio', 
                   help='true if transfer audio. then switch conv', 
                   action='store_true')
parser.add_argument('--num_50', 
                   help='only used for 500 phrases',
                   type=int)

  from .autonotebook import tqdm as notebook_tqdm


torch version 1.12.0
torch audio version 0.12.0 >= 0.12.0 needed


_StoreAction(option_strings=['--num_50'], dest='num_50', nargs=None, const=None, default=None, type=<class 'int'>, choices=None, required=False, help='only used for 500 phrases', metavar=None)

In [4]:
exp_str = exp_str = '--hidden_dim 512 --ks 2 --dropout 0.6 --num_layers 3 --num_50 500 --samples_to_trim 0'

In [5]:
# Now we parse the arguments... take out the ' --train_amt 1.0 when you're actually running the script in python. 
args = vars(parser.parse_args(exp_str.split()))

In [6]:
# initialize wandb project for logging.
wandb.init(project='b3_50_phrase_pub', 
          config=args)

#### Experiment is set up, now lets load in the neural data and corresponding labels.
##### Labels has several columns
##### ph_label - list, each item is the phonemes from that utterance, in order <br> 
##### txt_label - the ground truth text label for that utterance <br>
##### length - the length of the utterance, in samples, at 200 Hz <br> Needs to be divided by 6 to work with data (thats done later in this nb) 

##### Please ignore the rest of the columns, they were from a previous nb. 
# Each row in this df (e.g row 0) corresponds to trial 0 of the neural data

labels = pd.read_hdf(join(data_dir, 'Y_50_phrase_pub_tr.h5'))
# This test data uses the realtime blocks we used in the publication. 
labels_te = pd.read_hdf(join(data_dir, 'Y_50_phrase_pub_te.h5'))

te_len = len(labels_te)
labels = pd.concat((labels, labels_te), ignore_index=True)
labels.head()

all_labs = set(labels['txt_label'].values)

all_words = []
for a in all_labs: 
    all_words.extend(a.split(' '))
print('vocab size', len(set(all_words)), 'words')

### NEURAL DATA PREPROCESSING. This can be a very big deal for accuracy, so I suggest tinkering with parameters.

# from scipy.signal import decimate

def normalize(x, axis=-1, order=2):
    """Normalizes a Numpy array.
    Args:
        x: Numpy array to normalize.
        axis: axis along which to normalize.
        
        order: Normalization order (e.g. `order=2` for L2 norm).
    Returns:
        A normalized copy of the array.
    """
    l2 = np.atleast_1d(np.linalg.norm(x, order, axis))
    l2[l2 == 0] = 1
    return x / np.expand_dims(l2, axis)


import os
if True:
    X = np.load(os.path.join(data_dir, 'X_50_phrase_pub_tr.npy'))
    X_te = np.load(os.path.join(data_dir, 'X_50_phrase_pub_te.npy'))
    X = np.concatenate((X, X_te), axis=0)
    

In [29]:
for k in range(len(X)-te_len, len(X), 50):
    # Ensure no test in train.
    dists = (np.sum(np.sum(np.abs(X[k]-X[:-te_len]), axis=1), axis=1))
    assert np.sum(dists==0) ==0

5600
5650
5700


In [7]:
assert len (X) == len(labels)

### Note - some models will get really good performance, others not, this training is a bit sensitive to random init
After about 30 epochs expect to see a WER around 10% or lower. 
This is the WER we saw prior to testing the model [we only tested for the paper once after training with various random inits and choosing the best perofmring model]. This will correspond to a 3-4% WER on the test set - the test set appears to have had very strong performance, potentially becuase of the use of the avatar during this trial. Additional investigation is needed!

In [9]:
from os.path import join

if args['no_normalize']:
    print('normalizing data')
    X[:, :, :X.shape[-1]//2] = normalize(X[:, :, :X.shape[-1]//2])
    X[:, :, X.shape[-1]//2:] = normalize(X[:, :, X.shape[-1]//2:])
else: 
    print('no normalization.')
    
# select feature stream
if args['feat_stream'] == 'hga':
    X = X[:, :, :X.shape[-1]//2]
elif args['feat_stream'] == 'raw':
    X = X[:, :, X.shape[-1]//2:]
print('final X shape', X.shape)

if args['samples_to_trim']>0:
    X = X[:, :-args['samples_to_trim'], :]
print('trimmed X', X.shape)

# The following code just preprocesses the phoneme labels. Can be changed to letters or sentence pieces 
# pretty easily.
from data_loading_utilities.clean_labels import clean_labels
labels, all_ph = clean_labels(labels)
phone_enc  = {v:k for k,v in enumerate(sorted([a for a in list(set(all_ph)) if not a == '|']))}

#### Setup the files needed for the torchaudio CTC decoder. You can see more on this here, including how to adapt it to letters. 
# https://pytorch.org/audio/main/tutorials/asr_inference_with_ctc_decoder_tutorial.html

files  = download_pretrained_files("librispeech-4-gram") # Download a 4-gram librispeech LM, may take a sec.

# Here we're building a lexicon. Basically we just want to say how each word is pronounced
lex = {}
for k,v in zip(labels['txt_label'], labels['ph_label']):
    if not '|' in v:
        lex[k] = ' '.join(v) + ' |'
    else: 
        v  = '_'.join(v)
        v = v.split('|')
        for kk, vv in zip(k.split(' '), v):
            vv = vv.replace('_', ' ').strip() + ' |'
            if not kk == '':
                lex[kk] = vv
strings = []
for k, v in lex.items():
    string = k + ' ' + v
    strings.append(string)
    
strings =  [s for s in strings if len(s) > 3]
f = open(join(curdir, "for_ctc/lexicon_phrases_50.txt"), "w")
f.writelines([s+ '\n' for s in strings])
f.close()

print('example lexicon items')
for s in strings[:5]:
    print(s)
print('vocabulary size:', len(strings))

tokens = ['-', '|'] + list(phone_enc.keys())
with open(join(curdir, 'for_ctc/tokens_phrases_50.txt'), 'w') as f:
    f.writelines([t + '\n' for t in tokens])

# Final encoder for labels to go from tokens to labels.
enc_final = {v:k for k,v in enumerate(tokens)}

# Initialize beam search 

from torchaudio.models.decoder import ctc_decoder
from torchaudio.functional import edit_distance
# Set up a beam search decoder. Will take a second the first time. 
beam_search_decoder= ctc_decoder(
    lexicon = join(curdir, 'for_ctc/lexicon_phrases_50.txt'),
    tokens = join(curdir, 'for_ctc/tokens_phrases_50.txt'),
    lm =join(curdir, 'custom_lms/50_phrase_lm.binary'),
    nbest=3,
    beam_size=100, #args['beam_width'],
    lm_weight=args['LM_WEIGHT'],
    word_score=args['WORD_SCORE'],
    sil_token = '|', 
    unk_word = '<unk>',
)

# Get a greedy decoder ready as well,can be useful to see how much LM is helping.
from train.ctc_decoding import GreedyCTCDecoder
greedy_decoder = GreedyCTCDecoder(tokens)
greedy = GreedyCTCDecoder(labels=list(enc_final.keys()))

# Get neural data/targets ready for training

# Prepare neural and target data for CTC loss

y_final = []
for t, targ in zip(labels['txt_label'], labels['ph_label']):
    cur_y = []
    cur_y.append(enc_final['|'])
    for ph in targ:
        cur_y.append(enc_final[ph])
    cur_y.append(enc_final['|'])
    y_final.append(cur_y)

y_final_ = -1*np.ones((len(y_final), np.max([len(y) for y in y_final])))
targ_lengths =[]
for k, y in enumerate(y_final):
    y_final_[k, :len(y)] = np.array(y)
    targ_lengths.append(len(y))
targ_lengths = np.array(targ_lengths)
Y = y_final_

lens = [(l//args['decimation']) for l in labels['length']] # Adjust lengths based on decimation. 
# Finalize the lengths. 
outlens = targ_lengths
lens = np.array(lens)
lens = lens - args['samples_to_trim']
lens = [min(l, X.shape[1]) for l in lens]
lens = np.array(lens)# Some lengths may be a sample over. 
gt_text = labels['txt_label'].values

# Set up cv folds. We train on 95% of the data, test on heldout 5%
print(X.shape, Y.shape)

trainsets = []
inds = np.arange(len(X))
np.random.seed(1337)
np.random.shuffle(inds)
for k in range(10): 
    te_inds = sorted(inds)[-te_len:] #[k*(len(inds)//20): (k+1)*(len(inds)//20)]
    tr_inds = [i for i in inds if not i in te_inds] 
    val_inds = tr_inds[-200:]
    tr_inds = [t for t in tr_inds if not t in val_inds]
    trainsets.append((tr_inds, val_inds, te_inds))

### Train the neural network. Every 3 trials the wer/cer are evaluated. 

from train.ctc_trainer import train_loop
from models.cnn_rnn import FlexibleCnnRnnClassifier

for train, val, test in trainsets:
    # Train test split, plus load into dataset
        # Train test split, plus load into dataset
    train_amt = int(len(train)*args['train_amt'])
    print('num samples', train_amt)
    wandb.log({'num_samples':train_amt})
    train = train[:train_amt]
    print(len(train), train_amt)
    
    X_tr, X_te, X_v = X[train], X[test], X[val]
    Y_tr, Y_te , Y_v = Y[train], Y[test], Y[val]
    lens_tr, lens_te, lens_v = lens[train], lens[test], lens[val]
    inds_tr, inds_te, inds_v = np.array(train), np.array(test), np.array(val) # for loading text labels.
    outlens_tr, outlens_te, outlens_v = outlens[train], outlens[test], outlens[val]
    
    # Make datasets
    train_dset = TensorDataset(torch.from_numpy(X_tr.copy()), 
                               torch.from_numpy(Y_tr.copy()), 
                              torch.from_numpy(lens_tr.copy()), 
                              torch.from_numpy(outlens_tr.copy()), 
                              torch.from_numpy(inds_tr.copy()))
    test_dset = TensorDataset(torch.from_numpy(X_te.copy()), 
                              torch.from_numpy(Y_te.copy()), 
                             torch.from_numpy(lens_te.copy()),
                             torch.from_numpy(outlens_te.copy()), 
                             torch.from_numpy(inds_te.copy()))
    val_dset = TensorDataset(torch.from_numpy(X_v.copy()), 
                              torch.from_numpy(Y_v.copy()), 
                             torch.from_numpy(lens_v.copy()),
                             torch.from_numpy(outlens_v.copy()), 
                             torch.from_numpy(inds_v.copy()))
    
    # TODO: Add transforms from torchaudio.transforms
    train_loader = DataLoader(train_dset, batch_size=args['bs'], shuffle=True) 
    val_loader = DataLoader(val_dset, batch_size=args['bs'], shuffle=False)
    test_loader = DataLoader(test_dset, batch_size=args['bs'], shuffle=False)
    
    # Initialize the model. 
    if not args['feedforward']:
        if not args['pretrained'] is None: 
            n_targ = args['ndense']
        else: 
            n_targ=len((enc_final))
            
        model = FlexibleCnnRnnClassifier(rnn_dim=args['hidden_dim'], KS=args['ks'], 
                                         num_layers=args['num_layers'],
                                         dropout=args['dropout'], n_targ=n_targ,
                                  bidirectional=True, in_channels=X_tr.shape[-1])
    else: 
        model = FlexibleCnnRnnClassifier(rnn_dim=args['hidden_dim'], KS=args['ks'], 
                                 num_layers=args['num_layers'],
                                 dropout=args['dropout'], n_targ=len((enc_final)),
                          bidirectional=False, in_channels=X_tr.shape[-1])
        
    if not args['pretrained'] is None: 
        model.load_state_dict(torch.load(join(curdir, args['pretrained'])))
        model.dense = nn.Linear(2*args['hidden_dim'], len((enc_final)))
        
        if args['transfer_audio']:
            model.preprocessing_conv = torch.nn.Conv1d(in_channels=X_te.shape[-1],
                                               out_channels=args['hidden_dim'],
                                               kernel_size=args['ks'],
                                               stride=args['ks'])
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'])
    model = model.to(device)
    model =train_loop(model, train_loader,
                    val_loader, 
                      optimizer,
                    device, gt_text, greedy, beam_search_decoder, tokens, start_eval=0, 
                     wandb_log=True, checkpoint_dir=args['checkpoint_dir'])
    
    # Currently we only use one fold for model dev. 
    break

normalizing data
final X shape (5750, 217, 506)
trimmed X (5750, 217, 506)
example lexicon items
can K AE N |
i AY |
get G EH T |
my M AY |
medicine M EH D AH S AH N |
vocabulary size: 119
token lm False
(5750, 217, 506) (5750, 37)
num samples 5400
5400 5400
net wer 1.0
net cer 1.0
greedy per 0.7703390893464609
beam per 0.795643583454428
gt do not make me laugh
t  wer: 1.00 cer: 1.00
gt is there anything i can do
t  wer: 1.00 cer: 1.00
gt well it sure looks like it
t  wer: 1.00 cer: 1.00
gt it sounds good to me
t  wer: 1.00 cer: 1.00
gt how are things going for you
t  wer: 1.00 cer: 1.00
gt help me with the chair
t  wer: 1.00 cer: 1.00
gt i might check that out tomorrow
t  wer: 1.00 cer: 1.00
gt was there something else
t  wer: 1.00 cer: 1.00
gt hand that to me please
t  wer: 1.00 cer: 1.00
gt wait for the rest of them
t  wer: 1.00 cer: 1.00
gt just a minute let me think about that
t  wer: 1.00 cer: 1.00
epoch 0 tr loss: 0.052 te_loss: 0.053
epoch 1 tr loss: 0.039 te_loss: 0.048
epoch 

KeyboardInterrupt: 

In [11]:
# Run on test set. For early stopping, this model used a different strategy (some predictions were left to end), see paper for more details.  
model =train_loop(model, train_loader,
                test_loader, 
                  optimizer,
                device, gt_text, greedy, beam_search_decoder, tokens, start_eval=0, max_epochs=1,
                 wandb_log=True, checkpoint_dir=args['checkpoint_dir'], train=False)

net wer 0.03133333333333333
net cer 0.027093025283347862
greedy per 0.06600146498863235
beam per 0.027620318114687025
gt i am glad you are here
t i am glad you are here wer: 0.00 cer: 0.00
gt can i get my medicine
t can i get my medicine wer: 0.00 cer: 0.00
gt it sounds good to me
t it sounds good to me wer: 0.00 cer: 0.00
gt i think this is pretty good
t i think this is pretty good wer: 0.00 cer: 0.00
gt how is the weather today
t how is the weather today wer: 0.00 cer: 0.00
gt i thought it would be good for me
t i thought it would be good for me wer: 0.00 cer: 0.00
gt thank you it is looking good
t thank you it is looking good wer: 0.00 cer: 0.00
gt how are things going for you
t how are things going for you wer: 0.00 cer: 0.00
gt i might check that out tomorrow
t i might check that out for you wer: 0.33 cer: 0.19
gt what are you looking for
t what are you looking for wer: 0.00 cer: 0.00
gt i think this is pretty good
t i think this is pretty good wer: 0.00 cer: 0.00
epoch 0 tr loss: