# Occlusion Analysis

In [None]:
import utils
import torch
# ======= Graph Support Computation ======= #
def compute_graph_supports(adj_mat, filter_type='dual_random_walk', device=None):
    """
    Compute graph supports from adjacency matrix.
    Supports are used as input to DCRNN layers.

    Args:
        adj_mat (np.ndarray or scipy sparse): Adjacency matrix.
        filter_type (str): 'laplacian' or 'dual_random_walk'.
        device (torch.device): Device to load supports onto.

    Returns:
        List[torch.FloatTensor]: List of support matrices.
    """
    supports = []
    if filter_type == 'laplacian':
        supports.append(utils.calculate_scaled_laplacian(adj_mat, lambda_max=None))
    elif filter_type == 'dual_random_walk':
        supports.append(utils.calculate_random_walk_matrix(adj_mat).T)
        supports.append(utils.calculate_random_walk_matrix(adj_mat.T).T)
    else:
        raise ValueError(f'Unsupported filter type: {filter_type}')

    return [torch.FloatTensor(s.toarray()).to(device) for s in supports]

# ======= Device Setup ======= #
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
import pickle
import numpy as np

with open('af12.pkl', 'rb') as f:
    adjfn = pickle.load(f)
with open('ag12.pkl', 'rb') as f:
    adjgn = pickle.load(f)
with open('at12.pkl', 'rb') as f:
    adjtn = pickle.load(f)
with open('atc12.pkl', 'rb') as f:
    adjtc = pickle.load(f)
with open('ac12.pkl', 'rb') as f:
    adjcp = pickle.load(f)
with open('as12.pkl', 'rb') as f:
    adjsp = pickle.load(f)
with open('aa12.pkl', 'rb') as f:
    adjab = pickle.load(f)

afn = [arr for sublist in adjfn for arr in sublist]
agn = [arr for sublist in adjgn for arr in sublist]
acp = [arr for sublist in adjcp for arr in sublist]
asp = [arr for sublist in adjsp for arr in sublist]
atn = [arr for sublist in adjtn for arr in sublist]
atc = [arr for sublist in adjtc for arr in sublist]
aab = [arr for sublist in adjab for arr in sublist]

all_adjs = np.concatenate((
    np.array(afn, dtype=np.float32),
    np.array(agn, dtype=np.float32),
    np.array(aab, dtype=np.float32),
    np.array(atn, dtype=np.float32),
    np.array(atc, dtype=np.float32),
    np.array(acp, dtype=np.float32),
    np.array(asp, dtype=np.float32)
), axis=0)

all_supports = []
for i in range(len(all_adjs)):
    supports = compute_graph_supports(np.squeeze(all_adjs[i]), filter_type='dual_random_walk', device=device)
    all_supports.append(supports)

In [None]:
from collections import defaultdict

# def classwise_occlusion_analysis(
#     model, data_loader, adj_mat, device, 
#     compute_supports_func, filter_type, num_classes
# ): 1
def classwise_occlusion_analysis(
    model, data_loader, device, supports_orig, filter_type, num_classes
):
    model.eval()
    model.to(device)

    num_nodes = 19
    supports_orig = [s.to(device) for s in supports_orig]

    node_importance = {cls: defaultdict(float) for cls in range(num_classes)}
    edge_importance = {cls: defaultdict(float) for cls in range(num_classes)}
    class_counts = defaultdict(int)

    with torch.no_grad():
        for batch_x, batch_y, batch_seq_len in data_loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            batch_seq_len = batch_seq_len.to(device)

            outputs = model(batch_x, batch_seq_len, supports_orig)
            preds = torch.argmax(torch.softmax(outputs, dim=1), dim=1)

            for i in range(batch_x.size(0)):
                x_sample = batch_x[i:i+1]
                y_true = batch_y[i].item()
                class_counts[y_true] += 1

                sl = batch_seq_len[i:i+1]
                orig_out = model(x_sample, sl, supports_orig)
                orig_prob = torch.softmax(orig_out, dim=1)[0, y_true].item()

                # --- Node Occlusion ---
                for node in range(num_nodes):
                    x_occ = x_sample.clone()
                    x_occ[:, :, node, :] = 0
                    out_occ = model(x_occ, sl, supports_orig)
                    occ_prob = torch.softmax(out_occ, dim=1)[0, y_true].item()
                    drop = orig_prob - occ_prob
                    node_importance[y_true][node] += drop

                # --- Edge Occlusion ---
                for s_idx, s in enumerate(supports_orig):
                    for src in range(num_nodes):
                        for tgt in range(num_nodes):
                            if s[src, tgt] != 0:
                                s_occ = [s_.clone() for s_ in supports_orig]
                                s_occ[s_idx][src, tgt] = 0
                                out_occ = model(x_sample, sl, s_occ)
                                occ_prob = torch.softmax(out_occ, dim=1)[0, y_true].item()
                                drop = orig_prob - occ_prob
                                edge_importance[y_true][(src, tgt)] += drop

    # Normalize
    for cls in range(num_classes):
        if class_counts[cls] > 0:
            for node in node_importance[cls]:
                node_importance[cls][node] /= class_counts[cls]
            for edge in edge_importance[cls]:
                edge_importance[cls][edge] /= class_counts[cls]

    return node_importance, edge_importance

