In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, ConcatDataset
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse, subgraph, k_hop_subgraph
import torch_geometric.nn as pyg_nn
from scipy.io import loadmat, savemat
import os
import matplotlib.pyplot as plt


In [52]:
# Load saved beta data from mat file
analysis_type = 'N170'
curr_dir = os.getcwd()
datain_mat = loadmat(os.path.join(curr_dir, analysis_type + '_data.mat'))
binatry_matrix_mat = loadmat(os.path.join(curr_dir, 'adjacency_matrix.mat'))
datain = datain_mat['data']
binatry_matrix = binatry_matrix_mat['adjacency_matrix']

datain = np.array(datain)          # shape: [nBeta, nTime, nChan, nSubject] (adjust as needed)
binatry_matrix = np.array(binatry_matrix)

# Parse shapes
nBeta     = datain.shape[0]
nTime     = datain.shape[1]
nChan     = datain.shape[2]
nSubject  = datain.shape[3]

In [53]:
# ===========================================
# 2) Set device & convert adjacency
# ===========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
edge_index, _ = dense_to_sparse(torch.tensor(binatry_matrix, dtype=torch.float32))
print(device)

# ===========================================
# 3) Dataset builder
# ===========================================
def build_concat_dataset_for_beta(datain, beta_idx):
    """
    Build a ConcatDataset across all subjects for one beta index.
    `datain` is [nBeta, nTime, nChan, nSubject].
    Returns: ConcatDataset of length nSubject.
    """
    subject_datasets = []
    for s in range(nSubject):
        # Extract slice for one subject and one beta
        # shape: [nTime, nChan]
        single_beta_data = datain[beta_idx, :, :, s]

        # If each node is a channel and its feature dimension is time,
        # we might want single_beta_data to be [nChan, nTime].
        # If so, transpose here:
        single_beta_data = single_beta_data.T  # shape: [nChan, nTime]

        # For an autoencoder, features=labels
        x_tensor = torch.tensor(single_beta_data, dtype=torch.float32)
        x_tensor = x_tensor.unsqueeze(0)

        subject_datasets.append(TensorDataset(x_tensor, x_tensor))

    return ConcatDataset(subject_datasets)

cuda


In [54]:
# ===========================================
# 4) PyG Graph Dataset
# ===========================================
class EEGGraphDataset(torch.utils.data.Dataset):
    """
    Wraps a standard dataset so that each item is a PyG Data object:
      x -> node features (channels)
      edge_index -> adjacency
    """
    def __init__(self, eeg_data, edge_index):
        self.eeg_data = eeg_data
        self.edge_index = edge_index

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        # item: (x, x) because it's autoencoder
        sample = self.eeg_data[idx]
        x = sample[0]  # shape: [nChan, nTime] (if you used transpose above)
        # Convert x -> float32, build PyG Data
        graph_data = Data(x=torch.tensor(x, dtype=torch.float32),
                          edge_index=self.edge_index)
        return graph_data


