In [109]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, ConcatDataset
from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import Data
from torch_geometric.utils import dense_to_sparse
import torch_geometric.nn as pyg_nn
import os
import matplotlib.pyplot as plt
from collections import Counter
from scipy.io import loadmat, savemat

In [78]:
# Load saved beta data from mat files
analysis_type = 'N170'
curr_dir = os.getcwd()
datain_mat = loadmat(os.path.join(curr_dir, analysis_type + '_data.mat'))
binatry_matrix_mat = loadmat(os.path.join(curr_dir, 'adjacency_matrix.mat'))
datain = datain_mat['data']
binatry_matrix = binatry_matrix_mat['adjacency_matrix']

datain = np.array(datain)          # shape: [nBeta, nTime, nChan, nSubject] (adjust as needed)
binatry_matrix = np.array(binatry_matrix)

# Parse shapes
nBeta     = datain.shape[0]
nTime     = datain.shape[1]
nChan     = datain.shape[2]
nSubject  = datain.shape[3]

In [79]:
# ===========================================
# 2) Set device & convert adjacency
# ===========================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
edge_index, _ = dense_to_sparse(torch.tensor(binatry_matrix, dtype=torch.float32))
print(device)

# ===========================================
# 3) Dataset builder
# ===========================================
def build_concat_dataset_for_beta(datain, beta_idx):
    """
    Build a ConcatDataset across all subjects for one beta index.
    `datain` is [nBeta, nTime, nChan, nSubject].
    Returns: ConcatDataset of length nSubject.
    """
    subject_datasets = []
    for s in range(nSubject):
        # Extract slice for one subject and one beta
        # shape: [nTime, nChan]
        single_beta_data = datain[beta_idx, :, :, s]

        # If each node is a channel and its feature dimension is time,
        # we might want single_beta_data to be [nChan, nTime].
        # If so, transpose here:
        single_beta_data = single_beta_data.T  # shape: [nChan, nTime]

        # For an autoencoder, features=labels
        x_tensor = torch.tensor(single_beta_data, dtype=torch.float32)
        x_tensor = x_tensor.unsqueeze(0)

        subject_datasets.append(TensorDataset(x_tensor, x_tensor))

    return ConcatDataset(subject_datasets)

cuda


In [80]:
# ===========================================
# 4) PyG Graph Dataset
# ===========================================
class EEGGraphDataset(torch.utils.data.Dataset):
    """
    Wraps a standard dataset so that each item is a PyG Data object:
      x -> node features (channels)
      edge_index -> adjacency
    """
    def __init__(self, eeg_data, edge_index):
        self.eeg_data = eeg_data
        self.edge_index = edge_index

    def __len__(self):
        return len(self.eeg_data)

    def __getitem__(self, idx):
        # item: (x, x) because it's autoencoder
        sample = self.eeg_data[idx]
        x = sample[0]  # shape: [nChan, nTime] (if you used transpose above)
        # Convert x -> float32, build PyG Data
        graph_data = Data(x=torch.tensor(x, dtype=torch.float32),
                          edge_index=self.edge_index)
        return graph_data


