In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from torch.autograd import Variable
import numpy as np
import glob
import matplotlib.pyplot as plt
%matplotlib inline
import os
from torch_setupParameters import set_seed,set_device, getDefaultRNNArgs
from torch_dataPreprocessing import loadAllRealDatasets, prepareDataCubesForRNN
from torch_dataPreprocessing import normalizeSentenceDataCube, binTensor
from torch_dataPreprocessing import handBCI_Dataset, handBCI_SythDataset, combineSynthAndReal, gaussSmooth
from tfrecord.torch.dataset import MultiTFRecordDataset
from torch.optim.lr_scheduler import LambdaLR
from datetime import datetime
from collections import OrderedDict
from torchsummary import summary as netSummary

import importlib

In [2]:
# import torch_dataPreprocessing
# importlib.reload(torch_dataPreprocessing)

In [3]:
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()

Random seed 2021 has been set.
GPU is enabled !


In [4]:
#point this towards the top level dataset directory
rootDir = '../handwritingBCIData/'
outDir = 'output/'
#train an RNN using data from these specified sessions
dataDirs = ['t5.2019.05.08','t5.2019.11.25','t5.2019.12.09','t5.2019.12.11','t5.2019.12.18',
            't5.2019.12.20','t5.2020.01.06','t5.2020.01.08','t5.2020.01.13','t5.2020.01.15']


#use this train/test partition 
cvPart = 'HeldOutTrials'

#name of the directory where this RNN run will be saved
rnnOutputDir = cvPart

