In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, ConcatDataset
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse
import torch_geometric.nn as pyg_nn
import os

from scipy.io import loadmat, savemat

In [25]:
# Test - load saved beta data from mat file
analysis_type = 'N170'
datain_mat = loadmat(analysis_type + '_data.mat')
binatry_matrix_mat = loadmat('adjacency_matrix.mat')
datain = datain_mat['data']
binatry_matrix = binatry_matrix_mat['adjacency_matrix']

datain = np.array(datain)          # shape: [nBeta, nTime, nChan, nSubject] (adjust as needed)
binatry_matrix = np.array(binatry_matrix)

# Parse shapes
nBeta     = datain.shape[0]
nTime     = datain.shape[1]
nChan     = datain.shape[2]
nSubject  = datain.shape[3]

In [4]:
# ===========================================
# 2) Set device & convert adjacency
# ===========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
edge_index, _ = dense_to_sparse(torch.tensor(binatry_matrix, dtype=torch.float32))
print(device)

# ===========================================
# 3) Dataset builder
# ===========================================
def build_concat_dataset_for_beta(datain, beta_idx):
    """
    Build a ConcatDataset across all subjects for one beta index.
    `datain` is [nBeta, nTime, nChan, nSubject].
    Returns: ConcatDataset of length nSubject.
    """
    subject_datasets = []
    for s in range(nSubject):
        # Extract slice for one subject and one beta
        # shape: [nTime, nChan]
        single_beta_data = datain[beta_idx, :, :, s]

        # If each node is a channel and its feature dimension is time,
        # we might want single_beta_data to be [nChan, nTime].
        # If so, transpose here:
        single_beta_data = single_beta_data.T  # shape: [nChan, nTime]

        # For an autoencoder, features=labels
        x_tensor = torch.tensor(single_beta_data, dtype=torch.float32)
        x_tensor = x_tensor.unsqueeze(0)

        subject_datasets.append(TensorDataset(x_tensor, x_tensor))

    return ConcatDataset(subject_datasets)

cuda


In [5]:
# ===========================================
# 4) PyG Graph Dataset
# ===========================================
class EEGGraphDataset(torch.utils.data.Dataset):
    """
    Wraps a standard dataset so that each item is a PyG Data object:
      x -> node features (channels)
      edge_index -> adjacency
    """
    def __init__(self, eeg_data, edge_index):
        self.eeg_data = eeg_data
        self.edge_index = edge_index

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        # item: (x, x) because it's autoencoder
        sample = self.eeg_data[idx]
        x = sample[0]  # shape: [nChan, nTime] (if you used transpose above)
        # Convert x -> float32, build PyG Data
        graph_data = Data(x=torch.tensor(x, dtype=torch.float32),
                          edge_index=self.edge_index)
        return graph_data