In [55]:
# ===========================================
# 5) Define GraphAutoencoder
# ===========================================
class GraphAutoencoder(nn.Module):
    def __init__(self, num_features, embedding_dim=64):
        """
        num_features = dimension of each node's feature vector.
                       If each channel is a node, and you pass
                       [nChan, nTime], then num_features = nTime.
        """
        super(GraphAutoencoder, self).__init__()
        # Example with two ARMAConv layers
        self.encoder_gcn1 = pyg_nn.ARMAConv(num_features, 128, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.encoder_gcn2 = pyg_nn.ARMAConv(128, embedding_dim, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.decoder_fc1 = nn.Linear(embedding_dim, 128)
        self.decoder_fc2 = nn.Linear(128, num_features)

    def forward(self, x, edge_index):
        x = torch.relu(self.encoder_gcn1(x, edge_index))
        latent = torch.relu(self.encoder_gcn2(x, edge_index))
        x = torch.relu(self.decoder_fc1(latent))
        reconstructed = self.decoder_fc2(x)
        return reconstructed

In [56]:
# ===========================================
# 6) Training and evaluation
# ===========================================
def train_model_all(model, train_loader, device, num_epochs):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.SmoothL1Loss()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_data in train_loader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            # Forward
            reconstructed = model(batch_data.x, batch_data.edge_index)
            loss = criterion(reconstructed, batch_data.x)

            # Backprop
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)
        print(f"Beta Model Training - Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    return model

def evaluation_model(model, data_sample, edge_index, device):
    """
    data_sample: shape [nChan, nTime], or [nTime, nChan],
                 whichever you used in training.
    Returns: (original_np, reconstructed_np, error_np)
    """
    model.eval()
    with torch.no_grad():
        x = torch.tensor(data_sample, dtype=torch.float32, device=device)
        eidx = edge_index.to(device)
        reconstructed = model(x, eidx)
    orig = x.cpu().numpy()
    recon = reconstructed.cpu().numpy()
    error = orig - recon
    return orig, recon, error


In [64]:
# Channel Masking

# Define channel masking function with 'zero' and 'mean' methods
def channel_masking(x_n, ch_idx, mode='mean'):
    x_masked = x_n.copy()
    if mode == 'zero':
        x_masked[ch_idx, :] = 0.0
    # Mean of all other channels
    elif mode == 'mean':
        channel_mean = x_n.mean(axis=0)
        x_masked[ch_idx, :] = channel_mean
    return x_masked

# Apply channel masking and compute importance
def apply_channel_masking(model, sample_n, edge_index, device, mode='mean'):
    model.eval()
    with torch.no_grad():
        x = torch.tensor(sample_n, dtype=torch.float32, device=device)
        recon = model(x, edge_index.to(device))
        base_lose = nn.SmoothL1Loss()(recon, x).item()
        
        nChan = sample_n.shape[0]
        imp = np.zeros(nChan)
        
        # Iterate over channels to compute importance
        # The higher the loss increase, the more important the channel
        for c in range(nChan):
            x_masked = channel_masking(sample_n, c, mode=mode)
            x_masked = torch.tensor(x_masked, dtype=torch.float32, device=device)
            recon_masked = model(x_masked, edge_index.to(device))
            masked_loss = nn.SmoothL1Loss()(recon_masked, x_masked).item()
            imp[c] = abs(masked_loss - base_lose)
            
    return imp

In [65]:
# Channel dropping

# Define channel dropping method
def channel_dropping(model, sample_n, edge_index, device):
    model.eval()
    with torch.no_grad():
        x = torch.tensor(sample_n, dtype=torch.float32, device=device)
        recon = model(x, edge_index.to(device))
        base_loss = nn.SmoothL1Loss()(recon, x).item()

        nChan = sample_n.shape[0]
        imp = np.zeros(nChan)
        all_channels = torch.arange(nChan, device=device)
        e_idx = edge_index.to(device)
        
        for c in range(nChan):
            mask = torch.ones(nChan, dtype=torch.bool, device=device)
            mask[c] = False
            masked_channels = all_channels[mask]
            # Create subgraph with masked channels, dropping the channel c from the graph
            masked_edge_index = subgraph(masked_channels, e_idx, relabel_nodes=True)[0]
            # Drop the channel c from the data x
            masked_x = x[masked_channels, :]
            recon_masked = model(masked_x, masked_edge_index)
            masked_loss = nn.SmoothL1Loss()(recon_masked, masked_x).item()
            imp[c] = abs(masked_loss - base_loss)
            
    return imp

In [69]:
# Channel region hard masking

# Compute k-hop neighborhood channels for each channel
def compute_neighborhood_channels(edge_index, k=1):
    neighborhood_channels = []
    for c in range(edge_index.max().item() + 1):
        channels, _, _, _ = k_hop_subgraph(c, k, edge_index)
        neighborhood_channels.append(channels.cpu().numpy())
    return neighborhood_channels

# Define region hard masking function with 'zero' and 'mean' methods
def region_hard_masking(x_n, neig_list, mode='mean'):
    x_masked = x_n.copy()

    # Get the rest of the channels except the neighborhood list
    nChan = x_n.shape[0]
    mask_neig = np.zeros(nChan, dtype=bool)
    mask_neig[neig_list] = True
    mask_rest = ~mask_neig
    
    if mode == 'zero':
        x_masked[mask_neig, :] = 0.0
    # The mean of the rest of the channels
    elif mode == 'mean':
        if mask_rest.any():
            rest_mean = x_n[mask_rest, :].mean(axis=0)
        # If no rest channels, use overall mean
        else:
            rest_mean = x_n.mean(axis=0)
        x_masked[mask_neig, :] = rest_mean
    return x_masked

# Apply region hard masking and compute importance
def apply_region_hard_masking(model, sample_n, edge_index, device, mode='mean', k=1):
    model.eval()
    with torch.no_grad():
        neighborhood_channels = compute_neighborhood_channels(edge_index, k)
        x = torch.tensor(sample_n, dtype=torch.float32, device=device)
        recon = model(x, edge_index.to(device))
        base_loss = nn.SmoothL1Loss()(recon, x).item()

        nChan = sample_n.shape[0]
        imp = np.zeros(nChan)
        
        for c in range(nChan):
            neig_list = neighborhood_channels[c]
            x_masked = region_hard_masking(sample_n, neig_list, mode=mode)
            x_masked = torch.tensor(x_masked, dtype=torch.float32, device=device)
            recon_masked = model(x_masked, edge_index.to(device))
            masked_loss = nn.SmoothL1Loss()(recon_masked, x_masked).item()
            imp[c] = abs(masked_loss - base_loss)
    return imp

In [70]:
# Subject masking
def apply_subject_masking(model, full_dataset_i, edge_index, device, mode='mean'):
    model.eval()
    nSubjects = len(full_dataset_i)
    imp = np.zeros(nSubjects)
    
    x_list = []
    for s in range(nSubjects):
        x_list.append(full_dataset_i[s][0].numpy())
    # Shape: [nSubject, nChan, nTime]
    x_all = np.stack(x_list, axis=0)
    x_sum = x_all.sum(axis=0)
    
    for s in range(nSubjects):
        x_s =  x_list[s]
        x_s = torch.tensor(x_s, dtype=torch.float32, device=device)
        recon = model(x_s, edge_index.to(device))
        base_loss = nn.SmoothL1Loss()(recon, x_s).item()
        
        if mode == 'zero':
            x_masked = x_s.cpu().numpy()
            x_masked[:, :] = 0.0
        elif mode == 'mean':
            mean_other = (x_sum - x_list[s]) / (nSubjects - 1)
            x_masked = mean_other
        x_masked = torch.tensor(x_masked, dtype=torch.float32, device=device)
        recon_masked = model(x_masked, edge_index.to(device))
        mask_loss = nn.SmoothL1Loss()(recon_masked, x_masked).item()
        imp[s] = abs(mask_loss - base_loss)
    return imp

In [75]:
num_features = nTime   # if each channel is a node, and feature vector = [nTime]
num_epochs   = 200     # number of training - set as needed
batch_size   = 4;      # dataset/batch to update hyperparameters

models = []

imp_chan_zero_list = []
imp_chan_drop_list = []
imp_chan_mean_list = []
imp_region_mean_list = []
imp_subject_mean_list = []
for iBeta in range(nBeta):
    print(f"\n=== Building dataset/model for Beta {iBeta} ===")
    # Build dataset & loader
    full_dataset_i  = build_concat_dataset_for_beta(datain, iBeta)
    graph_dataset_i = EEGGraphDataset(full_dataset_i, edge_index)
    full_loader_i   = GeometricDataLoader(graph_dataset_i, batch_size=batch_size, shuffle=True)

    # Init model & train
    model_i = GraphAutoencoder(num_features=num_features, embedding_dim=32)
    model_i = train_model_all(model_i, full_loader_i, device, num_epochs=num_epochs)
    models.append(model_i)
    
    imp_subject_mean = apply_subject_masking(model_i, full_dataset_i, edge_index, device, mode='mean')
    imp_subject_mean_list.append(np.array(imp_subject_mean))

    imp_chan_zero_per_subject = []
    imp_chan_drop_per_subject = []
    imp_chan_mean_per_subject = []
    imp_region_mean_per_subject = []
    for s in range(len(full_dataset_i)):
        x_s = full_dataset_i[s][0].numpy()  # [nChan, nTime]
        imp_chan_zero = apply_channel_masking(model_i, x_s, edge_index, device, mode='zero')
        imp_chan_mean = apply_channel_masking(model_i, x_s, edge_index, device, mode='mean')
        imp_chan_drop = channel_dropping(model_i, x_s, edge_index, device)
        imp_region_mean = apply_region_hard_masking(model_i, x_s, edge_index, device, mode='mean', k=1)
        imp_chan_zero_per_subject.append(imp_chan_zero)
        imp_chan_drop_per_subject.append(imp_chan_drop)
        imp_chan_mean_per_subject.append(imp_chan_mean)
        imp_region_mean_per_subject.append(imp_region_mean)

    imp_chan_zero_list.append(np.array(imp_chan_zero_per_subject))
    imp_chan_drop_list.append(np.array(imp_chan_drop_per_subject))
    imp_chan_mean_list.append(np.array(imp_chan_mean_per_subject))
    imp_region_mean_list.append(np.array(imp_region_mean_per_subject))

all_imp_chan_zero = np.array(imp_chan_zero_list)
all_imp_chan_drop = np.array(imp_chan_drop_list)
all_imp_chan_mean = np.array(imp_chan_mean_list)
all_imp_region_mean = np.array(imp_region_mean_list)
all_imp_subject_mean = np.array(imp_subject_mean_list)



=== Building dataset/model for Beta 0 ===
Beta Model Training - Epoch [1/200], Loss: 0.7608
Beta Model Training - Epoch [2/200], Loss: 0.5626


  graph_data = Data(x=torch.tensor(x, dtype=torch.float32),


Beta Model Training - Epoch [3/200], Loss: 0.4204
Beta Model Training - Epoch [4/200], Loss: 0.3673
Beta Model Training - Epoch [5/200], Loss: 0.3407
Beta Model Training - Epoch [6/200], Loss: 0.3208
Beta Model Training - Epoch [7/200], Loss: 0.3054
Beta Model Training - Epoch [8/200], Loss: 0.2844
Beta Model Training - Epoch [9/200], Loss: 0.2616
Beta Model Training - Epoch [10/200], Loss: 0.2408
Beta Model Training - Epoch [11/200], Loss: 0.2237
Beta Model Training - Epoch [12/200], Loss: 0.2086
Beta Model Training - Epoch [13/200], Loss: 0.1988
Beta Model Training - Epoch [14/200], Loss: 0.1890
Beta Model Training - Epoch [15/200], Loss: 0.1785
Beta Model Training - Epoch [16/200], Loss: 0.1701
Beta Model Training - Epoch [17/200], Loss: 0.1667
Beta Model Training - Epoch [18/200], Loss: 0.1589
Beta Model Training - Epoch [19/200], Loss: 0.1571
Beta Model Training - Epoch [20/200], Loss: 0.1554
Beta Model Training - Epoch [21/200], Loss: 0.1473
Beta Model Training - Epoch [22/200], 

In [77]:
# save importances results
save_dir = os.path.join(curr_dir, analysis_type, 'Masking')
os.makedirs(save_dir, exist_ok=True)

save_subdir_chan = os.path.join(save_dir, 'Channel Masking')
os.makedirs(save_subdir_chan, exist_ok=True)

save_subdir_region = os.path.join(save_dir, 'Region Masking')
os.makedirs(save_subdir_region, exist_ok=True)

save_subdir_subject = os.path.join(save_dir, 'Subject Masking')
os.makedirs(save_subdir_subject, exist_ok=True)

imp_chan_zero_data = {'all_imp_chan_zero': all_imp_chan_zero}
imp_chan_drop_data = {'all_imp_chan_drop': all_imp_chan_drop}
imp_chan_mean_data = {'all_imp_chan_mean': all_imp_chan_mean}
imp_region_mean_data = {'all_imp_region_mean': all_imp_region_mean}
imp_subject_mean_data = {'all_imp_subject_mean': all_imp_subject_mean}

savemat(os.path.join(save_subdir_chan, 'imp_chan_zero.mat'), imp_chan_zero_data)
savemat(os.path.join(save_subdir_chan, 'imp_chan_drop.mat'), imp_chan_drop_data)
savemat(os.path.join(save_subdir_chan, 'imp_chan_mean.mat'), imp_chan_mean_data)
savemat(os.path.join(save_subdir_region, 'imp_region_mean.mat'), imp_region_mean_data)
savemat(os.path.join(save_subdir_subject, 'imp_subject_mean.mat'), imp_subject_mean_data)


In [None]:
# Visualize importances of channel masking zero method

values = all_imp_chan_zero.mean(axis=1)  # mean across subjects
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type + '_Channel Importance_zero')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_chan, analysis_type + '_Channel Importance_zero.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_chan_zero[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Channel Importance_zero')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_chan, analysis_type + f'_Beta_{iBeta}_Channel_Importance_zero.png'))
    #plt.show()
    plt.close()

