### Installing dependencies

                                                    Project details                                                                
                                 
The data represents various brain activities: resting, math & story tasks, working memory, and motor tasks.

    The 'Intra' folder contains data from one subject, while the 'Cross' folder includes multiple subjects.

Each file is a matrix of shape 248 x 35624, where 248 represents the number of sensors, and 35624 represents time steps.

The files have the following format: “taskType subjectIdentifier number.h5”
where taskType can be rest, task motor, task story math, and task working memory.

In practice, these tasks correspond to the activities performed by the subjects:

    • Resting Task
Recording the subjects’ brain while in a relaxed resting
state.

    • Math & Story Task
Subject performs mental calculation and language
processing task.

    • Working Memory task
Subject performs a memorization task.

    • Motor Task
Subject performs a motor task, typically moving fingers
or feets

In [2]:
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch import FloatTensor, LongTensor
from typing import Tuple, List, Callable, Optional
from sklearn.metrics import accuracy_score
import os
import numpy as np
from tqdm import tqdm
import random
import pandas as pd

Reading data:

In [3]:
def get_dataset_name(file_name_with_dir):
    filename_without_dir = file_name_with_dir.split('/')[-1]
    temp = filename_without_dir.split('_')[:-1]
    dataset_name = "_".join(temp)
    return dataset_name

## Functions for data preprocessing

In [4]:
# min-max scaling
def minmax(trial):
    min = trial.min()
    max = trial.max()
    normalisedTrial = (trial - min)/(max-min)
    return normalisedTrial

#Z-score normalisation OPTIONAL
def zscore(trial):
    mean = trial.mean()
    sd = trial.std()
    normalisedTrial = (trial - mean)/sd 
    return normalisedTrial

#downsamples data by totaltimesteps/factor
def downsample(trial, factor):
    ds_trial = trial[:,::factor]
    return ds_trial



In [5]:
def preprocess_files(files = None, path = 'Final Project data/Cross/train', downsampling = 30):
    label_to_int = {'rest': 0, 'task_motor': 1, 'task_story_math': 2, 'task_working_memory': 3}

    cross_data_train = [] # Store data
    cross_data_train_labels = [] # Store labels (based on filename)

    if files == None:
        files = os.listdir(path)

    for file in files:
        file_path = f'{path}/{file}'
        
        with h5py.File(file_path, 'r') as h5_file:
            # obtain labels
            dataset_name = get_dataset_name(file_path)
            label = dataset_name.split('_')
            label.remove(label[len(label)-1])
            label = '_'.join(label)
            cross_data_train_labels.append(label_to_int[label])
            
            # obtain X_data
            matrix = h5_file.get(dataset_name)[()]
            normalisedMatrix = downsample(zscore(matrix), downsampling) # apply minmax normalisation and downsampling
            cross_data_train.append(normalisedMatrix.T) # Transpose
             
    X = torch.from_numpy(np.array(cross_data_train)).float()
    y = torch.tensor(cross_data_train_labels)        
            
    return X, y

## RNN model

In [6]:
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        
        # RNN layer
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        
        # Fully connected layer
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Forward pass through RNN
        rnn, _ = self.rnn(x)
        
        # Only take the output from the final time step
        output = self.fc(rnn[:, -1, :])
        return output
    

In [7]:
# hyperparams: lr, hidden_size, downsampling
def train(path, lr = 0.001, hidden_size = 200, downsampling = 30, print_results = True):
    random.seed = 123 # Set seed for reproducability
    input_size = 248
    output_size = 4
    network = RNN(input_size, hidden_size, output_size)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(network.parameters(), lr=lr)

    files = os.listdir(path)
    random.shuffle(files) # Shuffle order of files
    current_samples = []
    n = 8
    batch_index = 1

    for i, file in tqdm(enumerate(files)):
        current_samples.append(file)
        if len(current_samples) == n or i == (len(files)-1):
            if print_results:
                print(f"training batch {batch_index}...")
            X_train, y_train = preprocess_files(current_samples, path=path, downsampling=downsampling) 
            current_samples = []
            
            network.train()
            opt.zero_grad()
            output = network(X_train)
            loss = loss_fn(output, y_train)
            loss.backward()
            opt.step()
            
            batch_index += 1
    return network

