In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import glob
import matplotlib.pyplot as plt
%matplotlib inline
import os
from torch_setupParameters import set_seed,set_device, getDefaultRNNArgs
from torch_dataPreprocessing import loadAllRealDatasets, prepareDataCubesForRNN
from torch_dataPreprocessing import normalizeSentenceDataCube, binTensor
from torch_dataPreprocessing import handBCI_Dataset, handBCI_SythDataset, gaussSmooth
from tfrecord.torch.dataset import MultiTFRecordDataset

In [None]:
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()

In [None]:
#point this towards the top level dataset directory
rootDir = '../handwritingBCIData/'
outDir = 'output/'
#train an RNN using data from these specified sessions
dataDirs = ['t5.2019.05.08','t5.2019.11.25','t5.2019.12.09','t5.2019.12.11','t5.2019.12.18',
            't5.2019.12.20','t5.2020.01.06','t5.2020.01.08','t5.2020.01.13','t5.2020.01.15']


#use this train/test partition 
cvPart = 'HeldOutTrials'

#name of the directory where this RNN run will be saved
rnnOutputDir = cvPart

## parameters
args = getDefaultRNNArgs(rootDir, cvPart, outDir)
#Configure the arguments for a multi-day RNN (that will have a unique input layer for each day)
for x in range(len(dataDirs)):
    args['sentencesFile_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/sentences.mat'
    args['singleLettersFile_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/singleLetters.mat'
    args['labelsFile_'+str(x)] = rootDir+'RNNTrainingSteps/Step2_HMMLabels/'+cvPart+'/'+dataDirs[x]+'_timeSeriesLabels.mat'
    args['syntheticDatasetDir_'+str(x)] = rootDir+'Datasets/'+dataDirs[x]+'/'+cvPart+'/'+dataDirs[x]+'_syntheticSentences/'
    args['cvPartitionFile_'+str(x)] = rootDir+'RNNTrainingSteps/trainTestPartitions_'+cvPart+'.mat'
    args['sessionName_'+str(x)] = dataDirs[x]

for t in range(30):  ## 10 days
    if 'labelsFile_'+str(t) not in args.keys():
        args['nDays'] = t
        break
if not os.path.isdir(args['outputDir']):
    os.mkdir(args['outputDir'])
    
#this weights each day equally (0.1 probability for each day) and allocates a unique input layer for each day (0-9)
args['dayProbability'] = '[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]'
args['dayToLayerMap'] = '[0,1,2,3,4,5,6,7,8,9]'
# args['verbose'] = True ## extra print-out information

args['mode'] = 'train' ## make sure it is set in 'train' mode
print('batchSize:', args['batchSize'])
print('synthBatchSize:', args['synthBatchSize'])
args['ForTestingOnly'] = False ## FOR DEBUGING. set "self.nDays = 2" (use 2 days of data for testing run)


## torch dataloaders

In [None]:
allSynthDataLoaders = []
allRealDataLoaders = []
allValDataLoaders = []
daysWithValData = []
args['isTraining'] = True
for dayIdx in [2,3]: #range(args['nDays']):
    ## real data
    print('Loading real data ', dayIdx)
    neuralData, targets, errWeights, binsPerTrial, cvIdx = prepareDataCubesForRNN(args['sentencesFile_'+str(dayIdx)],
                                                                          args['singleLettersFile_'+str(dayIdx)],
                                                                          args['labelsFile_'+str(dayIdx)],
                                                                          args['cvPartitionFile_'+str(dayIdx)],
                                                                          args['sessionName_'+str(dayIdx)],
                                                                          args['rnnBinSize'],
                                                                          args['timeSteps'],
                                                                          args['isTraining'])
    realDataSize = args['batchSize'] - args['synthBatchSize']
    trainIdx = cvIdx['trainIdx']
    valIdx = cvIdx['testIdx']
    print('create real dataset ', dayIdx)
    realData_train = handBCI_Dataset(args,neuralData[trainIdx,:,:], targets[trainIdx,:,:], errWeights[trainIdx,:],\
                               binsPerTrial[trainIdx,np.newaxis],\
                               addNoise=True)
    realDataTrain_Loader = torch.utils.data.DataLoader(realData_train, batch_size =realDataSize,shuffle=True, num_workers=0)
    
    if len(valIdx)==0:
        realDataVal_Loader = realDataTrain_Loader
    else:
        realData_val = handBCI_Dataset(args,neuralData[valIdx,:,:], targets[valIdx,:,:], errWeights[valIdx,:],\
                                   binsPerTrial[valIdx,np.newaxis],\
                                       addNoise=False)
        realDataVal_Loader = torch.utils.data.DataLoader(realData_val, batch_size =args['batchSize'],shuffle=True, num_workers=0)
        daysWithValData.append(dayIdx)
    allRealDataLoaders.append(realDataTrain_Loader)
    allValDataLoaders.append(realDataVal_Loader)
              
    ## sythetic data
    print('processing sythetic data ', dayIdx)
    synthDir = args['syntheticDatasetDir_'+str(dayIdx)]
    synth_obj = handBCI_SythDataset(synthDir, args)
    synth_ds = synth_obj.makeDataSet()
    synth_loader = torch.utils.data.DataLoader(synth_ds, batch_size=args['synthBatchSize'])
    allSynthDataLoaders.append(synth_loader)
  

In [None]:
class charSeqNet(nn.Module):
    def __init__(self, args):
        super(charSeqNet, self).__init__()
        """
        """        
        #count how many days of data are specified
        self.nDays = args['nDays']
        self.args = args
        if self.args['seed']==-1:
            self.args['seed']=datetime.now().microsecond
        drop_prob = args['drop_prob']
        #define the dimensions of layers in the RNN
        nOutputs = 31
        nInputs = 192        
        nUnits = args['nUnits']
        nTimeSteps = args['timeSteps']
        self.rnnBinSize = args['rnnBinSize']
        inputLayers = []        
#        shape: [args['batchSize'], args['timeSteps'], nInputs]
        for j in range(self.nDays):
            inputLayers.append(nn.Linear(nInputs, nInputs, bias = True))

        self.inputLayers = inputLayers
        self.gru1 = torch.nn.GRU(nInputs, nUnits, 1, \
                                 batch_first=True, dropout=drop_prob)
 
        self.gru2 = torch.nn.GRU(nUnits, nUnits, 1, \
                                 batch_first=True, dropout=drop_prob)

        self.fc1 = nn.Linear(nUnits, nOutputs, bias = True)
        
        self.criterion = nn.CrossEntropyLoss()
 
    def forward(self, x, dayIdx):

        if self.args['smoothInputs']==1: ## smooth
            x = torch.Tensor(gaussSmooth(x, kernelSD=4/self.rnnBinSize))
        layer1 = self.inputLayers[dayIdx].to('cuda') ## day specific input layer
        x = layer1(x.to('cuda'))
        x, h = self.gru1(x)
        ## TODO: downsample x's time dimension
        x, h = self.gru2(x)
        ## TODO: upsample
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)
    
    def train(self):
        ## TO be continued ...
        n = 100
        for epoc in range(n):
            dayIdx = np.random.randint(self.nDays)
            miniBatch = next(iter(allRealDataLoaders[0])) ## change 0 to dayIdx in full mode
            X = miniBatch['inputs']
            target = miniBatch['labels']
            er = miniBatch['errWeights']
            output = self.forward(X, dayIdx)
            ## output is 32 dimensions, target is 31?
            loss = self.criterion(output, target)
            loss.backward()

In [None]:
# miniBatch = next(iter(allRealDataLoaders[0]))

In [None]:
## Uncomment the lines below to train your network
charSeq_net = charSeqNet(args).to(DEVICE)
print("Total Parameters in Network {:10d}".format(sum(p.numel() for p in charSeq_net.parameters())))

In [None]:
charSeq_net.train()