In [32]:
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
import numpy as np 
import torch
import pickle

In [26]:
# The TRF can be perceived as follows.
# HGA shape: (B, C, T), batch size, num channels, num time points.
# X shape: (B, H, T). H represents number of features. 


# For i in range(1, number of ECoG channel):
#   Train a regressor model. 
#   For t in range(1, T):
#       For k in range(1, d): i.e. the delay.
#           HGA[:, i, t] = model.predict(X[:, :, t-k])

# To efficiently train the model above:
# Each model takes in: HGA[:, i, :] of shape (B, 1, T)
# X: (B, H, T)

# Reshape:  HGA => (B*T, 1); X => (B*T, H)
# Generate time delayed x features: X => (B*T, H* d)
# Train the model regressing (B*T, H*d) => (B*T, 1)
# Resulting model weights shape: (H*d, 1); which can be reshaped to (H, d), as noted in the paper.

In [33]:
def load_pkl(filename):
    with open(filename, "rb") as f:
        data = pickle.load(f)
    f.close()
    return data

def save_pkl(data, file):
    with open(file, 'wb') as f:
        pickle.dump(data, f)
    print(f"Data succesfully saved to {file}!")

In [27]:
# TODO: Change the following params to actual data shape.
B, C, T = 8, 256, 100
H = 1024
d = 60

# For debugging.
# B, C, T = 1, 2, 100
# H = 4
# d = 10

# TODO: Load the data.
HGA = torch.rand((B, C, T))
X = torch.rand((B, H, T))

for i in range(C):
    hga = HGA[:, i, :]
    hga = hga.reshape(B*T, 1)
    
    # Generate time delayed x features.
    # TODO: note that I'm assuming 0s for x if t - k < 0.
    # TODO: note that I'm also assuming d = 1 means 1 delay in X. (i.e. d == frequency of X)
    x_delayed = []
    for k in range(1, d + 1):
        # Append 0s to the front. 
        feature = torch.concatenate((torch.zeros((B, H, k)), X[:, :, :-k]), dim = -1)
        x_delayed.append(feature)
    
    # (B, H * d, T)
    x_delayed = torch.concatenate(x_delayed, dim = 1)
    assert x_delayed.shape == (B, H * d, T)
    
    # (B*T, H*d)
    x_delayed = x_delayed.transpose(1, -1).reshape(B*T, H*d)
    
    # Train the model.
    # In the paper it used ridge.
    # A list of possible models: LinearRegression, Ridge, Lasso, ElasticNet
    model = Ridge(alpha=0.1)
    
    # TODO: add k-fold & train-val-test split if necessary. 
    model.fit(x_delayed, hga)
    
    # Save the model.
    save_pkl(model, f"model_{i}.pkl")
    
    # To load the model:
    # model = load_pkl(f"model_{i}.pkl")
    
    
    # Delete the 'break' below. 
    break