### Imports and loading from files

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/MB_Grader

!pip install torch_geometric

Mounted at /content/drive
/content/drive/MyDrive/MB_Grader
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m40.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import Batched_GNN_T5 as network

import numpy as np
import pandas as pd
import pickle

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.explain import Explainer, GNNExplainer

from collections import Counter
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch

In [None]:
# PICKLE IN THE TRAIN AND TEST DATA
with open('Training_data/X_train.pkl', 'rb') as f:
    X_train = pickle.load(f)
with open('Training_data/X_test.pkl', 'rb') as f:
    X_test = pickle.load(f)
with open('Training_data/y_train.pkl', 'rb') as f:
    y_train = pickle.load(f)
with open('Training_data/y_test.pkl', 'rb') as f:
    y_test = pickle.load(f)

with open('Training_data/global_graph.pkl', 'rb') as f:
    global_graph = pickle.load(f)

In [None]:
model = network.GNNClassifier(num_classes=11, global_graph=global_graph)
model.load_state_dict(torch.load('Weights/gnn_T5_weights_firsttry.pth', map_location=torch.device('cpu')))

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model.to(device)
# print(f"Using device: {device}")
# model.eval()  # switch to evaluation mode

<All keys matched successfully>

### Actual explainer

In [None]:
# Optional: adjust depending on your dataset and model
EMBED_SIZE = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

# Define slices based on how you built x = [coords | orientations | hold_types | positions]
feature_indices = {
    'coords':        slice(0, 2),
    'positions':  slice(2, 3),
    'hold_types':    slice(3, 8),
    'orientations':     slice(8, 16),
}

class ExplainerAdapter(nn.Module):
    def __init__(self, base_model, base_data):
        super().__init__()
        self.base_model = base_model
        # Direct transfer of other attributes from other data
        self.base_data   = base_data

    def forward(self, x, edge_index):
        # reconstruct the Data object your original model needs
        data = torch_geometric.data.Data(
            x          = x,
            edge_index = edge_index,
            edge_attr  = getattr(self.base_data, 'edge_attr', None),
            y          = self.base_data.y,
            batch      = getattr(self.base_data, 'batch',    None)
        )
        return self.base_model(data)

# --- FEATURE ATTRIBUTION USING GNNEXPLAINER ---
def explain_graph(model, climb):
    climb = climb.to(device)
    adapter = ExplainerAdapter(model, climb)

    explainer = Explainer(
        model=adapter,
        algorithm=GNNExplainer(epochs = 25), # model_config=model_config, epochs = 100, lr = 0.01
        explanation_type='model',
        node_mask_type='attributes', #object, attributes
        edge_mask_type='object',
        model_config=dict(
                mode='multiclass_classification',
                task_level='graph',
                return_type='probs',
            )
    )

    target = climb.y.item() if hasattr(climb, 'y') else None

    explanation = explainer(
        x=climb.x,
        edge_index=climb.edge_index,
        #target=target
    )

    node_mask = explanation.node_mask  # [num_nodes, num_features]

    group_scores = {
        k: node_mask[:, idx].sum().item()
        for k, idx in feature_indices.items()
    }

    total = sum(group_scores.values())
    normalized = {k: v / total for k, v in group_scores.items()}

    return explanation.node_mask, group_scores, normalized


# --- FEATURE MASKING ABLATION ---
import copy

def mask_feature_group(x, group_slice, method="zero"):
    x = x.clone()
    if method == "zero":
        x[:, group_slice] = 0
    elif method == "noise":
        # Comment this line to switch between modes
        x[:, group_slice] = torch.randn_like(x[:, group_slice]) * 0.1
    return x

def run_feature_ablation(model, climb, mask_method="zero"):
    climb = copy.deepcopy(climb).to(device)  # full clone to avoid in-place changes
    model.eval()

    with torch.no_grad():
        original_output = model(climb)
        original_pred = F.softmax(original_output, dim=1)
        original_class = torch.argmax(original_pred, dim=1).item()

    print(f"\nOriginal prediction: class={original_class}, probs={original_pred.cpu().numpy().round(3)}")

    scores = {}
    for group, sl in feature_indices.items():
        masked_graph = copy.deepcopy(climb)
        masked_graph.x = mask_feature_group(masked_graph.x, sl, method=mask_method)

        with torch.no_grad():
            masked_output = model(masked_graph)
            masked_pred = F.softmax(masked_output, dim=1) + 1e-8  # prevent log(0)
            masked_class = torch.argmax(masked_pred, dim=1).item()

        # Log prediction and class
        print(f"Masked group: {group}")
        print(f"→ Masked class: {masked_class}, probs={masked_pred.cpu().numpy().round(3)}")

        # Use KL divergence to measure change
        kl_div = F.kl_div(masked_pred.log(), original_pred, reduction="batchmean")
        scores[group] = kl_div.item()

    return scores

