Below is one practical approach to (1) capture intermediate embeddings from your GNN (so you can compare them via PCA/t-SNE/CKA), and (2) systematically compare models trained at different MSn depths (e.g., cut_tree_at_level=2 vs. cut_tree_at_level=5). This does not rely on PyG’s “Explainability” module, but rather on extracting and saving hidden representations, then applying any off-the-shelf representation-similarity method (such as CKA).

1) Modify Your Model to Output Intermediate Embeddings

Right now, your GNNRetrievalModel only returns the final fingerprint. We can add a separate method (e.g., forward_with_embeds) that also returns hidden embeddings at various points:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from massspecgym.models.base import Stage, MassSpecGymModel
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel

class GNNRetrievalModel(RetrievalMassSpecGymModel):
    def __init__(
        self,
        hidden_channels: int = 128,
        out_channels: int = 4096,  # Fingerprint size
        node_feature_dim: int = 1039,  # Adjust based on your 'spec.x' feature size
        *args,
        **kwargs
    ):
        """GNN-based retrieval model for MSn spectral trees."""
        super().__init__(*args, **kwargs)
        
        # GCN Layers
        self.conv1 = GCNConv(in_channels=node_feature_dim, out_channels=hidden_channels)
        self.conv2 = GCNConv(in_channels=hidden_channels, out_channels=hidden_channels)
        
        # Fully Connected Layers for Fingerprint Prediction
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels),
            nn.Sigmoid()  # Fingerprint bits between 0 and 1
        )
    
    def forward(self, data):
        """
        Original forward pass to predict molecular fingerprint.
        """
        # -> Just calls a helper that returns final output
        x_out, _, _, _ = self.forward_with_embeds(data)
        return x_out
    
    def forward_with_embeds(self, data):
        """
        Extended forward pass that also returns intermediate embeddings:

        Returns:
          x_out: Final [batch_size, fp_size] fingerprint.
          x1:    Node embeddings after conv1 + ReLU   (shape: [num_nodes, hidden_channels]).
          x2:    Node embeddings after conv2 + ReLU   (shape: [num_nodes, hidden_channels]).
          x_pool: Graph-level embedding after global_mean_pool (shape: [batch_size, hidden_channels]).
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GCN
        x = self.conv1(x, edge_index)
        x1 = F.relu(x)  # Node embeddings after conv1

        # Second GCN
        x = self.conv2(x1, edge_index)
        x2 = F.relu(x)  # Node embeddings after conv2
        
        # Pool to graph-level
        x_pool = global_mean_pool(x2, batch)  # [batch_size, hidden_channels]
        
        # Final fingerprint
        x_out = self.fc(x_pool)  # [batch_size, fp_size]
        
        return x_out, x1, x2, x_pool
    
    def step(self, batch: dict, stage: Stage) -> dict:
        """
        Training/Validation/Test step.
        """
        data = batch['spec']          # PyG DataBatch
        fp_true = batch['mol']        # [batch_size, fp_size]
        cands = batch['candidates']   # [total_candidates, fp_size]
        batch_ptr = batch['batch_ptr']  # shape: [batch_size]

        fp_pred = self.forward(data)  # [batch_size, fp_size]
        
        # MSE loss
        loss = F.mse_loss(fp_pred, fp_true)
        
        # Evaluate retrieval scores
        fp_pred_repeated = fp_pred.repeat_interleave(batch_ptr, dim=0)  # [total_candidates, fp_size]
        scores = F.cosine_similarity(fp_pred_repeated, cands)           # [total_candidates]
        
        return {'loss': loss, 'scores': scores}

	•	forward_with_embeds(...) is the key addition, returning:
	•	x_out: the final fingerprint (same shape as your original forward).
	•	x1: node embeddings after the first GCN+ReLU.
	•	x2: node embeddings after the second GCN+ReLU.
	•	x_pool: the pooled graph embedding (before the final MLP that produces the fingerprint).

2) A Helper Function to Collect Embeddings Over a Dataloader

We want to run our model on a set of examples (like the test set) and stack all embeddings. For instance, if we only care about the final graph-level embeddings (x_pool or x_out), we can store them for each example. If you also want node-level embeddings (x1, x2), note that each batch can have a variable number of nodes per graph—so you’d handle that carefully.

Below is an example focusing on the graph-level embeddings (x_pool and x_out) because that’s simpler to compare across examples. If you also need node-level embeddings, you’d store them in a list-of-lists structure or carefully track the batch index.

import numpy as np

def collect_graph_embeddings(model, data_loader, device='cpu'):
    """
    Runs `model.forward_with_embeds(...)` on each batch in `data_loader`,
    collecting graph-level embeddings (both x_pool and final x_out).
    
    Returns:
       - all_pool: np.array of shape [num_graphs, hidden_channels]
       - all_out:  np.array of shape [num_graphs, fp_size]
    """
    model.eval()
    model.to(device)
    
    all_pool = []
    all_out = []
    
    with torch.no_grad():
        for batch_dict in data_loader:
            data = batch_dict['spec'].to(device)
            
            x_out, x1, x2, x_pool = model.forward_with_embeds(data)
            # x_out:  [batch_size, fp_size]
            # x_pool: [batch_size, hidden_channels]
            
            # Convert to CPU numpy
            all_pool.append(x_pool.cpu().numpy())
            all_out.append(x_out.cpu().numpy())
    
    all_pool = np.concatenate(all_pool, axis=0)  # shape = [num_samples, hidden_channels]
    all_out = np.concatenate(all_out, axis=0)    # shape = [num_samples, fp_size]

    return all_pool, all_out

3) Train (or Load) Multiple Models at Different MSn Depths

You mentioned you have code like:

# Example: cut_tree_at_level=2
dataset_msn_l2 = MSnRetrievalDataset(
    pth=file_mgf,
    candidates_pth=file_json,
    featurizer=featurizer,
    mol_transform=MolFingerprinter(fp_size=fp_size),
    max_allowed_deviation=0.005,
    cut_tree_at_level=2
)
data_module_msn_l2 = MassSpecDataModule(
    dataset=dataset_msn_l2,
    batch_size=50,
    split_pth=split_file,
    num_workers=0,
)
model_l2 = GNNRetrievalModel(...)  # same architecture

# Train or load checkpoint
trainer.fit(model_l2, datamodule=data_module_msn_l2)
trainer.save_checkpoint("gnn_l2.ckpt")

Repeat for cut_tree_at_level=3, 4, or 5, storing each as gnn_l3.ckpt, etc.

4) Collect Embeddings on a Common Test Set

To compare embeddings, you ideally want to feed the same input molecules to each model. For best apples-to-apples comparison, do one of the following:
	1.	Use the same test split file for each dataset, so each “batch” has the same molecules (though the deeper MSn trees have more nodes).
	2.	Or create a common test list of molecule IDs and build the PyG Data objects.

Below is an illustration using each model’s own test loader. If your splits are indeed identical, the same batch order will correspond to the same molecules. If not, you’ll need a consistent alignment step.

# Suppose we have test loaders for each dataset:
data_module_msn_l2.setup(stage="test")
test_loader_l2 = data_module_msn_l2.test_dataloader()

data_module_msn_l3.setup(stage="test")
test_loader_l3 = data_module_msn_l3.test_dataloader()

# Load the trained models from checkpoint
model_l2 = GNNRetrievalModel(...)
model_l2.load_state_dict(torch.load("gnn_l2.ckpt")["state_dict"])

model_l3 = GNNRetrievalModel(...)
model_l3.load_state_dict(torch.load("gnn_l3.ckpt")["state_dict"])

# Collect embeddings
pool_l2, out_l2 = collect_graph_embeddings(model_l2, test_loader_l2, device='cpu')
pool_l3, out_l3 = collect_graph_embeddings(model_l3, test_loader_l3, device='cpu')

	•	pool_l2: shape = ￼ with the graph embeddings from the second GCN.
	•	out_l2: shape = ￼ with final fingerprints.

Same for L3, etc.

If your test sets truly contain the same molecules in the same order, then pool_l2[i] corresponds to the same molecule as pool_l3[i]. If not, you have to realign them by an identifier (like batch_dict['identifier']).

5) Do PCA, t-SNE, or CKA on the Resulting Embeddings

With these numpy arrays, you can proceed to compare the embeddings. For example, CKA using your code snippet:

from cka import linear_CKA, kernel_CKA  # or wherever your CKA code is

print("Compare final graph embeddings from L2 vs L3 (out-layer) with linear CKA:")
cka_out = linear_CKA(out_l2, out_l3)
print("CKA on final fingerprint outputs:", cka_out)

print("Compare the pooled hidden state from L2 vs L3 with RBF kernel CKA:")
cka_pool = kernel_CKA(pool_l2, pool_l3)
print("CKA on pooled hidden layer:", cka_pool)

Or if you want to do t-SNE:

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

tsne_embeddings = TSNE(n_components=2, perplexity=30).fit_transform(pool_l2)
plt.scatter(tsne_embeddings[:,0], tsne_embeddings[:,1], s=5)
plt.title("TSNE of GNN-l2 Pooled Embeddings")
plt.show()

You could do the same for pool_l3, then compare visually.

6) Extending to More Depths or More Models

You can do exactly the same steps for a list of depths: [2, 3, 4, 5]. You’ll end up with 4 arrays of embeddings. Then run pairwise CKA or do multi-dimensional scaling, etc. Some approaches:
	•	Pairwise: CKA(pool_l2, pool_l3), CKA(pool_l2, pool_l4), CKA(pool_l3, pool_l5), etc.
	•	Matrix: Build a matrix of shape (4×4), where entry ￼ = CKA(pool_i, pool_j), then do a heatmap.

Putting It All Together
	1.	Add a forward_with_embeds(...) method returning intermediate embeddings.
	2.	Train the same GNN architecture at different MSn tree cut depths.
	3.	Collect embeddings on the same test set.
	4.	Compare these embeddings with PCA, t-SNE, or CKA.

This way, you’ll see numerically and visually how the learned representations differ (or stay similar) when you add deeper MSn levels—without relying on the internal PyG explainers. This is particularly straightforward for tasks like retrieval, where you care about final or near-final graph embeddings.