In [None]:
pre_model.eval()
pre_model.to(device)

node_importance = {cls: defaultdict(float) for cls in range(7)}
edge_importance = {cls: defaultdict(float) for cls in range(7)}
class_counts = defaultdict(int)

# node_imp, edge_imp = classwise_occlusion_analysis(
#     model=pre_model,
#     data_loader=test_loader,
#     adj_mat=adj_mat,
#     device=device,
#     compute_supports_func=compute_graph_supports,
#     filter_type='laplacian',
#     num_classes=7
# ) 1

node_imp, edge_imp = classwise_occlusion_analysis(
    model=pre_model,
    data_loader=test_loader,
    adj_mat=adj_mat,
    device=device,
    supports_orig=all_supports,
    filter_type='dual_random_walk',
    num_classes=7
) #2

In [None]:
import numpy as np
import os
import sys
import pickle
import networkx as nx
import collections
import matplotlib
import matplotlib.pyplot as plt
from scipy.stats import rankdata


def get_spectral_graph_positions():
    """
    # Get positions of EEG electrodes for visualizations
    """

    # adj_mx_all = adjdata
    # adj_mx = adj_mx_all[-1] 

    node_id_dict = {'EEG FP1': 0,
                    'EEG FP2': 1,
                    'EEG F3': 2,
                    'EEG F4': 3,
                    'EEG C3': 4,
                    'EEG C4': 5,
                    'EEG P3': 6,
                    'EEG P4': 7,
                    'EEG O1': 8,
                    'EEG O2': 9,
                    'EEG F7': 10,
                    'EEG F8': 11,
                    'EEG T3': 12,
                    'EEG T4': 13,
                    'EEG T5': 14,
                    'EEG T6': 15,
                    'EEG FZ': 16,
                    'EEG CZ': 17,
                    'EEG PZ': 18}

    eeg_viz = nx.Graph()
    # adj_mx = adj_mx_all[-1]
    node_id_label = collections.defaultdict()

    for i in range(19):
        eeg_viz.add_node(i)

    for k, v in node_id_dict.items():
        node_id_label[v] = k
        
    # Add edges
    for i in range(19):
        for j in range(19):  # do not include self-edge in visualization
            if i != j: 
            # and adj_mx[i, j] > 0:
                eeg_viz.add_edge(i, j)

    pos = nx.spectral_layout(eeg_viz)
    # keep the nice shape of the electronodes on the scalp
    pos_spec = {node: (y, -x) for (node, (x, y)) in pos.items()}

    return pos_spec


# # def draw_graph_weighted_edge(
# #         adj_mx,
# #         node_id_dict,
# #         pos_spec,
# #         is_directed,
# #         title='',
# #         save_dir=None,
# #         fig_size=(
# #             12,
# #             8),
# #     node_color='Red',
# #     font_size=20,
# #         plot_colorbar=False):
# def draw_graph_weighted_edge(
#         node_id_dict,
#         pos_spec,
#         is_directed,
#         title='',
#         save_dir=None,
#         fig_size=(
#             12,
#             8),
#     node_color='Red',
#     font_size=20,
#         plot_colorbar=False):
#     """
#     Draw a graph with weighted edges
#     Args:
#         adj_mx: Adjacency matrix for the graph, shape (num_nodes, num_nodes)
#         node_id_dict: dict, key is node name, value is node index
#         pos_spec: Graph node position specs from function get_spectral_graph_positions
#         is_directed: If True, draw directed graphs
#         title: str, title of the figure
#         save_dir: Dir to save the plot
#         fig_size: figure size

#     """
#     eeg_viz = nx.DiGraph() if is_directed else nx.Graph()
#     node_id_label = collections.defaultdict()

#     for i in range(19):
#         eeg_viz.add_node(i)

#     for k, v in node_id_dict.items():
#         node_id_label[v] = k

#     # Add edges
#     for i in range(19):
#         for j in range(19):  # since it's now directed
#             if i != j:
#                 eeg_viz.add_edge(i, j, weight=adj_mx[i, j])

#     edges, weights = zip(*nx.get_edge_attributes(eeg_viz, 'weight').items())

