In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from braindecode.models import ShallowFBCSPNet, EEGConformer
import pandas as pd
from collections import OrderedDict

In [2]:
def linear_kernel(X):
    """Computes the linear kernel matrix for X."""
    return X @ X.T  # Dot product

def centering_matrix(K):
    """Centers the kernel matrix K."""
    n = K.shape[0]
    H = np.eye(n,dtype=K.dtype) - (1 / n) * np.ones((n, n), dtype=K.dtype)
    return H # Equivalent to HKH

def compute_hsic(X, Y, kernel_X = linear_kernel, kernel_Y = linear_kernel):
    """
    Computes the Hilbert-Schmidt Independence Criterion (HSIC).
    
    Parameters:
    - X: (n_samples, n_features_X) numpy array
    - Y: (n_samples, n_features_Y) numpy array
    - kernel_X: function to compute the kernel matrix for X
    - kernel_Y: function to compute the kernel matrix for Y
    
    Returns:
    - HSIC value (float)
    """
    X = X.to(torch.float32)
    Y = Y.to(torch.float32)
    
    K = kernel_X(X)
    L = kernel_Y(Y)
    
    K = K.cpu().numpy()
    L = L.cpu().numpy()
    
    H = centering_matrix(K)
   
    
    Kxy_centered = K @ H @ L @ H
    
    hsic_value = np.trace(Kxy_centered) / ((X.shape[0] - 1) ** 2)
    return hsic_value.item()
  
def compute_CKA(X,Y,kernel_X = linear_kernel,kernel_Y = linear_kernel):
  """
  compute CKA between two X,Y activations
  
  Parameters:
  - X: (n_samples, x_features)
  - Y: (n_samples, y_features)
  - kernel_X: kernel for X
  - kernel_Y: kernel for Y  
  """
  HSIC_KL = compute_hsic(X,Y,kernel_X,kernel_Y) 
  HSIC_KK = compute_hsic(X,X,kernel_X,kernel_X)
  HSIC_LL = compute_hsic(Y,Y,kernel_Y, kernel_Y)
  
  return HSIC_KL/(np.sqrt(HSIC_KK * HSIC_LL))


In [48]:
# Define model parameters
in_chans = 22
n_classes = 4
n_channels = 22
input_window_samples = 1000
# Load two models for comparison
model= torch.load("conformer_model.pth",weights_only = False)
model2 =ShallowFBCSPNet(in_chans, n_classes, input_window_samples)





In [37]:
print(model)

Layer (type (var_name):depth-idx)                            Input Shape               Output Shape              Param #                   Kernel Shape
EEGConformer (EEGConformer)                                  [1, 22, 1000]             [1, 4]                    --                        --
├─_PatchEmbedding (patch_embedding): 1-1                     [1, 1, 22, 1000]          [1, 61, 40]               --                        --
│    └─Sequential (shallownet): 2-1                          [1, 1, 22, 1000]          [1, 40, 1, 61]            --                        --
│    │    └─Conv2d (0): 3-1                                  [1, 1, 22, 1000]          [1, 40, 22, 976]          1,040                     [1, 25]
│    │    └─Conv2d (1): 3-2                                  [1, 40, 22, 976]          [1, 40, 1, 976]           35,240                    [22, 1]
│    │    └─BatchNorm2d (2): 3-3                             [1, 40, 1, 976]           [1, 40, 1, 976]           80             

In [38]:
def extract_model_activations(model, input_tensor):
    activations = OrderedDict()

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    # Register hooks
    for name, layer in model.named_modules():
        #print(layer)
        layer.register_forward_hook(get_activation(name))


    # Run forward pass AFTER registering hooks
    model.eval()
   
    _ = model(input_tensor)  

    return activations  # Return collected activations

In [6]:
import numpy as np
from braindecode.datasets import MOABBDataset

subject_id = [1,2,3,4]
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[1,2,3,4,5,6,7,8,9])



from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

transforms = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, transforms, n_jobs=-1)




  warn('Preprocessing choices with lambda functions cannot be saved.')


