In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from shallow_fbcsp import ShallowFBCSPNet
import pandas as pd
from collections import OrderedDict
import math

In [None]:
def linear_kernel(X):
    """Computes the linear kernel matrix for X."""
    return torch.matmul(X,X.T)  # Dot product

def centering_matrix(K):
    """Apply centering to the kernel matrix."""
    n = K.shape[0]
    H = torch.eye(n) - (1.0 / n) * torch.ones((n, n), device=K.device)
    return H @ K @ H  # Centered kernel matrix

def compute_hsic(K_x, K_y):
    """
    Computes the Hilbert-Schmidt Independence Criterion (HSIC).
    
    Parameters:
    - X: (n_samples, n_features_X) numpy array
    - Y: (n_samples, n_features_Y) numpy array
    - kernel_X: function to compute the kernel matrix for X
    - kernel_Y: function to compute the kernel matrix for Y
    
    Returns:
    - HSIC value (float)
    """
    K_x_centered = centering_matrix(K_x)
    K_y_centered = centering_matrix(K_y)
    hsic_value = np.trace(K_x_centered @ K_y_centered) / ((K_x.shape[0] - 1) ** 2)
    return hsic_value
  
def compute_CKA(K_x,K_y):
  """
  compute CKA between two X,Y activations
  
  Parameters:
  - X: (n_samples, x_features)
  - Y: (n_samples, y_features)
  - kernel_X: kernel for X
  - kernel_Y: kernel for Y  
  """
  HSIC_KL = compute_hsic(K_x,K_y) 
  HSIC_KK = compute_hsic(K_x,K_x)
  HSIC_LL = compute_hsic(K_y,K_y)
  numerator = HSIC_KL
  denominator = math.sqrt(HSIC_KK * HSIC_LL)
  return(numerator/denominator).item()


In [None]:
# Define model parameters
in_chans = 22
n_classes = 4
n_channels = 22
input_window_samples = 1125
# Load two models for comparison
model= torch.load("braindecode_model_temponly_1.pth",weights_only = False,map_location=torch.device('cpu'))
model2 =ShallowFBCSPNet(in_chans, n_classes, input_window_samples)



In [None]:
print(model2)

In [None]:
print(model)

In [None]:
def extract_model_activations(model, input_tensor):
    activations = OrderedDict()

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    # Register hooks
    for name, layer in model.named_modules():
        if(name == "conv_time" or name == "conv_spat"):
            layer.register_forward_hook(get_activation(name))


    # Run forward pass AFTER registering hooks
    model.eval()
   
    _ = model(input_tensor)  

    return activations  # Return collected activations

In [None]:
import os
import torch

def extract_model_activations(model, input_tensor, output_dir, batch_size=128):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    activations = OrderedDict()

    def get_activation(name):
        def hook(model, input, output):
            activations[name] = output.detach()
        return hook

    # Register hooks for specific layers
    for name, layer in model.named_modules():
        if name == "conv_time" or name == "conv_spat":  # Modify as per your layer names
            layer.register_forward_hook(get_activation(name))

    model.eval()

    with torch.no_grad():
        for i in range(0, input_tensor.shape[0], batch_size):
            batch = input_tensor[i:i + batch_size]  # Select current batch
            _ = model(batch)  # Forward pass through the model

            # Save activations after each batch
            for name, activation in activations.items():
                batch_idx = i // batch_size + 1  # This determines the batch number
                print(f"saving: {name}_batch_{batch_idx}.pt")
                torch.save(activation, os.path.join(output_dir, f"{name}_batch_{batch_idx}.pt"))
            
            # Clear activations list after saving
            activations.clear()
            torch.cuda.empty_cache()  # Optional: Clear GPU memory after each batch



In [None]:
import pickle
with open('Datasets/test_set.pkl', 'rb') as f:
    test_set = pickle.load(f)

In [None]:
X = torch.stack([torch.from_numpy(test_set[i][0]) for i in range(len(test_set))])

total_samples = X.shape[0]
print(X.shape)  # Verify the tensor shape
print(type(X))  # Should output <class 'torch.Tensor'>