#     # Change the color scales below
#     k = 3
#     cmap = plt.cm.Greys(np.linspace(0, 1, (k + 1) * len(weights)))
#     cmap = matplotlib.colors.ListedColormap(cmap[len(weights):-1:(k - 1)])

#     plt.figure(figsize=fig_size)
#     nx.draw_networkx(eeg_viz, pos_spec, labels=node_id_label, with_labels=True,
#                      edgelist=edges, edge_color=rankdata(weights),
#                      width=fig_size[1] / 2, edge_cmap=cmap, font_weight='bold',
#                      node_color=node_color,
#                      node_size=250 * (fig_size[0] + fig_size[1]),
#                      font_color='white',
#                      font_size=font_size)
#     plt.title(title, fontsize=font_size)
#     plt.axis('off')
#     if plot_colorbar:
#         sm = plt.cm.ScalarMappable(
#             cmap=cmap, norm=plt.Normalize(
#                 vmin=0, vmax=1))
#         sm.set_array([])
#         plt.colorbar(sm)
#     plt.tight_layout()
#     if save_dir is not None:
#         plt.savefig(save_dir, dpi=300)

#     plt.show()

In [None]:
# def visualize_importance_graph(adj_mx_data, node_importance, edge_importance, class_label='Class', figsize=(12, 8)):
def visualize_importance_graph(node_importance, edge_importance, class_label='Class', figsize=(12, 8)):
    import matplotlib.pyplot as plt
    import networkx as nx

    # Load adj matrix
    # adj_mx = adj_mx_data[2]
    eegnodes = {'EEG FP1': 0,
    'EEG FP2': 1,
    'EEG F3': 2,
    'EEG F4': 3,
    'EEG C3': 4,
    'EEG C4': 5,
    'EEG P3': 6,
    'EEG P4': 7,
    'EEG O1': 8,
    'EEG O2': 9,
    'EEG F7': 10,
    'EEG F8': 11,
    'EEG T3': 12,
    'EEG T4': 13,
    'EEG T5': 14,
    'EEG T6': 15,
    'EEG FZ': 16,
    'EEG CZ': 17,
    'EEG PZ': 18}
    # node_id_dict = {k.split(' ')[-1]: v for k, v in adj_mx_data[1].items()}
    node_id_dict = {k.split(' ')[-1]: v for k, v in eegnodes.items()}
    pos_spec = get_spectral_graph_positions()

    # Invert the mapping: index → label
    node_labels = {v: k for k, v in node_id_dict.items()}
    G = nx.Graph()

    for i in range(19):
        G.add_node(i)

    for i in range(19):
        for j in range(19):
            if i != j: 
            # and adj_mx[i, j] > 0:
                G.add_edge(i, j)

    # Normalize node and edge importances
    node_colors = np.array([node_importance.get(i, 0.0) for i in range(19)])
    edge_colors = np.array([edge_importance.get((i, j), 0.0) for i, j in G.edges()])

    # Normalize for color mapping
    node_vmin, node_vmax = node_colors.min(), node_colors.max()
    edge_vmin, edge_vmax = edge_colors.min(), edge_colors.max()

    fig, ax = plt.subplots(figsize=figsize)
    nodes = nx.draw_networkx_nodes(
        G, pos_spec, node_color=node_colors, cmap='Reds',
        node_size=1200, ax=ax, vmin=node_vmin, vmax=node_vmax
    )
    edges = nx.draw_networkx_edges(
        G, pos_spec, edge_color=edge_colors, edge_cmap=plt.cm.Blues,
        width=3, ax=ax, edge_vmin=edge_vmin, edge_vmax=edge_vmax
    )
    nx.draw_networkx_labels(G, pos_spec, labels=node_labels, font_color='black', ax=ax, font_size=12)
    ax.set_title(f"Importance Visualization - {class_label}")
    ax.axis('off')

    # Add colorbars with correct axes context
    sm_nodes = plt.cm.ScalarMappable(cmap='Reds', norm=plt.Normalize(vmin=node_vmin, vmax=node_vmax))
    sm_nodes.set_array([])
    plt.colorbar(sm_nodes, ax=ax, orientation='vertical', label='Node Importance')

    sm_edges = plt.cm.ScalarMappable(cmap='Blues', norm=plt.Normalize(vmin=edge_vmin, vmax=edge_vmax))
    sm_edges.set_array([])
    plt.colorbar(sm_edges, ax=ax, orientation='vertical', label='Edge Importance')

    plt.tight_layout()
    plt.show()

In [None]:
class_id = 0
# visualize_importance_graph(adjdata, node_importance[class_id], edge_importance[class_id], class_label=f"Class {class_id}")
visualize_importance_graph(node_importance[class_id], edge_importance[class_id], class_label=f"Class {class_id}")