In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from shallow_fbcsp import ShallowFBCSPNet
import pandas as pd
from collections import OrderedDict

In [3]:
def linear_kernel(X):
    """Computes the linear kernel matrix for X."""
    return torch.matmul(X,X.T)  # Dot product

def centering_matrix(K):
    """Centers the kernel matrix K."""
    n = K.shape[0]
    H = np.eye(n,dtype=K.dtype) - (1 / n) * np.ones((n, n), dtype=K.dtype)
    return H # Equivalent to HKH

def compute_hsic(X, Y, kernel_X = linear_kernel, kernel_Y = linear_kernel):
    """
    Computes the Hilbert-Schmidt Independence Criterion (HSIC).
    
    Parameters:
    - X: (n_samples, n_features_X) numpy array
    - Y: (n_samples, n_features_Y) numpy array
    - kernel_X: function to compute the kernel matrix for X
    - kernel_Y: function to compute the kernel matrix for Y
    
    Returns:
    - HSIC value (float)
    """
    X = X.to(torch.float32)
    Y = Y.to(torch.float32)
    
    K = kernel_X(X)
    L = kernel_Y(Y)
    
    K = K.cpu().numpy()
    L = L.cpu().numpy()
    
    H = centering_matrix(K)

    Kxy_centered = K @ H @ L @ H
    
    hsic_value = np.trace(Kxy_centered) / ((X.shape[0] - 1) ** 2)
    return hsic_value.item()
  
def compute_CKA(X,Y,kernel_X = linear_kernel,kernel_Y = linear_kernel):
  """
  compute CKA between two X,Y activations
  
  Parameters:
  - X: (n_samples, x_features)
  - Y: (n_samples, y_features)
  - kernel_X: kernel for X
  - kernel_Y: kernel for Y  
  """
  HSIC_KL = compute_hsic(X,Y,kernel_X,kernel_Y) 
  HSIC_KK = compute_hsic(X,X,kernel_X,kernel_X)
  HSIC_LL = compute_hsic(Y,Y,kernel_Y, kernel_Y)
  
  return HSIC_KL/(np.sqrt(HSIC_KK * HSIC_LL))


In [4]:
# Define model parameters
in_chans = 22
n_classes = 4
n_channels = 22
input_window_samples = 1000
# Load two models for comparison
model= torch.load("braindecode_model_temponly_1.pth",weights_only = False,map_location=torch.device('cpu'))
model2 =ShallowFBCSPNet(in_chans, n_classes, input_window_samples)



In [5]:
print(model)

Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 1125]             [1, 4]                    --                        --
├─SafeLog (pool_nonlin_exp): 1-1         [1, 22, 1125]             [1, 22, 1125]             --                        --
├─Ensure4d (ensuredims): 1-2             [1, 22, 1125]             [1, 22, 1125, 1]          --                        --
├─Rearrange (dimshuffle): 1-3            [1, 22, 1125, 1]          [1, 1, 1125, 22]          --                        --
├─Conv2d (conv_time): 1-4                [1, 1, 1125, 22]          [1, 40, 1101, 22]         1,040                     [25, 1]
├─Conv2d (conv_spat): 1-5                [1, 40, 1101, 22]         [1, 40, 1101, 1]          35,200                    [1, 22]
├─BatchNorm2d (bnorm): 1-6               [1, 40, 1101, 1]          [1, 40, 1101, 1]          80                        --
├─Ex

In [6]:
def extract_model_activations(model, input_tensor):
    activations = OrderedDict()

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    # Register hooks
    for name, layer in model.named_modules():
        if(name == "conv_time" or name == "conv_spat"):
            layer.register_forward_hook(get_activation(name))


    # Run forward pass AFTER registering hooks
    model.eval()
   
    _ = model(input_tensor)  

    return activations  # Return collected activations

In [7]:
import os
import torch

def extract_model_activations(model, input_tensor, output_dir, batch_size=128):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    activations = OrderedDict()

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    # Register hooks for specific layers
    for name, layer in model.named_modules():
        if name == "conv_time" or name == "conv_spat":  # Modify as per your layer names
            layer.register_forward_hook(get_activation(name))

    model.eval()

    with torch.no_grad():
        for i in range(0, input_tensor.shape[0], batch_size):
            batch = input_tensor[i:i + batch_size]  # Select current batch
            _ = model(batch)  # Forward pass through the model

            # Save activations after each batch
            for name, activation in activations.items():
                batch_idx = i // batch_size + 1  # This determines the batch number
                print(f"saving: {name}_batch_{batch_idx}.pt")
                torch.save(activation, os.path.join(output_dir, f"{name}_batch_{batch_idx}.pt"))
            
            # Clear activations list after saving
            activations.clear()
            torch.cuda.empty_cache()  # Optional: Clear GPU memory after each batch



