In [1]:
import torch, torchvision
import pandas as pds
import numpy as np
import os, time, copy
import matplotlib.pyplot as plt

### Data Formats
* Input: (lag, 135 * (262/lag), freq)
    * 135: Number of participants
    * 262: Feedback onset time (timestep 888 - timestep 626)
    * lag: lag time hyperparameter
    * freq: frequencies hyperparameter chosen for number of samples
    
* Output: (lag, 135 * (262/lag), 4)
    * same values for a single batch
    * 4 theta bands (4, 5, 6, 7)
    
* LNR evoked data
    * FC1
    * FC2

### Model Architectures
* LSTM
* RNN
* GRU

In [4]:
def get_dataset(data_filepath):
    data = np.load(data_filepath)

    tensor_x = torch.Tensor(data['x'])
    tensor_y = torch.Tensor(data['y'])

    dataset = TensorDataset(tensor_x,tensor_y)

    return dataset

In [2]:
inp = torch.from_numpy(np.load('/home/matt/eeg_data/eeg_dataset_training2/x.npy'))
outputs = torch.from_numpy(np.load('/home/matt/eeg_data/eeg_dataset_training2/y.npy'))

lag = 5
# batch_size = 135*(262/lag)
batch_size = 52#26416
freq = 8
hidden_size = 20
num_layers = 1
output_bands = 4
hidden_state = torch.randn(1, batch_size, hidden_size, dtype=torch.double)
cell_state = torch.randn(1, batch_size, hidden_size, dtype=torch.double)

training_dataset = get_dataset('/home/matt/eeg_data/eeg_dataset_training2.npz')
training_loader = DataLoader(dataset=training_dataset,batch_size=batch_size,shuffle=True)

In [3]:
"""
LSTM
"""
lstm = torch.nn.LSTM(freq, hidden_size=hidden_size, num_layers=num_layers, batch_first=True).double()
l_output, (l_hn, l_cn) = lstm(inp, (hidden_state, cell_state))
squeezed_l_out = torch.squeeze(l_output[:,4,:], 1)
l_fcc = torch.nn.Linear(squeezed_l_out.shape[1], output_bands).double()
# print(squeezed_l_out)
l_out = l_fcc(squeezed_l_out)
print(l_out)

RuntimeError: Expected hidden[0] size (1, 26416, 20), got [1, 52, 20]

In [None]:
"""
RNN
"""
rnn = torch.nn.RNN(freq, hidden_size=hidden_size, num_layers=num_layers, batch_first=True).double()
r_output, r_hn = rnn(inp, hidden_state)
squeezed_r_out = torch.squeeze(r_output[:,4,:], 1)
r_fcc = torch.nn.Linear(squeezed_r_out.shape[1], output_bands).double()
r_out = r_fcc(squeezed_r_out)
print(r_out)

In [None]:
"""
GRU
"""
gru = torch.nn.GRU(freq, hidden_size=hidden_size, num_layers=num_layers, batch_first=True).double()
g_output, g_hn = gru(inp, hidden_state)
squeezed_g_out = torch.squeeze(g_output[:,4,:], 1)
g_fcc = torch.nn.Linear(squeezed_g_out.shape[1], output_bands).double()
g_out = g_fcc(squeezed_g_out)
print(g_out)