In [8]:
# testing:
def test(network, paths):
    for path in paths:
        files = os.listdir(path)
        X, y = preprocess_files(files, path, 1)
        network.eval()

        test_output = network(X).detach().numpy()
        pred = np.argmax(test_output, axis=1) # to numpy
        y = y.numpy()
        return accuracy_score(pred, y)


In [9]:
def tune_hyperparams_rnn(data_type, lr_list, hidden_size_list, downsampling_list):
    path_training = f'Final Project data/{data_type}/train'
    if data_type == 'Cross':
        paths_testing = [ 'Final Project data/Cross/test1',  'Final Project data/Cross/test2',  'Final Project data/Cross/test3']
    else:
        paths_testing = ['Final Project data/Intra/test']
        
    results = pd.DataFrame(columns = ['lr', 'hidden_size', 'downsampling', 'acc'])
    
    for lr in lr_list:
        for hidden_size in hidden_size_list:
            for downsampling in downsampling_list:
                network = train(path_training, lr=lr, hidden_size=hidden_size, downsampling=downsampling, print_results=False)
                acc = test(network=network, paths=paths_testing)
                results = pd.concat([results, pd.DataFrame({'lr':[lr], 'hidden_size':[hidden_size], 'downsampling':[downsampling], 'acc':[acc]})])
    results.to_csv(f'results/{data_type}_results.csv')
    return results

In [18]:
tune_hyperparams_rnn('Cross', lr_list=[0.01, 0.001, 0.0001], hidden_size_list=[100, 150, 200, 250], downsampling_list=[1, 5, 15, 30])

0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.61it/s]

training batch 2...


16it [00:09,  1.69it/s]

training batch 3...


24it [00:14,  1.67it/s]

training batch 4...


32it [00:19,  1.57it/s]

training batch 5...


40it [00:25,  1.54it/s]

training batch 6...


48it [00:30,  1.52it/s]

training batch 7...


56it [00:36,  1.52it/s]

training batch 8...


64it [00:41,  1.54it/s]
  results = pd.concat([results, pd.DataFrame({'lr':[lr], 'hidden_size':[hidden_size], 'downsampling':[downsampling], 'acc':[acc]})])
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.43it/s]

training batch 2...


16it [00:06,  2.36it/s]

training batch 3...


24it [00:09,  2.42it/s]

training batch 4...


32it [00:13,  2.38it/s]

training batch 5...


40it [00:16,  2.38it/s]

training batch 6...


48it [00:20,  2.37it/s]

training batch 7...


56it [00:23,  2.38it/s]

training batch 8...


64it [00:26,  2.37it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.57it/s]

training batch 2...


16it [00:06,  2.68it/s]

training batch 3...


24it [00:08,  2.74it/s]

training batch 4...


32it [00:11,  2.77it/s]

training batch 5...


40it [00:14,  2.75it/s]

training batch 6...


48it [00:17,  2.72it/s]

training batch 7...


56it [00:20,  2.71it/s]

training batch 8...


64it [00:23,  2.70it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:06,  1.21it/s]

training batch 2...


16it [00:12,  1.24it/s]

training batch 3...


24it [00:19,  1.25it/s]

training batch 4...


32it [00:25,  1.24it/s]

training batch 5...


40it [00:32,  1.24it/s]

training batch 6...


48it [00:39,  1.22it/s]

training batch 7...


56it [00:45,  1.22it/s]

training batch 8...


64it [00:52,  1.22it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.86it/s]

training batch 2...


16it [00:08,  1.97it/s]

training batch 3...


24it [00:11,  2.07it/s]

training batch 4...


32it [00:15,  2.08it/s]

training batch 5...


40it [00:19,  2.11it/s]