In [8]:
import pickle
with open('Datasets/test_set.pkl', 'rb') as f:
    test_set = pickle.load(f)

In [9]:
X = torch.stack([torch.from_numpy(test_set[i][0]) for i in range(len(test_set))])


print(X.shape)  # Verify the tensor shape
print(type(X))  # Should output <class 'torch.Tensor'>


torch.Size([2592, 22, 1125])
<class 'torch.Tensor'>


In [10]:
print(model.conv_time.weight.device)  # Check the device of the conv_time layer
print(next(model.parameters()).device)


cpu
cpu


In [11]:
batch_size = 64
save_every_n_batches = 10

In [None]:
extract_model_activations(model,X,output_dir="Datasets/activations/model1/",batch_size=batch_size)


saving: conv_time_batch_1.pt
saving: conv_spat_batch_1.pt
saving: conv_time_batch_2.pt
saving: conv_spat_batch_2.pt
saving: conv_time_batch_3.pt
saving: conv_spat_batch_3.pt
saving: conv_time_batch_4.pt
saving: conv_spat_batch_4.pt
saving: conv_time_batch_5.pt
saving: conv_spat_batch_5.pt
saving: conv_time_batch_6.pt
saving: conv_spat_batch_6.pt
saving: conv_time_batch_7.pt
saving: conv_spat_batch_7.pt
saving: conv_time_batch_8.pt
saving: conv_spat_batch_8.pt
saving: conv_time_batch_9.pt
saving: conv_spat_batch_9.pt
saving: conv_time_batch_10.pt
saving: conv_spat_batch_10.pt
saving: conv_time_batch_11.pt
saving: conv_spat_batch_11.pt
saving: conv_time_batch_12.pt
saving: conv_spat_batch_12.pt
saving: conv_time_batch_13.pt
saving: conv_spat_batch_13.pt
saving: conv_time_batch_14.pt
saving: conv_spat_batch_14.pt
saving: conv_time_batch_15.pt
saving: conv_spat_batch_15.pt
saving: conv_time_batch_16.pt
saving: conv_spat_batch_16.pt
saving: conv_time_batch_17.pt
saving: conv_spat_batch_17.p

In [None]:
extract_model_activations(model2,X,output_dir="Datasets/activations/model2/",batch_size=batch_size)

saving: conv_time_batch_1.pt
saving: conv_spat_batch_1.pt
saving: conv_time_batch_2.pt
saving: conv_spat_batch_2.pt
saving: conv_time_batch_3.pt
saving: conv_spat_batch_3.pt
saving: conv_time_batch_4.pt
saving: conv_spat_batch_4.pt
saving: conv_time_batch_5.pt
saving: conv_spat_batch_5.pt
saving: conv_time_batch_6.pt
saving: conv_spat_batch_6.pt
saving: conv_time_batch_7.pt
saving: conv_spat_batch_7.pt
saving: conv_time_batch_8.pt
saving: conv_spat_batch_8.pt
saving: conv_time_batch_9.pt
saving: conv_spat_batch_9.pt
saving: conv_time_batch_10.pt
saving: conv_spat_batch_10.pt
saving: conv_time_batch_11.pt
saving: conv_spat_batch_11.pt
saving: conv_time_batch_12.pt
saving: conv_spat_batch_12.pt
saving: conv_time_batch_13.pt
saving: conv_spat_batch_13.pt
saving: conv_time_batch_14.pt
saving: conv_spat_batch_14.pt
saving: conv_time_batch_15.pt
saving: conv_spat_batch_15.pt
saving: conv_time_batch_16.pt
saving: conv_spat_batch_16.pt
saving: conv_time_batch_17.pt
saving: conv_spat_batch_17.p

In [16]:
len = torch.load(f"Datasets/activations/model1/conv_time_batch_{3}.pt")
print(len.shape)

torch.Size([64, 40, 1101, 22])


In [17]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict

# Use OrderedDict to preserve the order of layer names
layer_names_model1 = ['conv_time', 'conv_spat']
layer_names_model2 = layer_names_model1


print(layer_names_model1)
print(layer_names_model2)

['conv_time', 'conv_spat']
['conv_time', 'conv_spat']


In [21]:
import math
batch_nr = X.shape[0] /batch_size

batch_nr

40.5