In [None]:
# Visualize importances of channel masking mean method

values = all_imp_chan_mean.mean(axis=1)  # mean across subjects
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type + '_Channel Importance_mean')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_chan, analysis_type + '_Channel Importance_mean.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_chan_mean[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Channel Importance_mean')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_chan, analysis_type + f'_Beta_{iBeta}_Channel_Importance_mean.png'))
    #plt.show()
    plt.close()

In [None]:
# Visualize importances of channel masking dropping method

values = all_imp_chan_drop.mean(axis=1)  # mean across subjects
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type+'_Channel Importance_drop')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_chan, analysis_type + '_Channel Importance_drop.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_chan_drop[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Channel Importance_drop')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_chan, analysis_type + f'_Beta_{iBeta}_Channel_Importance_drop.png'))
    #plt.show()
    plt.close()

In [None]:
# Visualize importances of region hard masking

values = all_imp_region_mean.mean(axis=1)
plt.figure(figsize=(16,13))
plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Mean importance')
plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
plt.xlabel('Channel')
plt.ylabel('Beta')
plt.title(analysis_type+'_Region Importance_mean')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_region, analysis_type + '_Region Importance_mean.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_region_mean[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values, aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nChan), labels=np.arange(1, nChan+1))
    plt.xlabel('Channel')
    plt.ylabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Region Importance_mean')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_region, analysis_type + f'_Beta_{iBeta}_Region_Importance_mean.png'))
    #plt.show()
    plt.close()