training batch 6...


48it [00:23,  2.08it/s]

training batch 7...


56it [00:27,  2.10it/s]

training batch 8...


64it [00:30,  2.08it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.19it/s]

training batch 2...


16it [00:06,  2.41it/s]

training batch 3...


24it [00:10,  2.38it/s]

training batch 4...


32it [00:14,  2.25it/s]

training batch 5...


40it [00:17,  2.32it/s]

training batch 6...


48it [00:20,  2.32it/s]

training batch 7...


56it [00:24,  2.32it/s]

training batch 8...


64it [00:27,  2.32it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:08,  1.02s/it]

training batch 2...


16it [00:16,  1.06s/it]

training batch 3...


24it [00:26,  1.12s/it]

training batch 4...


32it [00:36,  1.15s/it]

training batch 5...


40it [00:44,  1.12s/it]

training batch 6...


48it [00:52,  1.06s/it]

training batch 7...


56it [00:59,  1.03s/it]

training batch 8...


64it [01:07,  1.05s/it]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.80it/s]

training batch 2...


16it [00:08,  1.84it/s]

training batch 3...


24it [00:12,  1.88it/s]

training batch 4...


32it [00:19,  1.55it/s]

training batch 5...


40it [00:23,  1.67it/s]

training batch 6...


48it [00:28,  1.70it/s]

training batch 7...


56it [00:32,  1.73it/s]

training batch 8...


64it [00:36,  1.73it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.05it/s]

training batch 2...


16it [00:07,  2.24it/s]

training batch 3...


24it [00:10,  2.31it/s]

training batch 4...


32it [00:13,  2.34it/s]

training batch 5...


40it [00:17,  2.24it/s]

training batch 6...


48it [00:21,  2.13it/s]

training batch 7...


56it [00:25,  2.13it/s]

training batch 8...


64it [00:29,  2.17it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:05,  1.37it/s]

training batch 2...


16it [00:11,  1.42it/s]

training batch 3...


24it [00:20,  1.08it/s]

training batch 4...


32it [00:25,  1.27it/s]

training batch 5...


40it [00:30,  1.39it/s]

training batch 6...


48it [00:35,  1.47it/s]

training batch 7...


56it [00:39,  1.53it/s]

training batch 8...


64it [00:44,  1.44it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:02,  2.79it/s]

training batch 2...


16it [00:05,  2.80it/s]

training batch 3...


24it [00:08,  2.85it/s]

training batch 4...


32it [00:11,  2.86it/s]

training batch 5...


40it [00:14,  2.84it/s]

training batch 6...


48it [00:17,  2.78it/s]

training batch 7...


56it [00:19,  2.78it/s]

training batch 8...


64it [00:23,  2.77it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.03it/s]

training batch 2...


16it [00:07,  2.29it/s]

training batch 3...


24it [00:10,  2.30it/s]

training batch 4...


32it [00:13,  2.43it/s]

training batch 5...


40it [00:16,  2.48it/s]

training batch 6...


48it [00:19,  2.49it/s]

training batch 7...


56it [00:23,  2.50it/s]

training batch 8...


64it [00:26,  2.45it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:07,  1.03it/s]

training batch 2...


16it [00:14,  1.09it/s]

training batch 3...


24it [00:21,  1.15it/s]

training batch 4...


32it [00:27,  1.17it/s]

training batch 5...


40it [00:34,  1.19it/s]

training batch 6...


48it [00:41,  1.19it/s]

training batch 7...


56it [00:47,  1.20it/s]

training batch 8...


64it [00:54,  1.18it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.68it/s]

training batch 2...


16it [00:08,  1.90it/s]

training batch 3...


24it [00:13,  1.76it/s]

training batch 4...


32it [00:18,  1.67it/s]

training batch 5...


40it [00:24,  1.58it/s]

training batch 6...


48it [00:28,  1.63it/s]

training batch 7...


56it [00:33,  1.61it/s]

training batch 8...


64it [00:38,  1.65it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.99it/s]