In [23]:
cka_accumulated_scores = { (layer1, layer2): 0 for layer1 in layer_names_model1 for layer2 in layer_names_model2 }
model_1_activations = OrderedDict()
model_2_activations = OrderedDict()
acum = 0
l1 = ("conv_time","conv_time")
sum = 0.0
# Compute CKA for each layer from all activations
for batch_idx in range(1,math.ceil(batch_nr)):
    # Get activations for the current batch
    model_1_activations[layer_names_model1[0]] = torch.load(f"Datasets/activations/model1/{layer_names_model1[0]}_batch_{batch_idx}.pt")
    model_1_activations[layer_names_model1[1]] = torch.load(f"Datasets/activations/model1/{layer_names_model1[1]}_batch_{batch_idx}.pt")
    model_2_activations[layer_names_model1[0]] = torch.load(f"Datasets/activations/model2/{layer_names_model1[0]}_batch_{batch_idx}.pt")
    model_2_activations[layer_names_model1[1]] = torch.load(f"Datasets/activations/model2/{layer_names_model1[1]}_batch_{batch_idx}.pt")
    for layer1 in layer_names_model1:
        print(f"Processing layer {layer1} (batch {batch_idx})")
        for layer2 in layer_names_model2:
            # Reshape activations to 2D for CKA computation
            activations_x = model_1_activations[layer1].reshape(model_1_activations[layer1].shape[0], -1)
            activations_y = model_2_activations[layer2].reshape(model_2_activations[layer2].shape[0], -1)

            # Compute CKA score for the current layer pair
            cka_score = compute_CKA(activations_x, activations_y)
            print("cka score:", cka_score)
            if(layer1,layer2) == l1:
                sum += cka_score
            # Accumulate the CKA score for the current layer pair
            cka_accumulated_scores[(layer1, layer2)] += cka_score
            if (layer1,layer2) == l1:
                acum +=1
# Average the CKA scores across all batches
for layer_pair in cka_accumulated_scores:
    cka_accumulated_scores[layer_pair] = cka_accumulated_scores[layer_pair]/acum



Processing layer conv_time (batch 1)
cka score: 0.9981855898709294
cka score: 0.9946553066113416
Processing layer conv_spat (batch 1)
cka score: 0.9729664115123107
cka score: 0.9702039629189194
Processing layer conv_time (batch 2)
cka score: 0.9981085299251141
cka score: 0.9946001746610357
Processing layer conv_spat (batch 2)
cka score: 0.9730804609726298
cka score: 0.9712006789976095
Processing layer conv_time (batch 3)
cka score: 0.9979505645343671
cka score: 0.9952355499042429
Processing layer conv_spat (batch 3)
cka score: 0.9748745379927689
cka score: 0.973738462854181
Processing layer conv_time (batch 4)
cka score: 0.9979191137858031
cka score: 0.9943232442062301
Processing layer conv_spat (batch 4)
cka score: 0.9751846802944475
cka score: 0.9726522147892205
Processing layer conv_time (batch 5)
cka score: 0.9971241933303283
cka score: 0.9941771775729712
Processing layer conv_spat (batch 5)
cka score: 0.977711197094353
cka score: 0.9785039565885231
Processing layer conv_time (batc

In [185]:
model_1_kernels = OrderedDict()
model_2_kernels = OrderedDict()

# Compute CKA for each layer from all activations
for layer in layer_names_model1:
    activations_list = []
    for batch_idx in range(1,math.ceil(batch_nr)):
    # Get activations for the current batch
        batch_activations = torch.load(f"Datasets/activations/model1/{layer_names_model1[0]}_batch_{batch_idx}.pt")
        print(batch_activations.shape)
        activations_list.append(batch_activations)
        print(len(activations_list[0]))
    
    # Concatenate the activations across all batches along the 0th dimension (samples)
    model_1_kernels[layer] = torch.cat(activations_list, dim=0)
    print(len(model_1_kernels[layer]))





torch.Size([128, 40, 1101, 22])
128
torch.Size([128, 40, 1101, 22])
128
torch.Size([32, 40, 1101, 22])
128
288
torch.Size([128, 40, 1101, 22])
128
torch.Size([128, 40, 1101, 22])
128
torch.Size([32, 40, 1101, 22])
128
288


In [24]:
# Initialize a square matrix for the CKA similarities
n_layers = len(layer_names_model1)# + len(layer_names_model2)
matrix = np.zeros((n_layers, n_layers))

# Fill the matrix with the average CKA similarity values
for (layer1, layer2), similarity in cka_accumulated_scores.items():
    i = layer_names_model1.index(layer1) if layer1 in layer_names_model1 else len(layer_names_model1) + layer_names_model2.index(layer1)
    j = layer_names_model2.index(layer2) if layer2 in layer_names_model2 else len(layer_names_model2) + layer_names_model1.index(layer2)
    
    matrix[i, j] = similarity
    matrix[j, i] = similarity  # Symmetric matrix

# Create a DataFrame for better visualization
df = pd.DataFrame(matrix, index=layer_names_model1, columns=layer_names_model1)

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(df, annot=True, cmap='magma', fmt='.2f', square=True, linewidths=0.5, cbar=True, vmin=0, vmax=1)
plt.title('Average CKA Similarity Heatmap')
plt.xlabel('Layer')
plt.ylabel('Layer')
plt.show()

TypeError: 'Tensor' object is not callable