<braindecode.datasets.moabb.MOABBDataset at 0x1611e30ac60>

In [10]:
from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
trial_stop_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)
trial_stop_offset_samples = int(trial_stop_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=trial_stop_offset_samples,
    preload=True,
)

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

In [57]:
splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test'] 

In [58]:
from torch.utils.data import DataLoader
test_loader = DataLoader(test_set, batch_size=64)
dummy = next(iter(test_loader))
dummy_input = dummy[0]

In [59]:
#dummy_input = torch.randn(64, in_chans, input_window_samples)

model_1_activations = extract_model_activations(model,dummy_input)
model_2_activations = extract_model_activations(model2,dummy_input)


In [60]:
def visualize_comparison(model_1_activations, model_2_activations):
    # Extract the activation tensors from both models
    act_1 = model_1_activations.get('conv_time_spat')
    act_2 = model_2_activations.get('conv_time_spat')

    if act_1 is not None and act_2 is not None:
        # Average activation across all channels
        avg_act_1 = torch.mean(act_1, dim=1).squeeze(0).cpu().numpy()  # Average across channels (C dimension)
        avg_act_2 = torch.mean(act_2, dim=1).squeeze(0).cpu().numpy()
        

        # Plot side-by-side comparison of average activations
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        # Plot for model 1
        axes[0].imshow(avg_act_1, cmap='viridis')
        axes[0].axis('off')
        axes[0].set_title("Model 1 - Average Activation")

        # Plot for model 2
        axes[1].imshow(avg_act_2, cmap='viridis')
        axes[1].axis('off')
        axes[1].set_title("Model 2 - Average Activation")

        plt.tight_layout()
        plt.show()

        # Plot difference between activations (model 1 - model 2)
        activation_diff = avg_act_1 - avg_act_2
        plt.imshow(activation_diff, cmap='seismic', vmin=-np.max(np.abs(activation_diff)), vmax=np.max(np.abs(activation_diff)))
        plt.colorbar()
        plt.axis('off')
        plt.title("Activation Difference (Model 1 - Model 2)")
        plt.show()

    else:
        print("Could not extract activations for 'conv_time_spat'.")

In [61]:
model_1_activations