### Run it

In [None]:
for i in range(10):
  #mask, group, normed = explain_graph(model, X_test[i])
  #print(i, X_test[i].y.item(), normed)
  s = run_feature_ablation(model, X_test[i])
  print()
  print(s)
  print()


Original prediction: class=0, probs=[[0.999 0.001 0.    0.    0.    0.    0.    0.    0.    0.    0.   ]]
Masked group: coords
→ Masked class: 1, probs=[[0.409 0.508 0.083 0.    0.    0.    0.    0.    0.    0.    0.   ]]
Masked group: positions
→ Masked class: 0, probs=[[0.646 0.354 0.    0.    0.    0.    0.    0.    0.    0.    0.   ]]
Masked group: hold_types
→ Masked class: 5, probs=[[0.001 0.033 0.075 0.    0.    0.89  0.    0.    0.    0.    0.   ]]
Masked group: orientations
→ Masked class: 7, probs=[[0.    0.    0.    0.    0.054 0.002 0.    0.944 0.    0.    0.   ]]

{'coords': 0.8865851759910583, 'positions': 0.4305455982685089, 'hold_types': 7.288563251495361, 'orientations': 18.403623580932617}


Original prediction: class=3, probs=[[0.003 0.016 0.04  0.932 0.009 0.    0.    0.    0.    0.    0.   ]]
Masked group: coords
→ Masked class: 4, probs=[[0.009 0.035 0.3   0.148 0.382 0.122 0.001 0.001 0.001 0.001 0.   ]]
Masked group: positions
→ Masked class: 2, probs=[[0.015 0

In [None]:
for i in range(10):
  mask, group, normed = explain_graph(model, X_test[i])
  print(i, X_test[i].y.item(), normed)

0 0 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.5278516594497401, 'orientations': 0.4721483405502599}
1 3 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.5465827458557185, 'orientations': 0.45341725414428147}
2 4 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.36506771461023285, 'orientations': 0.6349322853897671}
3 0 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.6261906641345216, 'orientations': 0.3738093358654783}
4 4 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.5272980756458483, 'orientations': 0.47270192435415165}
5 0 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.4912918684357621, 'orientations': 0.5087081315642379}
6 2 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.6254861848644986, 'orientations': 0.3745138151355014}
7 0 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.5155877687695005, 'orientations': 0.48441223123049953}
8 0 {'coords': 0.0, 'positions': 0.0, 'hold_types': 0.47805017520728565, 'orientations': 0.5219498247927143}
9 6 {'coords': 0.0, 'po

### other

In [None]:
### NEXT STEP IS TO IDENTIFY WHAT THE WEIGHTS OF EVERY HOLD / SLICE ACTUALLY ARE AND ESSENTIALLY WHERE THIS NON-USE OF COORDS+ORIENTATIONS IS HAPPENNING

In [None]:
# from captum.attr import IntegratedGradients
# from captum.attr import visualization as viz

# ig = IntegratedGradients(model)
# attributions = ig.attribute(inputs=(X_test[69]),
#                             target=Y_test[69],
#                             n_steps=50)

# # attributions[0] is feature‐wise importance for each node
# fig, ax = viz.visualize_feature_importances(
#     attributions[0].cpu().detach().numpy(),
#     feature_names=[f"f{i}" for i in range(x.size(1))]
# )

In [None]:
#group

In [None]:
# mask.shape, X_test[69].x.shape

# for x, m in zip(X_test[69].x, mask):
#   print(x)
#   print(m)
#   print()

In [None]:
raise

RuntimeError: No active exception to reraise

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

class FeatureActivationLogger:
    def __init__(self, model, device='cpu', verbose=True, plot_histograms=False):
        self.model = model.to(device)
        self.device = device
        self.verbose = verbose
        self.plot_histograms = plot_histograms

        self.reset()

    def reset(self):
        self.all_stats = {
            'coords': [],
            'positions': [],
            'hold_types': [],
            'orientations': []
        }

    def analyze_one_climb(self, climb_data):
        self.model.eval()
        climb_data = climb_data.to(self.device)

        with torch.no_grad():
            # Step 1: Forward up to feature embeddings
            x = climb_data.x
            coords_y = x[:, 0].to(torch.long)
            coords_x = x[:, 1].to(torch.long)
            coords_y = self.model.coordinate_embedding_y(coords_y)
            coords_x = self.model.coordinate_embedding_x(coords_x)
            coords   = torch.cat([coords_y, coords_x], dim=1)
            coords   = self.model.coordinate_smoosh(coords)
            coords   = F.leaky_relu(coords)

            positions = x[:, 2].to(torch.long)
            positions = self.model.position_embedding_1(positions)
            positions = F.leaky_relu(positions)
            #positions = self.model.position_embedding_2(positions)

            hold_types = x[:, 3:8]
            hold_types_sum = hold_types.sum(dim=1, keepdim=True)
            hold_types = self.model.hold_type_one_hot_embedding(hold_types)
            hold_types = hold_types / (hold_types_sum + self.model.epsilon)

            orientations = x[:, 8:16]
            orientations_sum = orientations.sum(dim=1, keepdim=True)
            orientations = self.model.orientation_one_hot_embedding(orientations)
            orientations = orientations / (orientations_sum + self.model.epsilon)

            # Step 2: Calculate stats
            stats = {}
            for group_name, tensor in zip(
                ['coords', 'positions', 'hold_types', 'orientations'],
                [coords, positions, hold_types, orientations]
            ):
                mean_abs = tensor.abs().mean().item()
                zero_fraction = (tensor == 0).float().mean().item()

                stats[group_name] = {
                    'mean_abs': mean_abs,
                    'zero_fraction': zero_fraction
                }

                # Save to overall stats
                self.all_stats[group_name].append(stats[group_name])

                if self.verbose:
                    print(f"--- {group_name} ---")
                    print(f"Mean absolute value: {mean_abs:.6f}")
                    print(f"Fraction zeros: {zero_fraction:.4f}")

                if self.plot_histograms:
                    plt.hist(tensor.cpu().numpy().flatten(), bins=30)
                    plt.title(f"{group_name} activation histogram")
                    plt.xlabel("Value")
                    plt.ylabel("Frequency")
                    plt.show()

    def summarize(self):
        print("\n\n==== Summary across all climbs ====")
        for group_name in self.all_stats.keys():
            mean_abs_values = [x['mean_abs'] for x in self.all_stats[group_name]]
            zero_fractions  = [x['zero_fraction'] for x in self.all_stats[group_name]]

            print(f"\nFeature Group: {group_name}")
            print(f"Avg Mean Absolute Activation: {np.mean(mean_abs_values):.6f}")
            print(f"Avg Zero Fraction: {np.mean(zero_fractions):.4f}")

In [None]:
logger = FeatureActivationLogger(model, device=device, verbose=True, plot_histograms=False)

for i in range(5):  # Assuming X_test is a list of Data objects
    logger.analyze_one_climb(X_test[i])
    print()

# After analyzing many climbs:
logger.summarize()

In [None]:
raise

In [None]:
print(torch_geometric.__version__)

In [None]:
class ExplainerAdapter(nn.Module):
    def __init__(self, base_model, base_data):
        super().__init__()
        self.base_model = base_model
        # Direct transfer of other attributes from other data
        self.base_data   = base_data

    def forward(self, x, edge_index):
        # reconstruct the Data object your original model needs
        data = torch_geometric.data.Data(
            x          = x,
            edge_index = edge_index,
            edge_attr  = getattr(self.base_data, 'edge_attr', None),
            y          = self.base_data.y,
            batch      = getattr(self.base_data, 'batch',    None)
        )
        return self.base_model(data)

# Instantiate it
adapter = ExplainerAdapter(model, X_test[0])

In [None]:
climb = X_test[0]

explainer = Explainer(
        model=adapter,
        algorithm=GNNExplainer(),
        explanation_type='model',
        node_mask_type='attributes',
        edge_mask_type='object',
        model_config=dict(
                mode='multiclass_classification',
                task_level='graph',
                return_type='probs',
            )
    )

explanation = explainer(
        x=climb.x,
        edge_index=climb.edge_index,
        #target=target
    )



In [None]:
node_mask = explanation.node_mask  # [num_nodes, num_features]

group_scores = {
    k: node_mask[:, idx].sum().item()
    for k, idx in feature_indices.items()
}

total = sum(group_scores.values())
normalized = {k: v / total for k, v in group_scores.items()}

In [None]:
normalized

In [None]:
X_test[0]