In [6]:
# ===========================================
# 5) Define GraphAutoencoder
# ===========================================
class GraphAutoencoder(nn.Module):
    def __init__(self, num_features, embedding_dim=64):
        """
        num_features = dimension of each node's feature vector.
                       If each channel is a node, and you pass
                       [nChan, nTime], then num_features = nTime.
        """
        super(GraphAutoencoder, self).__init__()
        # Example with two ARMAConv layers
        self.encoder_gcn1 = pyg_nn.ARMAConv(num_features, 128, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.encoder_gcn2 = pyg_nn.ARMAConv(128, embedding_dim, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.decoder_fc1 = nn.Linear(embedding_dim, 128)
        self.decoder_fc2 = nn.Linear(128, num_features)

    def forward(self, x, edge_index):
        x = torch.relu(self.encoder_gcn1(x, edge_index))
        latent = torch.relu(self.encoder_gcn2(x, edge_index))
        x = torch.relu(self.decoder_fc1(latent))
        reconstructed = self.decoder_fc2(x)
        return reconstructed

In [7]:
# ===========================================
# 6) Training and evaluation
# ===========================================
def train_model_all(model, train_loader, device, num_epochs):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.SmoothL1Loss()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_data in train_loader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            # Forward
            reconstructed = model(batch_data.x, batch_data.edge_index)
            loss = criterion(reconstructed, batch_data.x)

            # Backprop
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)
        print(f"Beta Model Training - Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    return model

def evaluation_model(model, data_sample, edge_index, device):
    """
    data_sample: shape [nChan, nTime], or [nTime, nChan],
                 whichever you used in training.
    Returns: (original_np, reconstructed_np, error_np)
    """
    model.eval()
    with torch.no_grad():
        x = torch.tensor(data_sample, dtype=torch.float32, device=device)
        eidx = edge_index.to(device)
        reconstructed = model(x, eidx)
    orig = x.cpu().numpy()
    recon = reconstructed.cpu().numpy()
    error = orig - recon
    return orig, recon, error


In [8]:
# Channel Masking

def channel_masking(x_n, ch_idx, mode='zero'):
    x_masked = x_n.copy()
    if mode == 'zero':
        x_masked[ch_idx, :] = 0.0
    elif mode == 'noise':
        mu = x_n[ch_idx, :].mean()
        sigma = x_n[ch_idx, :].std()
        x_masked[ch_idx, :] = np.random.normal(mu, sigma, size=x_n[ch_idx, :].shape)
    return x_masked

def apply_channel_masking(model, sample_n, edge_index, device, mode):
    model.eval()
    with torch.no_grad():
        x = torch.tensor(sample_n, dtype=torch.float32, device=device)
        recon = model(x, edge_index.to(device))
        base_lose = nn.SmoothL1Loss()(recon, x).item()
        
        nChan = sample_n.shape[0]
        imp = np.zeros(nChan)
        
        for c in range(nChan):
            x_masked_n = channel_masking(sample_n, c, mode=mode)
            x_masked = torch.tensor(x_masked_n, dtype=torch.float32, device=device)
            recon_masked = model(x_masked, edge_index.to(device))
            masked_loss = nn.SmoothL1Loss()(recon_masked, x_masked).item()
            imp[c] = masked_loss - base_lose
            
    return imp

In [9]:
# Channel dropping
from torch_geometric.utils import subgraph


def channel_dropping(model, sample_n, edge_index, device):
    model.eval()
    with torch.no_grad():
        x = torch.tensor(sample_n, dtype=torch.float32, device=device)
        recon = model(x, edge_index.to(device))
        base_loss = nn.SmoothL1Loss()(recon, x).item()

        nChan = sample_n.shape[0]
        imp = np.zeros(nChan)
        all_channels = torch.arange(nChan, device=device)
        e_idx = edge_index.to(device)
        
        for c in range(nChan):
            mask = torch.ones(nChan, dtype=torch.bool, device=device)
            mask[c] = False
            masked_channels = all_channels[mask]
            masked_edge_index = subgraph(masked_channels, e_idx, relabel_nodes=True)[0]
            masked_x = x[masked_channels, :]
            recon_masked = model(masked_x, masked_edge_index)
            masked_loss = nn.SmoothL1Loss()(recon_masked, masked_x).item()
            imp[c] = masked_loss - base_loss
            
    return imp

In [13]:
num_features = nTime   # if each channel is a node, and feature vector = [nTime]
num_epochs   = 200     # number of training - set as needed
batch_size   = 4;      # dataset/batch to update hyperparameters

models = []

imp_zero_list = []
imp_drop_list = []
for iBeta in range(nBeta):
    print(f"\n=== Building dataset/model for Beta {iBeta} ===")
    # Build dataset & loader
    full_dataset_i  = build_concat_dataset_for_beta(datain, iBeta)
    graph_dataset_i = EEGGraphDataset(full_dataset_i, edge_index)
    full_loader_i   = GeometricDataLoader(graph_dataset_i, batch_size=batch_size, shuffle=True)

    # Init model & train
    model_i = GraphAutoencoder(num_features=num_features, embedding_dim=32)
    model_i = train_model_all(model_i, full_loader_i, device, num_epochs=num_epochs)
    models.append(model_i)

    imp_zero_per_subject = []
    imp_drop_per_subject = []
    for s in range(len(full_dataset_i)):
        x_s = full_dataset_i[s][0].numpy()  # [nChan, nTime]
        imp_zero = apply_channel_masking(model_i, x_s, edge_index, device, mode='zero')
        imp_drop = channel_dropping(model_i, x_s, edge_index, device)
        imp_zero_per_subject.append(imp_zero)
        imp_drop_per_subject.append(imp_drop)
        
    imp_zero_list.append(np.array(imp_zero_per_subject))
    imp_drop_list.append(np.array(imp_drop_per_subject))
    
all_imp_zero = np.array(imp_zero_list)
all_imp_drop = np.array(imp_drop_list)




=== Building dataset/model for Beta 0 ===


  graph_data = Data(x=torch.tensor(x, dtype=torch.float32),


Beta Model Training - Epoch [1/200], Loss: 0.7662
Beta Model Training - Epoch [2/200], Loss: 0.5737
Beta Model Training - Epoch [3/200], Loss: 0.4263
Beta Model Training - Epoch [4/200], Loss: 0.3755
Beta Model Training - Epoch [5/200], Loss: 0.3442
Beta Model Training - Epoch [6/200], Loss: 0.3307
Beta Model Training - Epoch [7/200], Loss: 0.3215
Beta Model Training - Epoch [8/200], Loss: 0.3093
Beta Model Training - Epoch [9/200], Loss: 0.2891
Beta Model Training - Epoch [10/200], Loss: 0.2682
Beta Model Training - Epoch [11/200], Loss: 0.2469
Beta Model Training - Epoch [12/200], Loss: 0.2305
Beta Model Training - Epoch [13/200], Loss: 0.2190
Beta Model Training - Epoch [14/200], Loss: 0.2137
Beta Model Training - Epoch [15/200], Loss: 0.2035
Beta Model Training - Epoch [16/200], Loss: 0.1971
Beta Model Training - Epoch [17/200], Loss: 0.1799
Beta Model Training - Epoch [18/200], Loss: 0.1713
Beta Model Training - Epoch [19/200], Loss: 0.1661
Beta Model Training - Epoch [20/200], Lo

In [None]:
# save importances results
save_dir = os.path.join(analysis_type, 'Masking Results')
os.makedirs(save_dir, exist_ok=True)

save_subdir_chan = os.path.join(save_dir, 'Channel Masking')
os.makedirs(save_subdir_chan, exist_ok=True)

np.save(os.path.join(save_subdir_chan, 'imp_zero.npy'), all_imp_zero)
np.save(os.path.join(save_subdir_chan, 'imp_drop.npy'), all_imp_drop)

In [26]:
# Visualize importances of channel masking zero method
import matplotlib.pyplot as plt

values = all_imp_zero.mean(axis=1)  # mean across subjects
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type + '_Channel Importance_zero')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_chan, '_Channel Importance_zero.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_zero[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Channel Importance_zero')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_chan, f'Beta_{iBeta}_Channel_Importance_zero.png'))
    #plt.show()
    plt.close()

In [24]:
# Visualize importances of channel masking dropping method
import matplotlib.pyplot as plt

values = all_imp_drop.mean(axis=1)  # mean across subjects
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type+'_Channel Importance_drop')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_chan, '_Channel Importance_drop.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_drop[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Channel Importance_drop')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_chan, f'Beta_{iBeta}_Channel_Importance_drop.png'))
    #plt.show()
    plt.close()