OrderedDict([('patch_embedding.shallownet.0',
              tensor([[[[-2.0768e-01, -2.7353e-01, -3.1864e-01,  ...,  8.8347e-01,
                          8.0451e-01,  6.2091e-01],
                        [-3.1832e-01, -4.3382e-01, -5.3684e-01,  ...,  6.2133e-01,
                          5.5530e-01,  4.5094e-01],
                        [-3.4574e-01, -3.6233e-01, -3.3809e-01,  ...,  6.6708e-01,
                          6.3388e-01,  5.5785e-01],
                        ...,
                        [-5.4564e-01, -4.7337e-01, -4.5283e-01,  ..., -7.0529e-01,
                         -2.0815e-01,  2.6597e-01],
                        [-5.9131e-01, -4.2916e-01, -3.7088e-01,  ..., -6.7632e-01,
                         -9.2145e-02,  4.3535e-01],
                        [-5.7815e-01, -3.4144e-01, -2.1975e-01,  ..., -7.3054e-01,
                         -1.6290e-01,  3.8313e-01]],
              
                       [[ 6.8241e-01,  7.8988e-01,  8.8708e-01,  ..., -4.6273e-02,
                

In [62]:
print("Captured Activations for Model 1:")
for layer_name, activation in model_1_activations.items():

    print(f"{layer_name}: {activation.shape}")
    #print(activation)

Captured Activations for Model 1:
patch_embedding.shallownet.0: torch.Size([64, 40, 22, 976])
patch_embedding.shallownet.1: torch.Size([64, 40, 1, 976])
patch_embedding.shallownet.2: torch.Size([64, 40, 1, 976])
patch_embedding.shallownet.3: torch.Size([64, 40, 1, 976])
patch_embedding.shallownet.4: torch.Size([64, 40, 1, 61])
patch_embedding.shallownet.5: torch.Size([64, 40, 1, 61])
patch_embedding.shallownet: torch.Size([64, 40, 1, 61])
patch_embedding.projection.0: torch.Size([64, 40, 1, 61])
patch_embedding.projection.1: torch.Size([64, 61, 40])
patch_embedding.projection: torch.Size([64, 61, 40])
patch_embedding: torch.Size([64, 61, 40])
transformer.0.0.fn.0: torch.Size([64, 61, 40])
transformer.0.0.fn.1.queries: torch.Size([64, 61, 40])
transformer.0.0.fn.1.keys: torch.Size([64, 61, 40])
transformer.0.0.fn.1.values: torch.Size([64, 61, 40])
transformer.0.0.fn.1.att_drop: torch.Size([64, 10, 61, 61])
transformer.0.0.fn.1.projection: torch.Size([64, 61, 40])
transformer.0.0.fn.1: t

In [64]:
print("Captured Activations for Model 2:")
for layer_name, activation in model_2_activations.items():
    
        print(f"{layer_name}: {activation.shape}")
        print(activation)

Captured Activations for Model 2:
ensuredims: torch.Size([64, 22, 1000, 1])
tensor([[[[-1.0934],
          [-1.3873],
          [-1.6339],
          ...,
          [ 0.5334],
          [ 0.1577],
          [-0.3081]],

         [[-0.7164],
          [-1.2801],
          [-1.6891],
          ...,
          [ 0.7532],
          [ 0.2864],
          [-0.4522]],

         [[-0.4254],
          [-1.0408],
          [-1.5866],
          ...,
          [ 0.5883],
          [ 0.1377],
          [-0.4401]],

         ...,

         [[-0.0504],
          [-1.0458],
          [-1.7917],
          ...,
          [ 0.8610],
          [ 0.6860],
          [ 0.1684]],

         [[-0.0465],
          [-1.0722],
          [-1.8145],
          ...,
          [ 0.6980],
          [ 0.6135],
          [ 0.2038]],

         [[ 0.0146],
          [-0.9425],
          [-1.6523],
          ...,
          [ 0.9797],
          [ 0.7577],
          [ 0.1864]]],


        [[[-0.1932],
          [-0.2962],
       

In [None]:
# Use OrderedDict to preserve the order of layer names
layer_names_model1 = list(model_1_activations.keys())
layer_names_model2 = list(model_2_activations.keys())

cka_similarities = OrderedDict()

# Compute CKA for each layer from model1 and model2
for layer1 in layer_names_model1:
    for layer2 in layer_names_model2:
        activations_x = model_1_activations[layer1].reshape(model_1_activations[layer1].shape[0], -1)
        activations_y = model_2_activations[layer2].reshape(model_2_activations[layer2].shape[0], -1)
        cka_score = compute_CKA(activations_x, activations_y)
        print(f"inner layer {layer1} done")
        
        cka_similarities[(layer1, layer2)] = cka_score
    print(f"layer {layer1} done")

layer_names = list(OrderedDict.fromkeys(layer_names_model1 + layer_names_model2))

# Initialize a square matrix for the CKA similarities
n_layers = len(layer_names)

matrix = np.zeros((n_layers, n_layers))

# Fill the matrix with the CKA similarity values
for (layer1, layer2), similarity in cka_similarities.items():
    i = layer_names.index(layer1)
    j = layer_names.index(layer2)
    matrix[i, j] = similarity
    matrix[j, i] = similarity 

# Create a DataFrame for better visualization
df = pd.DataFrame(matrix, index=layer_names, columns=layer_names)

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(df, annot=False, cmap='magma', fmt='.2f', square=True, linewidths=0.5, cbar=True)
plt.title('CKA Similarity Heatmap')
plt.xlabel('Layer')
plt.ylabel('Layer')
plt.show()


inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.0 done
layer patch_embedding.shallownet.0 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1 done
inner layer patch_embedding.shallownet.1

IndexError: index 145 is out of bounds for axis 1 with size 13