In [81]:
# Visualize importances of subject masking

plt.figure(figsize=(16,13))
plt.imshow(all_imp_subject_mean, aspect='auto', origin='lower', cmap='Reds')
plt.colorbar(label='Importance')
plt.xticks(ticks=np.arange(nSubject), labels=np.arange(1, nSubject+1))
plt.xlabel('Subject')
plt.ylabel('Beta')
plt.title(analysis_type+'_Subject Importance_mean')
plt.tight_layout()
plt.savefig(os.path.join(save_subdir_subject, analysis_type + '_Subject Importance_mean.png'))
#plt.show()
plt.close()

for iBeta in range(nBeta):
    values = all_imp_subject_mean[iBeta]
    plt.figure(figsize=(16,13))
    plt.imshow(values.reshape(1, -1), aspect='auto', origin='lower', cmap='Reds')
    plt.colorbar(label='Importance')
    plt.xticks(ticks=np.arange(nSubject), labels=np.arange(1, nSubject+1))
    plt.xlabel('Subject')
    plt.title(f'Beta {iBeta}_{analysis_type}_Subject Importance_mean')
    plt.tight_layout()
    plt.savefig(os.path.join(save_subdir_subject, analysis_type + f'_Beta_{iBeta}_Subject_Importance_mean.png'))
    #plt.show()
    plt.close()