training batch 2...


16it [00:07,  2.20it/s]

training batch 3...


24it [00:10,  2.29it/s]

training batch 4...


32it [00:13,  2.37it/s]

training batch 5...


40it [00:17,  2.44it/s]

training batch 6...


48it [00:20,  2.44it/s]

training batch 7...


56it [00:23,  2.42it/s]

training batch 8...


64it [00:27,  2.35it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:08,  1.08s/it]

training batch 2...


16it [00:19,  1.26s/it]

training batch 3...


24it [00:29,  1.22s/it]

training batch 4...


32it [00:37,  1.16s/it]

training batch 5...


40it [00:45,  1.10s/it]

training batch 6...


48it [00:53,  1.07s/it]

training batch 7...


56it [01:02,  1.06s/it]

training batch 8...


64it [01:10,  1.10s/it]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.68it/s]

training batch 2...


16it [00:08,  1.82it/s]

training batch 3...


24it [00:13,  1.78it/s]

training batch 4...


32it [00:17,  1.83it/s]

training batch 5...


40it [00:22,  1.82it/s]

training batch 6...


48it [00:26,  1.84it/s]

training batch 7...


56it [00:30,  1.86it/s]

training batch 8...


64it [00:34,  1.84it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:03,  2.08it/s]

training batch 2...


16it [00:07,  2.17it/s]

training batch 3...


24it [00:11,  2.18it/s]

training batch 4...


32it [00:14,  2.25it/s]

training batch 5...


40it [00:17,  2.27it/s]

training batch 6...


48it [00:21,  2.21it/s]

training batch 7...


56it [00:25,  2.23it/s]

training batch 8...


64it [00:28,  2.23it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:05,  1.40it/s]

training batch 2...


16it [00:11,  1.43it/s]

training batch 3...


24it [00:17,  1.34it/s]

training batch 4...


32it [00:25,  1.23it/s]

training batch 5...


40it [00:31,  1.21it/s]

training batch 6...


48it [00:38,  1.17it/s]

training batch 7...


56it [00:46,  1.16it/s]

training batch 8...


64it [00:52,  1.22it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:05,  1.50it/s]

training batch 2...


16it [00:11,  1.43it/s]

training batch 3...


24it [00:16,  1.43it/s]

training batch 4...


32it [00:22,  1.39it/s]

training batch 5...


40it [00:29,  1.33it/s]

training batch 6...


48it [00:34,  1.37it/s]

training batch 7...


56it [00:40,  1.40it/s]

training batch 8...


64it [00:45,  1.42it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.73it/s]

training batch 2...


16it [00:09,  1.74it/s]

training batch 3...


24it [00:13,  1.77it/s]

training batch 4...


32it [00:18,  1.76it/s]

training batch 5...


40it [00:22,  1.74it/s]

training batch 6...


48it [00:27,  1.78it/s]

training batch 7...


56it [00:31,  1.78it/s]

training batch 8...


64it [00:36,  1.77it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:11,  1.42s/it]

training batch 2...


16it [00:22,  1.41s/it]

training batch 3...


24it [00:33,  1.41s/it]

training batch 4...


32it [00:45,  1.43s/it]

training batch 5...


40it [00:57,  1.43s/it]

training batch 6...


48it [01:08,  1.45s/it]

training batch 7...


56it [01:21,  1.49s/it]

training batch 8...


64it [01:33,  1.45s/it]
0it [00:00, ?it/s]

training batch 1...


8it [00:06,  1.19it/s]

training batch 2...


16it [00:13,  1.21it/s]

training batch 3...


24it [00:19,  1.27it/s]

training batch 4...


32it [00:25,  1.24it/s]

training batch 5...


40it [00:32,  1.23it/s]

training batch 6...


48it [00:37,  1.35it/s]

training batch 7...


56it [00:43,  1.31it/s]

training batch 8...


64it [00:50,  1.28it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:06,  1.29it/s]

training batch 2...


16it [00:12,  1.33it/s]