In [None]:
print(model.conv_time.weight.device)  # Check the device of the conv_time layer
print(next(model.parameters()).device)


In [None]:

batch_size = 100
total_samples = X.shape[0]
print(total_samples)

In [None]:
extract_model_activations(model,X,output_dir="Datasets/activations/model1/",batch_size=batch_size)


In [None]:
extract_model_activations(model2,X,output_dir="Datasets/activations/model2/",batch_size=batch_size)

In [None]:
length = torch.load(f"Datasets/activations/model1/conv_time_batch_{3}.pt")
print(length.shape)

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
del X
# Use OrderedDict to preserve the order of layer names
layer_names_model1 = ['conv_time', 'conv_spat']
layer_names_model2 = layer_names_model1


print(layer_names_model1)
print(layer_names_model2)

In [None]:
import math
batch_nr = total_samples /batch_size

batch_nr

In [None]:

# Compute CKA for each layer from all activations



def compute_kernel_full_highmem(layer, batch_nr, model_number):
    """Computes the linear kernel for all samples at once."""
    
    # List to hold activations for all batches
    activations_list = []
    
    for batch_idx in range(1, math.ceil(batch_nr+1)):
        print(f"\rLoading batch {batch_idx} for {layer} in model {model_number}", end='', flush=True)
        
        # Load batch activations and flatten them
        batch_activations = torch.load(f"Datasets/activations/model{model_number}/{layer}_batch_{batch_idx}.pt")
        batch_activations = batch_activations.reshape(batch_activations.shape[0], -1)  # Flatten
        
        activations_list.append(batch_activations)
    
    # Concatenate all activations from batches along the 0th axis (samples)
    all_activations = torch.cat(activations_list, dim=0)
    print("\nFinal Activations Shape:", all_activations.shape)

    # Compute the kernel for all samples at once (linear kernel)
    full_kernel = torch.matmul(all_activations, all_activations.T)

    return full_kernel

model_1_kernels = OrderedDict()
model_2_kernels = OrderedDict()

for layer in layer_names_model1:
    model_1_kernels[layer] = compute_kernel_full_highmem(layer, batch_nr,1)
for layer in layer_names_model2:
    model_2_kernels[layer] = compute_kernel_full_highmem(layer,batch_nr,2)







In [None]:

print(model_1_kernels["conv_time"].shape)
if (model_1_kernels["conv_time"] == 0).any():
    print("The tensor contains at least one zero value.")
else:
    print("No zero values in the tensor.")


In [None]:
cka_results = OrderedDict()

for layer1, K_x in model_1_kernels.items():
    for layer2, K_y in model_2_kernels.items():
        cka_value = compute_CKA(K_x, K_y)
        cka_results[(layer1, layer2)] = cka_value
        print(f"CKA({layer1}, {layer2}): {cka_value}")

In [None]:

n_layers = len(layer_names_model1)# + len(layer_names_model2)
matrix = np.zeros((n_layers, n_layers))


for (layer1, layer2), similarity in cka_results.items():
    i = layer_names_model1.index(layer1) if layer1 in layer_names_model1 else len(layer_names_model1) + layer_names_model2.index(layer1)
    j = layer_names_model2.index(layer2) if layer2 in layer_names_model2 else len(layer_names_model2) + layer_names_model1.index(layer2)
    
    matrix[i, j] = similarity
    matrix[j, i] = similarity  # Symmetric matrix

df = pd.DataFrame(matrix, index=layer_names_model1, columns=layer_names_model1)

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(df, annot=True, cmap='magma', fmt='.2f', square=True, linewidths=0.5, cbar=True, vmin=0, vmax=1)
plt.title('Average CKA Similarity Heatmap')
plt.xlabel('Layer')
plt.ylabel('Layer')
plt.show()

