In [2]:
import torch, torchvision
import pandas as pds
import numpy as np
import os, time, copy
import matplotlib.pyplot as plt

### Data Formats
* Input: (lag, 135 * (262/lag), freq)
    * 135: Number of participants
    * 262: Feedback onset time (timestep 888 - timestep 626)
    * lag: lag time hyperparameter
    * freq: frequencies hyperparameter chosen for number of samples
    
* Output: (lag, 135 * (262/lag), 4)
    * same values for a single batch
    * 4 theta bands (4, 5, 6, 7)
    
* LNR evoked data
    * FC1
    * FC2

### Model Architectures
* LSTM
* RNN
* GRU

In [None]:
PATH_TO_DATA = '/home/matt/eeg_data/power'
inputs = pds.load_csv(PATH_TO_DATA).to_numpy()

lag = 5
batch_size = 135*(262/lag)
freq = 5
hidden_size = 20
num_layers = 1
output_bands = 4
hidden_state = torch.randn(num_layers, batch_size, hidden_size)
cell_state = torch.randn(num_layers, batch_size, hidden_size)

In [None]:
"""
LSTM
"""
lstm = torch.nn.LSTM(freq, hidden_size, num_layers)
l_output, (l_hn, l_cn) = lstm(inputs, (hidden_state, cell_state))
l_fcc = torch.nn.Linear(l_output.shape[0], output_bands)
print(l_fcc)

In [None]:
"""
RNN
"""
rnn = torch.nn.RNN(freq, hidden_size, num_layers)
r_output, r_hn = rnn(inputs, hidden_state)
r_fcc = torch.nn.Linear(r_output.shape[0], output_bands)
print(r_fcc)

In [None]:
"""
GRU
"""
gru = torch.nn.GRU(freq, hidden_size, num_layers)
g_output, g_hn = gru(inputs, hidden_state)
g_fcc = torch.nn.Linear(r_output.shape[0], output_bands)
print(g_fcc)