training batch 3...


24it [00:18,  1.25it/s]

training batch 4...


32it [00:24,  1.32it/s]

training batch 5...


40it [00:30,  1.36it/s]

training batch 6...


48it [00:35,  1.40it/s]

training batch 7...


56it [00:40,  1.42it/s]

training batch 8...


64it [00:46,  1.37it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:15,  1.98s/it]

training batch 2...


16it [00:31,  1.97s/it]

training batch 3...


24it [00:47,  1.98s/it]

training batch 4...


32it [01:03,  1.99s/it]

training batch 5...


40it [01:19,  1.98s/it]

training batch 6...


48it [01:34,  1.97s/it]

training batch 7...


56it [01:50,  1.97s/it]

training batch 8...


64it [02:07,  1.99s/it]
0it [00:00, ?it/s]

training batch 1...


8it [00:07,  1.06it/s]

training batch 2...


16it [00:16,  1.03s/it]

training batch 3...


24it [00:23,  1.05it/s]

training batch 4...


32it [00:31,  1.01it/s]

training batch 5...


40it [00:39,  1.02it/s]

training batch 6...


48it [00:46,  1.05it/s]

training batch 7...


56it [00:52,  1.11it/s]

training batch 8...


64it [00:59,  1.08it/s]
0it [00:00, ?it/s]

training batch 1...


8it [00:04,  1.63it/s]

training batch 2...


16it [00:09,  1.70it/s]

training batch 3...


24it [00:13,  1.77it/s]

training batch 4...


32it [00:18,  1.81it/s]

training batch 5...


40it [00:22,  1.86it/s]

training batch 6...


48it [00:26,  1.86it/s]

training batch 7...


56it [00:30,  1.85it/s]

training batch 8...


64it [00:35,  1.82it/s]


Unnamed: 0,lr,hidden_size,downsampling,acc
0,0.01,100,5,0.9375
0,0.01,100,15,0.8125
0,0.01,100,30,0.75
0,0.01,200,5,0.625
0,0.01,200,15,0.875
0,0.01,200,30,0.75
0,0.01,300,5,0.4375
0,0.01,300,15,0.3125
0,0.01,300,30,0.3125
0,0.001,100,5,0.875


In [10]:
tune_hyperparams_rnn('Intra', lr_list=[0.01, 0.001, 0.0001], hidden_size_list=[100, 150, 200, 250], downsampling_list=[1, 5, 15, 30])

32it [00:19,  1.62it/s]
  results = pd.concat([results, pd.DataFrame({'lr':[lr], 'hidden_size':[hidden_size], 'downsampling':[downsampling], 'acc':[acc]})])
32it [00:14,  2.23it/s]
32it [00:16,  1.94it/s]
32it [00:14,  2.17it/s]
32it [00:19,  1.60it/s]
32it [00:17,  1.84it/s]
32it [00:18,  1.74it/s]
32it [00:15,  2.00it/s]
32it [00:16,  1.98it/s]
32it [00:15,  2.10it/s]
32it [00:19,  1.67it/s]
32it [00:16,  1.95it/s]
32it [00:19,  1.61it/s]
32it [00:14,  2.19it/s]
32it [00:21,  1.51it/s]
32it [00:15,  2.03it/s]
32it [00:15,  2.10it/s]
32it [00:13,  2.43it/s]
32it [00:16,  1.97it/s]
32it [00:14,  2.25it/s]
32it [00:16,  1.90it/s]
32it [00:15,  2.03it/s]
32it [00:21,  1.48it/s]
32it [00:18,  1.76it/s]


Unnamed: 0,lr,hidden_size,downsampling,acc
0,0.01,100,15,1.0
0,0.01,100,30,1.0
0,0.01,150,15,1.0
0,0.01,150,30,0.75
0,0.01,200,15,1.0
0,0.01,200,30,1.0
0,0.01,250,15,1.0
0,0.01,250,30,0.625
0,0.001,100,15,0.75
0,0.001,100,30,0.75