## parameters
args = getDefaultRNNArgs(rootDir, cvPart, outDir)
#Configure the arguments for a multi-day RNN (that will have a unique input layer for each day)
for x in range(len(dataDirs)):
    args['sentencesFile_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/sentences.mat'
    args['singleLettersFile_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/singleLetters.mat'
    args['labelsFile_'+str(x)] = rootDir+'RNNTrainingSteps/Step2_HMMLabels/'+cvPart+'/'+dataDirs[x]+'_timeSeriesLabels.mat'
    args['syntheticDatasetDir_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/'+cvPart+'/'+dataDirs[x]+'_syntheticSentences/'
    args['cvPartitionFile_'+str(x)] = rootDir+'RNNTrainingSteps/trainTestPartitions_'+cvPart+'.mat'
    args['sessionName_'+str(x)] = dataDirs[x]

for t in range(30):  ## 10 days
    if 'labelsFile_'+str(t) not in args.keys():
        args['nDays'] = t
        break
if not os.path.isdir(args['outputDir']):
    os.mkdir(args['outputDir'])
    
#this weights each day equally (0.1 probability for each day) and allocates a unique input layer for each day (0-9)
args['dayProbability'] = '[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]'
args['dayToLayerMap'] = '[0,1,2,3,4,5,6,7,8,9]'
# args['verbose'] = True ## extra print-out information

args['mode'] = 'train' ## make sure it is set in 'train' mode
print('batchSize:', args['batchSize'])
print('synthBatchSize:', args['synthBatchSize'])
args['ForTestingOnly'] = False ## FOR DEBUGING. set "self.nDays = 2" (use 2 days of data for testing run)


batchSize: 32
synthBatchSize: 12


## torch dataloaders

In [5]:
allSynthDataLoaders = []
allRealDataLoaders = []
allValDataLoaders = []
daysWithValData = []
args['isTraining'] = True
for dayIdx in range(args['nDays']):
    ## real data
    print('Loading real data ', dayIdx)
    neuralData, targets, errWeights, binsPerTrial, cvIdx = prepareDataCubesForRNN(args['sentencesFile_'+str(dayIdx)],
                                                                          args['singleLettersFile_'+str(dayIdx)],
                                                                          args['labelsFile_'+str(dayIdx)],
                                                                          args['cvPartitionFile_'+str(dayIdx)],
                                                                          args['sessionName_'+str(dayIdx)],
                                                                          args['rnnBinSize'],
                                                                          args['timeSteps'],
                                                                          args['isTraining'])
    realDataSize = args['batchSize'] - args['synthBatchSize']
    trainIdx = cvIdx['trainIdx']
    valIdx = cvIdx['testIdx']
    print('create real dataset ', dayIdx)
    realData_train = handBCI_Dataset(args,neuralData[trainIdx,:,:], targets[trainIdx,:,:], errWeights[trainIdx,:],\
                               binsPerTrial[trainIdx,np.newaxis],\
                               addNoise=True)
    realDataTrain_Loader = torch.utils.data.DataLoader(realData_train, batch_size =realDataSize,shuffle=True, num_workers=0)
    
    if len(valIdx)==0:
        realDataVal_Loader = realDataTrain_Loader
    else:
        realData_val = handBCI_Dataset(args,neuralData[valIdx,:,:], targets[valIdx,:,:], errWeights[valIdx,:],\
                                   binsPerTrial[valIdx,np.newaxis],\
                                       addNoise=False)
        realDataVal_Loader = torch.utils.data.DataLoader(realData_val, batch_size =args['batchSize'],shuffle=True, num_workers=0)
        daysWithValData.append(dayIdx)
    allRealDataLoaders.append(realDataTrain_Loader)
    allValDataLoaders.append(realDataVal_Loader)
              
    ## sythetic data
    if args['synthBatchSize'] > 0:
        print('processing sythetic data ', dayIdx)
        synthDir = args['syntheticDatasetDir_'+str(dayIdx)]
        synth_obj = handBCI_SythDataset(synthDir, args)
        synth_ds = synth_obj.makeDataSet()
        synth_loader = torch.utils.data.DataLoader(synth_ds, batch_size=args['synthBatchSize'])
        allSynthDataLoaders.append(synth_loader)

Loading real data  0
create real dataset  0
processing sythetic data  0
Loading real data  1
create real dataset  1
processing sythetic data  1
Loading real data  2
create real dataset  2
processing sythetic data  2
Loading real data  3
create real dataset  3
processing sythetic data  3
Loading real data  4
create real dataset  4
processing sythetic data  4
Loading real data  5
create real dataset  5
processing sythetic data  5
Loading real data  6
create real dataset  6
processing sythetic data  6
Loading real data  7
create real dataset  7
processing sythetic data  7
Loading real data  8
create real dataset  8
processing sythetic data  8
Loading real data  9
create real dataset  9
processing sythetic data  9


In [6]:
# miniBatch = next(iter(allRealDataLoaders[0]))  

In [11]:
class charSeqNet(nn.Module):
    def __init__(self, args):
        super(charSeqNet, self).__init__()
        """
        """        
        #count how many days of data are specified
        self.nDays = args['nDays']
        self.args = args
        if self.args['seed']==-1:
            self.args['seed']=datetime.now().microsecond
        drop_prob = args['drop_prob']
        #define the dimensions of layers in the RNN
        self.batchSize = args['batchSize']
        nOutputs = 31 + 1 # 31 letters/punctations + 1 transition labels
        self.nInputs = 192        
        self.nUnits = 512 # args['nUnits']
        nUnits2 = 512
        self.nTimeSteps = args['timeSteps']
        self.rnnBinSize = args['rnnBinSize']
        self.skipLen = args['skipLen']
        self.outputDelay = args['outputDelay']
        self.inputLayers = {}
        nLinerOuput = 192
        for j in range(self.nDays):
            self.inputLayers['input_'+str(j)] = nn.Linear(self.nInputs, nLinerOuput, bias = True)

        self.input = self.inputLayers['input_0']
#         self.relu1 = nn.ReLU()
        self.gru1 = nn.GRU(nLinerOuput, self.nUnits, 1, \
                                 batch_first=True, dropout=0)
        self.gru2 = nn.GRU(self.nUnits, nUnits2, 1, \
                             batch_first=True, dropout=0)

        self.fc1 = nn.Linear(nUnits2, nOutputs, bias = True)


    def forward(self, x):
#       x shape: [args['batchSize'], args['timeSteps'], nInputs]        
        x = torch.transpose(x, 1,2) ## swap time dimension with channel dimenion to get [bsize, chan, time]
        if self.args['smoothInputs']==1: ## smooth
            x = gaussSmooth(x, kernelSD=4/self.rnnBinSize)
        x = torch.transpose(x, 1,2) #swap time dimension with channel dimenion to get [bsize, time,chan]
        x = x.clone().float().detach().requires_grad_(True)
#         x = Variable(x.clone().float().detach(), requires_grad=True)
        x = self.input(x.to('cuda'))
#         x = self.relu1(x)  ## Add a relu layer
#         h0 = self.init_hidden(self.batchSize)
        x, h = self.gru1(x)
        x, h = self.gru2(x[:, 0::self.skipLen,:], h)  # downsample x's time dimension
#         x = torch.repeat_interleave(x.detach(), self.skipLen, dim=1)## upsample x's time dimension
        output = self.fc1(torch.repeat_interleave(x.detach(), self.skipLen, dim=1)) ## upsample x's time dimension
        return output
        ## Weights initialization

    def init_hidden(self,  batch_size):
        hidden = self.gru1.weight_hh_l0.new(1, batch_size, self.nUnits).zero_().to('cuda')
        return hidden

In [12]:
## Uncomment the lines below to train your network
charSeq_net = charSeqNet(args).to(DEVICE)
## initialization 
for p in charSeq_net.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
#         nn.init.xavier_normal_(p.weight)
#         p.bias.data.zero_()
# for ii, layer in enumerate(charSeq_net.parameters()):
#     if type(layer) == nn.Linear:
#         nn.init.xavier_uniform_(layer, gain=nn.init.calculate_gain('relu'))
print("Total Parameters in Network {:10d}".format(sum(p.numel() for p in charSeq_net.parameters())))

Total Parameters in Network    2713824


In [13]:
print(charSeq_net)

charSeqNet(
  (input): Linear(in_features=192, out_features=192, bias=True)
  (gru1): GRU(192, 512, batch_first=True)
  (gru2): GRU(512, 512, batch_first=True)
  (fc1): Linear(in_features=512, out_features=32, bias=True)
)


In [14]:
netSummary(charSeq_net,input_size = (1200,192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1            [-1, 1200, 192]          37,056
               GRU-2  [[-1, 1200, 512], [-1, 2, 512]]               0
               GRU-3  [[-1, 240, 512], [-1, 2, 512]]               0
            Linear-4             [-1, 1200, 32]          16,416
Total params: 53,472
Trainable params: 53,472
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.88
Forward/backward pass size (MB): 5757.95
Params size (MB): 0.20
Estimated Total Size (MB): 5759.03
----------------------------------------------------------------


In [None]:
criterion_letter = nn.CrossEntropyLoss(reduction = 'none') # we need to apply error weight matrix before averaging
transit_sigmoid = nn.Sigmoid()
criterion_transit = nn.MSELoss() ## mean square error for transit
optimizer = torch.optim.AdamW(charSeq_net.parameters(),lr=1e-4, betas=(0.9, 0.999), eps=1e-08,\
                           weight_decay = args['l2scale'], amsgrad = False) # weight_decay for l2 regurlization


In [None]:
def computeFrameAccuracy(rnnOutput, targets, errWeight, outputDelay):
    """
    Computes a frame-by-frame accuracy percentage given the rnnOutput and the targets, while ignoring
    frames that are masked-out by errWeight and accounting for the RNN's outputDelay. 
    """
    #Select all columns but the last one (which is the character start signal) and align rnnOutput to targets
    #while taking into account the output delay. 
    bestClass = np.argmax(rnnOutput[:,outputDelay:,0:-1], axis=2)
    indicatedClass = np.argmax(targets[:,0:-outputDelay,0:-1], axis=2)
    bw = errWeight[:,0:-outputDelay]

    #Mean accuracy is computed by summing number of accurate frames and dividing by total number of valid frames (where bw == 1)
    acc = np.sum(bw*np.equal(np.squeeze(bestClass), np.squeeze(indicatedClass)))/np.sum(bw)
    
    return acc

In [None]:
# def _validationDiagnostics(self, i, nBatchesPerVal, lr, totalSeconds, runResultsTrain, trainAcc):
#     """
#     Runs a single minibatch on the validation data and returns performance statistics and a snapshot of key variables for
#     diagnostic purposes. The snapshot file can be loaded and plotted by an outside program for real-time feedback of how
#     the training process is going.
#     """
#     #Randomly select a day that has validation data; if there is no validation data, then just use the last days' training data
#     if self.daysWithValData==[]:
#         dayNum = self.nDays-1
#         datasetNum = dayNum*2
#     else:
#         randIdx = np.random.randint(len(self.daysWithValData))
#         dayNum = self.daysWithValData[randIdx]
#         datasetNum = 1+dayNum*2 #odd numbers are the validation partitions

#     runResults = self._runBatch(datasetNum=datasetNum, dayNum=dayNum, lr=lr, computeGradient=True, doGradientUpdate=False)

#     valAcc = computeFrameAccuracy(runResults['logitOutput'], 
#                             runResults['targets'],
#                             runResults['batchWeight'], 
#                             self.args['outputDelay'])

#     print('Val Batch: ' + str(i) + '/' + str(self.args['nBatchesToTrain']) + ', valErr: ' + str(runResults['err']) + ', trainErr: ' + str(runResultsTrain['err']) + ', Val Acc.: ' + str(valAcc) + ', Train Acc.: ' + str(trainAcc) + ', grad: ' + str(runResults['gradNorm']) + ', learnRate: ' + str(lr) + ', time: ' + str(totalSeconds))

#     outputSnapshot = {}
#     outputSnapshot['inputs'] = runResults['inputFeatures'][0,:,:]
#     outputSnapshot['rnnUnits'] = runResults['output'][0,:,:]
#     outputSnapshot['charProbOutput'] = runResults['logitOutput'][0,:,0:-1]
#     outputSnapshot['charStartOutput'] = scipy.special.expit(runResults['logitOutput'][0,self.args['outputDelay']:,-1])
#     outputSnapshot['charProbTarget'] = runResults['targets'][0,:,0:-1]
#     outputSnapshot['charStartTarget'] = runResults['targets'][0,:,-1]
#     outputSnapshot['errorWeight'] = runResults['batchWeight'][0,:]

#     return [i, runResults['err'], runResults['gradNorm'], valAcc], outputSnapshot

In [None]:

lr_lambda = lambda epoch: args['learnRateStart']*(1 - epoch/args['nBatchesToTrain'])
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

In [None]:
#https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94/17
#  auto grad
#https://pytorch.org/docs/stable/autograd.html

In [None]:
 ## Traing..
nPredicts = args['timeSteps']-args['outputDelay']
train_acc = []
# dayIdx = np.random.randint(2) #args['nDays']
dtStart = datetime.now()
scale_transit_loss = 5
dayIdx = 6 #np.random.randint(2)
for epoch in range(args['nBatchesToTrain']):
#     dtStart = datetime.now()

    # grab a batch of data
#     miniBatch = next(iter(allRealDataLoaders[dayIdx]))
#     X = miniBatch['inputs']
#     target_raw = miniBatch['labels']
#     erws = miniBatch['errWeights'] 
    
    X, target_raw, erws = combineSynthAndReal(iter(allSynthDataLoaders[dayIdx]),iter(allRealDataLoaders[dayIdx])) ## change 0 to dayIdx in full mode
#     totalSeconds = (datetime.now()-dtStart).total_seconds()

#     print('data loading time:', totalSeconds)

    charSeq_net.input = charSeq_net.inputLayers['input_'+str(dayIdx)].to('cuda') ## day specific input
    optimizer.zero_grad() ## clear gradients befor new forward pass
    output = charSeq_net(X) ## forward pass

    ## here we accounting for the output delay
    target = target_raw[:,0:-args['outputDelay'],:]
    bw = erws[:,0:-args['outputDelay']]
    logits = output[:,args['outputDelay']:,:]
    
    ## seperate out characters and transit signal (last column)
    output_transit = Variable(logits[:,:,-1].to('cuda'), requires_grad=True)
    output_letters = Variable(logits[:,:,0:-1].to('cuda'), requires_grad=True)
    target_transit = Variable(target[:,:,-1].to('cuda'), requires_grad=True)
    target_letters = Variable(target[:,:,0].to('cuda'), requires_grad=True)
    
#     output_ = output[:,args['outputDelay']:,:][:,:,-1]
#     target_ = target_raw[:,0:-args['outputDelay'],:][:,:,-1]
    ## compute loss
    loss_letters=[]
    for t in range(args['timeSteps']-args['outputDelay']):  ## need to apply the error weight per timestep per batch 
        loss_letters.append(criterion_letter(output_letters[:,t,:],\
                                             target_letters[:,t].long())*bw[:,t].to('cuda')/nPredicts)
  
    #loss_letters.backward(retain_graph=True)
    loss_letters = torch.stack(loss_letters, dim=0).sum(dim=0).mean(dim=0) ## sum over averaged time then mean across batch
    loss_transit = criterion_transit(transit_sigmoid(output[:,args['outputDelay']:,:][:,:,-1]),\
                                     target_raw[:,0:-args['outputDelay'],:][:,:,-1].float().to('cuda'))
    loss = loss_letters + scale_transit_loss*loss_transit

    ## backprop
    loss.backward()

#     clip_grad_norm_(charSeq_net.parameters(), args['clip_grads']) ## clip gradient to prevent explosion
    optimizer.step()          # weight update
#     print(charSeq_net.fc1.weight.data.grad)
#     print(charSeq_net.input.weight.data.grad)
    scheduler.step()          ## learning rate update
    
    ## report 

    if epoch % args['batchesPerVal'] == 0:
        with torch.no_grad():
            trainAcc = computeFrameAccuracy(output.detach().cpu().numpy(), 
                            target_raw.detach().cpu().numpy(),
                            erws.detach().cpu().numpy(), 
                            args['outputDelay'])
            totalSeconds = (datetime.now()-dtStart).total_seconds()
            train_acc.append(trainAcc)
#             totalPars = sum(p.numel() for p in charSeq_net.parameters())
            print(f'epoc#{str(epoch):3} day#{str(dayIdx):2} train-loss:{loss.item():.3f} train-Acc:{trainAcc:.2%} train-time:{totalSeconds/60.0:.1f}')

#         batchValStats[valSetIdx,0:4], outputSnapshot = \
#         self._validationDiagnostics(i, args['batchesPerVal'], lr,\
#                                     totalSeconds, runResultsTrain, trainAcc)
              

In [None]:
plt.figure()
val_epochs = np.arange(0, len(train_acc)*args['batchesPerVal'],args['batchesPerVal'])
plt.plot(val_epochs,train_acc);

In [None]:
print(charSeq_net.input.weight.grad, charSeq_net.input.bias.grad)

In [None]:
print(charSeq_net.gru1.weights.grad, charSeq_net.gru1.bias.grad)