In [82]:
# ===========================================
# 5) Define GraphAutoencoder
# ===========================================
class GraphAutoencoder(nn.Module):
    def __init__(self, num_features, embedding_dim=64):
        """
        num_features = dimension of each node's feature vector.
                       If each channel is a node, and you pass
                       [nChan, nTime], then num_features = nTime.
        """
        super(GraphAutoencoder, self).__init__()
        # Example with two ARMAConv layers
        self.encoder_gcn1 = pyg_nn.ARMAConv(num_features, 128, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.encoder_gcn2 = pyg_nn.ARMAConv(128, embedding_dim, 
                                            num_stacks=2, 
                                            num_layers=3,
                                            shared_weights=True)
        self.decoder_fc1 = nn.Linear(embedding_dim, 128)
        self.decoder_fc2 = nn.Linear(128, num_features)

    # def forward(self, x, edge_index):
    #     x = torch.relu(self.encoder_gcn1(x, edge_index))
    #     latent = torch.relu(self.encoder_gcn2(x, edge_index))
    #     x = torch.relu(self.decoder_fc1(latent))
    #     reconstructed = self.decoder_fc2(x)
    #     return reconstructed
    
    def encode(self, x, edge_index):
        x = torch.relu(self.encoder_gcn1(x, edge_index))
        latent = torch.relu(self.encoder_gcn2(x, edge_index))
        return latent
    
    def decode(self, latent):
        x = torch.relu(self.decoder_fc1(latent))
        reconstructed = self.decoder_fc2(x)
        return reconstructed
    
    def forward(self, x, edge_index):
        latent = self.encode(x, edge_index)
        reconstructed = self.decode(latent)
        return reconstructed, latent
    

In [84]:
# Latent ablation Part
def compute_baseline_loss(model, data_loader, device):
    model.eval()
    total_loss, n = 0, 0
    criterion = nn.SmoothL1Loss()
    with torch.no_grad():
        for data in data_loader:
            data = data.to(device)
            recon, _ = model(data.x, data.edge_index)
            loss = criterion(recon, data.x)
            total_loss += loss.item()
            n += 1
    return total_loss / max(n, 1)

def apply_latent_ablation(model, data_loader, device, embedding_dim):
    model.eval()
    with torch.no_grad():
        baseline_loss = compute_baseline_loss(model, data_loader, device)
        imp = np.zeros(embedding_dim)
        criterion = nn.SmoothL1Loss()
        
        for dim in range(embedding_dim):
            print(f"Ablating latent dimension {dim+1}/{embedding_dim}")
            total_loss, n = 0, 0
            for data in data_loader:
                data = data.to(device)
                latent = model.encode(data.x, data.edge_index)
                latent_ablated = latent.clone()
                latent_ablated[:, dim] = 0  # Ablate one dimension
                recon = model.decode(latent_ablated)
                loss = criterion(recon, data.x)
                total_loss += loss.item()
                n += 1
            avg_loss = total_loss / max(n, 1)
            imp[dim] = avg_loss - baseline_loss
    return imp

In [None]:
# Module ablation

def apply_module_ablation(model, data_loader, device, module_name, ratio=0.3):
    modules = dict(model.named_modules())
    if module_name not in modules:
        raise ValueError(f"Module {module_name} is not found in model.")
    target_module = modules[module_name]
    mask = None
    
    def hook(module, input, output):
        # Make mask only once
        nonlocal mask
        
        if mask is None:
            num = output.size(-1)
            k = int(num * ratio)
            # Randomly select k indices to ablate
            ind = torch.randperm(num, device=output.device)[:k]
            mask = torch.ones(num, device=output.device)
            mask[ind] = 0.0
        return output * mask
    
    handle = target_module.register_forward_hook(hook)
    ablated_loss = compute_baseline_loss(model, data_loader, device)
    handle.remove()
    return ablated_loss

In [86]:
# Edge ablation

# Randomly remove a ratio of edges from the graph and compute loss increase
def make_random_edge_ablation(edge_index, drop_ratio):
    edge_index = edge_index.cpu()
    E = edge_index.size(1)
    num_drop = int(E * drop_ratio)
    # Randomly select edges to drop
    ind = torch.randperm(E)
    mask = torch.ones(E, dtype=torch.bool)
    mask[ind[:num_drop]] = False
    edge_index_dropped = edge_index[:, mask]
    return edge_index_dropped

# Remove edges conneted to the slected node
def make_node_edge_ablation(edge_index, node_idx):
    edge_index = edge_index.cpu()
    mask = (edge_index[0] != node_idx) & (edge_index[1] != node_idx)
    edge_index_dropped = edge_index[:, mask]
    return edge_index_dropped


In [87]:
# ===========================================
# 6) Training and evaluation
# ===========================================
def train_model_all(model, train_loader, device, num_epochs):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.SmoothL1Loss()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch_data in train_loader:
            batch_data = batch_data.to(device)
            optimizer.zero_grad()

            # Forward
            reconstructed, _ = model(batch_data.x, batch_data.edge_index)
            loss = criterion(reconstructed, batch_data.x)

            # Backprop
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)
        print(f"Beta Model Training - Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    return model

def evaluation_model(model, data_sample, edge_index, device):
    """
    data_sample: shape [nChan, nTime], or [nTime, nChan],
                 whichever you used in training.
    Returns: (original_np, reconstructed_np, error_np)
    """
    model.eval()
    with torch.no_grad():
        x = torch.tensor(data_sample, dtype=torch.float32, device=device)
        eidx = edge_index.to(device)
        reconstructed, _ = model(x, eidx)
    orig = x.cpu().numpy()
    recon = reconstructed.cpu().numpy()
    error = orig - recon
    return orig, recon, error

In [88]:
# ===========================================
# 7) Main pipeline
#    1) For each beta, build dataset/loader.
#    2) Initialize & train a separate model.
#    3) Reconstruct each subject's data for that beta.
#    4) Store reconstruction in a final array.
# ===========================================

num_features = nTime   # if each channel is a node, and feature vector = [nTime]
num_epochs   = 200     # number of training - set as needed
batch_size   = 4;      # dataset/batch to update hyperparameters
modules_name = [
        "encoder_gcn1",
        "encoder_gcn2",
        "decoder_fc1",
        "decoder_fc2",
    ]


models = []
imp_latent_list = []
imp_module_list = []
imp_rand_edge_list = []
imp_node_edge_list = []

for iBeta in range(nBeta):
    print(f"\n=== Building dataset/model for Beta {iBeta} ===")
    # Build dataset & loader
    full_dataset_i  = build_concat_dataset_for_beta(datain, iBeta)
    graph_dataset_i = EEGGraphDataset(full_dataset_i, edge_index)
    full_loader_i   = GeometricDataLoader(graph_dataset_i, batch_size=batch_size, shuffle=True)

    # Init model & train
    model_i = GraphAutoencoder(num_features=num_features, embedding_dim=32)
    model_i = train_model_all(model_i, full_loader_i, device, num_epochs=num_epochs)
    models.append(model_i)
    
    # Apply latent ablation
    imp_latent = apply_latent_ablation(model_i, full_loader_i, device, embedding_dim=32)
    imp_latent_list.append(np.array(imp_latent))
    
    # Apply module ablation
    ratio = 0.3
    baseline_loss = compute_baseline_loss(model_i, full_loader_i, device)
    module_ablation_res = {}
    for mod_name in modules_name:
        ablated_loss = apply_module_ablation(model_i, full_loader_i, device, mod_name, ratio=ratio)
        module_ablation_res[mod_name] = ablated_loss - baseline_loss
    imp_module_list.append(module_ablation_res)
    
    # Apply edge ablation
    # Random edge ablation with various drop ratios
    drop_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
    imp_rand_edge = []
    for dr in drop_ratios:
        edge_index_rand = make_random_edge_ablation(edge_index, dr)
        rand_graph_dataset_i = EEGGraphDataset(full_dataset_i, edge_index_rand)
        rand_full_loader_i = GeometricDataLoader(rand_graph_dataset_i, batch_size=batch_size, shuffle=True)
        edge_rand_loss = compute_baseline_loss(model_i, rand_full_loader_i, device)
        # Relative importance
        imp_rand_edge.append(abs(edge_rand_loss - baseline_loss)/baseline_loss)
    imp_rand_edge_list.append(np.array(imp_rand_edge))
        
    # Node-based edge ablation
    imp_node_edge = []
    for c in range(nChan):
        edge_index_node_ablated = make_node_edge_ablation(edge_index, c)
        node_graph_dataset_i = EEGGraphDataset(full_dataset_i, edge_index_node_ablated)
        node_full_loader_i = GeometricDataLoader(node_graph_dataset_i, batch_size=batch_size, shuffle=True)
        edge_node_loss = compute_baseline_loss(model_i, node_full_loader_i, device)
        imp_node_edge.append(abs(edge_node_loss - baseline_loss))
    imp_node_edge_list.append(np.array(imp_node_edge))
        
        
    

all_imp_latent = np.array(imp_latent_list)  # [nBeta, embedding_dim]
all_imp_module = np.array(imp_module_list)  # [nBeta, nModules]
all_imp_rand_edge = np.array(imp_rand_edge_list)  # [nBeta, nDropRatios]
all_imp_node_edge = np.array(imp_node_edge_list)  # [nBeta, nChan]  
        


=== Building dataset/model for Beta 0 ===


  graph_data = Data(x=torch.tensor(x, dtype=torch.float32),


Beta Model Training - Epoch [1/200], Loss: 0.7552
Beta Model Training - Epoch [2/200], Loss: 0.5720
Beta Model Training - Epoch [3/200], Loss: 0.4417
Beta Model Training - Epoch [4/200], Loss: 0.3927
Beta Model Training - Epoch [5/200], Loss: 0.3517
Beta Model Training - Epoch [6/200], Loss: 0.3163
Beta Model Training - Epoch [7/200], Loss: 0.2918
Beta Model Training - Epoch [8/200], Loss: 0.2681
Beta Model Training - Epoch [9/200], Loss: 0.2498
Beta Model Training - Epoch [10/200], Loss: 0.2346
Beta Model Training - Epoch [11/200], Loss: 0.2226
Beta Model Training - Epoch [12/200], Loss: 0.2107
Beta Model Training - Epoch [13/200], Loss: 0.2001
Beta Model Training - Epoch [14/200], Loss: 0.1878
Beta Model Training - Epoch [15/200], Loss: 0.1761
Beta Model Training - Epoch [16/200], Loss: 0.1685
Beta Model Training - Epoch [17/200], Loss: 0.1637
Beta Model Training - Epoch [18/200], Loss: 0.1604
Beta Model Training - Epoch [19/200], Loss: 0.1590
Beta Model Training - Epoch [20/200], Lo

In [89]:
# Save the ablation results
save_dir = os.path.join(curr_dir,analysis_type, 'Ablation')
os.makedirs(save_dir, exist_ok=True)

save_subdir_latent = os.path.join(save_dir, 'Latent Ablation')
os.makedirs(save_subdir_latent, exist_ok=True)

save_subdir_module = os.path.join(save_dir, 'Module Ablation')
os.makedirs(save_subdir_module, exist_ok=True)

save_subdir_edge = os.path.join(save_dir, 'Edge Ablation')
os.makedirs(save_subdir_edge, exist_ok=True)


imp_latent = {'all_imp_latent': all_imp_latent}
imp_module = {'all_imp_module': all_imp_module}
imp_node_edge = {'all_imp_node_edge': all_imp_node_edge}

savemat(os.path.join(save_subdir_latent, 'imp_latent.mat'), imp_latent)
savemat(os.path.join(save_subdir_module, f'imp_module_ratio_{ratio}.mat'), imp_module)
savemat(os.path.join(save_subdir_edge, 'imp_node_edge.mat'), imp_node_edge)

In [90]:
# Visualize the latent ablation results

for iBeta in range(nBeta):
    value = all_imp_latent[iBeta]
    plt.figure(figsize=(16,13))
    plt.bar(range(len(value)), value)
    plt.xticks(range(len(value)), range(1, len(value)+1))
    plt.xlabel('Latent Dimension')
    plt.ylabel('Importance')
    plt.title(f'Latent Dimension Importance for Beta {iBeta+1}')
    plt.tight_layout()
    #plt.show()
    plt.savefig(os.path.join(save_subdir_latent, f'Latent_Importance_Beta_{iBeta+1}.png'))
    plt.close()
    
    sorted_values = np.argsort(value)[:5]
    print(f'lowest 5 important latent dimensions for Beta {iBeta+1}: {sorted_values}')

lowest 5 important latent dimensions for Beta 1: [ 3 25 14  6 31]
lowest 5 important latent dimensions for Beta 2: [ 1 29 20 18  3]
lowest 5 important latent dimensions for Beta 3: [18 24 25 28 13]
lowest 5 important latent dimensions for Beta 4: [ 5 22 14  7 10]
lowest 5 important latent dimensions for Beta 5: [ 2 27 23 24 19]


In [105]:
# Vislualize the module ablation results

for iBeta in range(nBeta):
    module_imp_dict = imp_module_list[iBeta]
    modules = module_imp_dict.keys()
    values = module_imp_dict.values()

    plt.figure(figsize=(16,13))
    plt.bar(range(len(values)), values)
    plt.xticks(range(len(values)), modules)
    plt.xlabel('Module')
    plt.ylabel('Importance')
    plt.title(f'Module Importance for Beta {iBeta+1} (Ratio={ratio})')
    plt.tight_layout()
    #plt.show()
    plt.savefig(os.path.join(save_subdir_module, f'Module_Importance_Beta_{iBeta+1}_Ratio_{ratio}.png'))
    plt.close()

In [111]:
# Visualize the edge ablation results
for iBeta, list in enumerate(all_imp_rand_edge):
    print(f"--- Beta {iBeta+1} Random Edge Ablation Relative Importance ---")
    for dr, imp in zip(drop_ratios, list):
        print(f"Drop Ratio: {dr}, Relative Importance: {imp}")
        
        
print("\n--- Node-based Edge Ablation Importance ---")
top5_counter = Counter()
bottom5_counter = Counter()

for iBeta in range(nBeta):
    value = all_imp_node_edge[iBeta]
    plt.figure(figsize=(16,13))
    plt.bar(range(len(value)), value)
    plt.xticks(range(len(value)), range(1, len(value)+1))
    plt.xlabel('Channel Index')
    plt.ylabel('Importance')
    plt.title(f'Node-based Edge Ablation Importance for Beta {iBeta+1}')
    plt.tight_layout()
    #plt.show()
    plt.savefig(os.path.join(save_subdir_edge, f'Node_Edge_Importance_Beta_{iBeta+1}.png'))
    plt.close()
    
    sorted_values = np.argsort(value)
    high_ind = sorted_values[-5:]
    low_ind = sorted_values[:5]
    top5_counter.update(high_ind)
    bottom5_counter.update(low_ind)
    
    high_vals = value[high_ind]
    low_vals = value[low_ind]
    print(f'Highest 5 important channels for Beta {iBeta+1}:')
    for ind, val in zip(high_ind, high_vals):
        print(f'Channel {ind+1} | Importance: {val:.6f}')
    print(f'Lowest 5 important channels for Beta {iBeta+1}:')
    for ind, val in zip(low_ind, low_vals):
        print(f'Channel {ind+1} | Importance: {val:.6f}')

top5_sorted = sorted(top5_counter.items(), key=lambda x: x[1], reverse=True)
bottom5_sorted = sorted(bottom5_counter.items(), key=lambda x: x[1], reverse=True)
print("\n--- Most important channels across all Betas ---")
for ind, count in top5_sorted:
    print(f'Channel {ind+1} | Count in Top 5: {count}')
print("\n--- Least important channels across all Betas ---")
for ind, count in bottom5_sorted:
    print(f'Channel {ind+1} | Count in Bottom 5: {count}')

--- Beta 1 Random Edge Ablation Relative Importance ---
Drop Ratio: 0.1, Relative Importance: 0.01274361371123829
Drop Ratio: 0.2, Relative Importance: 0.052862541599235126
Drop Ratio: 0.3, Relative Importance: 0.119014362149056
Drop Ratio: 0.4, Relative Importance: 0.24051603848006323
Drop Ratio: 0.5, Relative Importance: 0.44857139357864767
Drop Ratio: 0.6, Relative Importance: 0.8574574997160506
Drop Ratio: 0.7, Relative Importance: 1.687515832284541
Drop Ratio: 0.8, Relative Importance: 2.1114712738508574
Drop Ratio: 0.9, Relative Importance: 3.4974194816185933
Drop Ratio: 1.0, Relative Importance: 3.7620232115512917
--- Beta 2 Random Edge Ablation Relative Importance ---
Drop Ratio: 0.1, Relative Importance: 0.013741325092458715
Drop Ratio: 0.2, Relative Importance: 0.0724824051856708
Drop Ratio: 0.3, Relative Importance: 0.06736292216290493
Drop Ratio: 0.4, Relative Importance: 0.13040633314709352
Drop Ratio: 0.5, Relative Importance: 0.27569662084599217
Drop Ratio: 0.6, Relative