In [None]:
def compute_kernel_full_lowmem(layer, batch_nr, model_number, total_samples,batch_size, use_cuda=False):
    """Computes the full kernel matrix in batches efficiently using matrix multiplication."""
    
    device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
    

    full_kernel = torch.zeros((total_samples, total_samples), dtype=torch.float32, device=device)
    

    for batch_idx in range(1, math.ceil(batch_nr + 1)):
        print(f"\rLoading batch {batch_idx} for {layer} in model {model_number}", end='', flush=True)
        
        batch_activations = torch.load(f"Datasets/activations/model{model_number}/{layer}_batch_{batch_idx}.pt").to(device)
        batch_activations = batch_activations.reshape(batch_activations.shape[0], -1)
        
        start_idx_col = (batch_idx - 1) * batch_size
        end_idx_col = min(start_idx_col + batch_size, total_samples)  # Fix end index
        for batch_idx2 in range(1, math.ceil(batch_nr + 1)):  # Compute only lower triangle
            batch_activations_transpose = torch.load(
                f"Datasets/activations/model{model_number}/{layer}_batch_{batch_idx2}.pt"
            ).to(device)
            batch_activations_transpose = batch_activations_transpose.reshape(batch_activations_transpose.shape[0], -1)
            
            start_idx_row = (batch_idx2 - 1) * batch_size
            end_idx_row = min(start_idx_row + batch_size, total_samples)  # Fix end index

            kernel_block = batch_activations @ batch_activations_transpose.T  # Matrix multiplication
            full_kernel[start_idx_col:end_idx_col, start_idx_row:end_idx_row] = kernel_block
            full_kernel[start_idx_row:end_idx_row, start_idx_col:end_idx_col] = kernel_block.T  # Use symmetry
    return full_kernel.cpu()


def compute_full_kernels(layer_names_model1, layer_names_model2, batch_nr, total_samples,batch_size):
    """Computes the kernels for both models."""
    model_1_kernels = {}
    model_2_kernels = {}

    # Compute kernels for model 1
    for layer in layer_names_model1:
        model_1_kernels[layer] = compute_kernel_full_lowmem(layer, batch_nr, 1, total_samples,batch_size)

    # Compute kernels for model 2
    for layer in layer_names_model2:
        model_2_kernels[layer] = compute_kernel_full_lowmem(layer, batch_nr, 2, total_samples,batch_size)

    return model_1_kernels, model_2_kernels

model_1_kernels = OrderedDict()
model_2_kernels = OrderedDict()
layer_names_model = layer_names_model1

    
model_1_kernels, model_2_kernels = compute_full_kernels(layer_names_model, layer_names_model, batch_nr,total_samples,batch_size)

In [None]:

print(model_1_kernels["conv_time"].shape)
if (model_1_kernels["conv_time"] == 0).any():
    print("The tensor contains at least one zero value.")
else:
    print("No zero values in the tensor.")


In [None]:
cka_results = OrderedDict()

for layer1, K_x in model_1_kernels.items():
    for layer2, K_y in model_2_kernels.items():
        cka_value = compute_CKA(K_x, K_y)
        cka_results[(layer1, layer2)] = cka_value
        print(f"CKA({layer1}, {layer2}): {cka_value}")

In [None]:
# Initialize a square matrix for the CKA similarities
n_layers = len(layer_names_model1)# + len(layer_names_model2)
matrix = np.zeros((n_layers, n_layers))

# Fill the matrix with the average CKA similarity values
for (layer1, layer2), similarity in cka_results.items():
    i = layer_names_model1.index(layer1) if layer1 in layer_names_model1 else len(layer_names_model1) + layer_names_model2.index(layer1)
    j = layer_names_model2.index(layer2) if layer2 in layer_names_model2 else len(layer_names_model2) + layer_names_model1.index(layer2)
    
    matrix[i, j] = similarity
    matrix[j, i] = similarity  # Symmetric matrix

# Create a DataFrame for better visualization
df = pd.DataFrame(matrix, index=layer_names_model1, columns=layer_names_model1)

# Plot the heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(df, annot=True, cmap='magma', fmt='.2f', square=True, linewidths=0.5, cbar=True, vmin=0, vmax=1)
plt.title('Average CKA Similarity Heatmap')
plt.xlabel('Layer')
plt.ylabel('Layer')
plt.show()