In [4]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, ConcatDataset
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse
import torch_geometric.nn as pyg_nn
import os

from scipy.io import loadmat, savemat

In [5]:
# Test - load saved beta data from mat file
analysis_type = 'N170'
datain_mat = loadmat(analysis_type + '_data.mat')
binatry_matrix_mat = loadmat('adjacency_matrix.mat')
datain = datain_mat['data']
binatry_matrix = binatry_matrix_mat['adjacency_matrix']

datain = np.array(datain)          # shape: [nBeta, nTime, nChan, nSubject] (adjust as needed)
binatry_matrix = np.array(binatry_matrix)

# Parse shapes
nBeta     = datain.shape[0]
nTime     = datain.shape[1]
nChan     = datain.shape[2]
nSubject  = datain.shape[3]

In [None]:
# ===========================================
# 2) Set device & convert adjacency
# ===========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
edge_index, _ = dense_to_sparse(torch.tensor(binatry_matrix, dtype=torch.float32))
print(device)

# ===========================================
# 3) Dataset builder
# ===========================================
def build_concat_dataset_for_beta(datain, beta_idx):
    """
    Build a ConcatDataset across all subjects for one beta index.
    `datain` is [nBeta, nTime, nChan, nSubject].
    Returns: ConcatDataset of length nSubject.
    """
    subject_datasets = []
    for s in range(nSubject):
        # Extract slice for one subject and one beta
        # shape: [nTime, nChan]
        single_beta_data = datain[beta_idx, :, :, s]

        # If each node is a channel and its feature dimension is time,
        # we might want single_beta_data to be [nChan, nTime].
        # If so, transpose here:
        single_beta_data = single_beta_data.T  # shape: [nChan, nTime]

        # For an autoencoder, features=labels
        x_tensor = torch.tensor(single_beta_data, dtype=torch.float32)
        x_tensor = x_tensor.unsqueeze(0)

        subject_datasets.append(TensorDataset(x_tensor, x_tensor))

    return ConcatDataset(subject_datasets)