<a href="https://colab.research.google.com/github/LukasPertl1/GNN_Explainability/blob/main/CombinedFeatures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import math
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Install PyG and dependencies (if not already installed)
!pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html

# PyTorch Geometric imports:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINConv, global_mean_pool

# For network visualization (optional)
import networkx as nx
from torch_geometric.utils import to_networkx

###############################################################################
# 1) Node Embeddings (Set 1 functions)
###############################################################################

def random_onehot_embedding(num_categories=5, p=0.25):
    """
    Returns a one-hot vector (or a zero vector) of dimension `num_categories`.

    With probability `p`, a one-hot vector is returned; otherwise, a zero vector.
    """
    if random.random() < p:
        emb = torch.zeros(num_categories, dtype=torch.float)
        idx = random.randint(0, num_categories - 1)
        emb[idx] = 1.0
        return emb, idx
    return torch.zeros(num_categories, dtype=torch.float), None


def correlated_onehot_embedding(prev_feature, num_categories=5, p=0.25,
                                base_distribution=None, transition_matrix=None):
    """
    Returns a one-hot vector (or a zero vector) with correlation.

    If activated (with probability p), the feature is sampled based on the previous
    node's feature using either a base distribution or a transition matrix. If not,
    returns a zero vector.
    """
    if random.random() < p:
        if prev_feature is None:
            probs = base_distribution if base_distribution is not None else torch.ones(num_categories) / num_categories
        else:
            probs = transition_matrix[prev_feature] if transition_matrix is not None else torch.ones(num_categories) / num_categories

        idx = torch.multinomial(probs, num_samples=1).item()
        emb = torch.zeros(num_categories, dtype=torch.float)
        emb[idx] = 1.0
        return emb, idx
    else:
        return torch.zeros(num_categories, dtype=torch.float), None


def assign_node_embeddings_correlated(num_nodes, num_categories=5, p=0.25,
                                        base_distribution=None, transition_matrix=None):
    """
    Generates node embeddings for a graph using a simple Markov chain mechanism.

    For each node (in order), a correlated one-hot vector is generated.
    Each embedding is extended by appending an extra constant dimension (always 1).

    Returns:
        A tensor of shape (num_nodes, num_categories+1).
    """
    embeddings = []
    prev_feature = None

    for _ in range(num_nodes):
        emb, curr_feature = correlated_onehot_embedding(
            prev_feature, num_categories, p, base_distribution, transition_matrix
        )
        # Append an extra constant dimension = 1.
        emb_extended = torch.cat([emb, torch.ones(1, dtype=emb.dtype)], dim=0)
        embeddings.append(emb_extended)
        prev_feature = curr_feature if curr_feature is not None else None

    return torch.stack(embeddings, dim=0)


def compute_feature_vector(edge_index, node_embs, num_categories=5):
    """
    Computes a binary vector (length=num_categories) where each element is 1 if there
    exists at least one (non self-loop) edge connecting two nodes that both have the
    corresponding feature activated.

    Only the first `num_categories` components of each node embedding are used.
    """
    features = torch.zeros(num_categories, dtype=torch.float)
    src, tgt = edge_index[0], edge_index[1]

    # Ignore self-loops.
    mask = src != tgt
    src, tgt = src[mask], tgt[mask]

    # Check for each category if there's any edge with both nodes activated.
    for i in range(num_categories):
        active_src = node_embs[src, i] == 1.0
        active_tgt = node_embs[tgt, i] == 1.0
        if torch.any(active_src & active_tgt):
            features[i] = 1.0

    return features

###############################################################################
# 2) Candidate Transition Matrices (for correlated embeddings)
###############################################################################

def create_candidate_transition_matrices(num_categories, favored_groups, alpha=0.8, beta=0.2):
    """
    Creates candidate transition matrices.

    For indices in a favored group, the self-transition probability is set to `alpha`.
    For indices not in a favored group, it is set to `beta`.

    Returns:
        A list of candidate transition matrices.
    """
    candidate_matrices = []
    for group in favored_groups:
        matrix = torch.zeros(num_categories, num_categories)
        for i in range(num_categories):
            if i in group:
                for j in range(num_categories):
                    matrix[i, j] = alpha if i == j else (1 - alpha) / (num_categories - 1)
            else:
                for j in range(num_categories):
                    matrix[i, j] = beta if i == j else (1 - beta) / (num_categories - 1)
        candidate_matrices.append(matrix)
    return candidate_matrices


def favoured_group_calculator(num_features):
    """
    Creates candidate favored groups (pairs) for the given number of features.

    If the number of features is even, groups are created as pairs.
    If odd, the last feature stands alone.
    """
    favoured_groups = []
    if num_features % 2 == 0:
        for i in range(num_features // 2):
            favoured_groups.append([2 * i, 2 * i + 1])
    else:
        for i in range((num_features - 1) // 2):
            favoured_groups.append([2 * i, 2 * i + 1])
        favoured_groups.append([num_features - 1])
    return favoured_groups

###############################################################################
# 3) Motif Graph Topology (Set 2 functions, extended with “pair”)
###############################################################################

def create_motif_edge_index(motif_type="triangle", chain_length=3, motif_dim=3):
    """
    Creates an edge index for a graph that contains:
      - A motif subgraph and attached chains.
      - A motif label as a one-hot vector (of length motif_dim) for motifs
        "triangle", "square", or "pentagon". For motif type "pair", the motif
        label is all zeros.

    Returns:
        edge_index: The edge list in tensor form.
        total_nodes: The total number of nodes in the graph.
        motif_label: The label (as a list) for the motif.
    """
    if motif_type == "pair":
        motif_n = 2
        motif_label = [0] * motif_dim
    elif motif_type == "triangle":
        motif_n = 3
        motif_label = [1, 0, 0]
    elif motif_type == "square":
        motif_n = 4
        motif_label = [0, 1, 0]
    elif motif_type == "pentagon":
        motif_n = 5
        motif_label = [0, 0, 1]
    else:
        raise ValueError("Invalid motif type")

    # Total nodes = motif nodes + one chain per motif node.
    total_nodes = motif_n + motif_n * chain_length
    edges = []

    # Add self-loops for all nodes.
    for i in range(total_nodes):
        edges.append((i, i))

    # Build the motif subgraph.
    if motif_type == "pair":
        # For a pair, add a single edge in both directions.
        edges.append((0, 1))
        edges.append((1, 0))
    elif motif_type in ["square", "pentagon"]:
        # Fully connected subgraph.
        for i in range(motif_n):
            for j in range(i + 1, motif_n):
                edges.append((i, j))
                edges.append((j, i))
    else:  # triangle: use a cycle (which is complete for 3 nodes)
        for i in range(motif_n):
            j = (i + 1) % motif_n
            edges.append((i, j))
            edges.append((j, i))

    # Attach a chain to each motif node.
    for i in range(motif_n):
        start = motif_n + i * chain_length
        edges.append((i, start))
        edges.append((start, i))
        for j in range(chain_length - 1):
            a = start + j
            b = start + j + 1
            edges.append((a, b))
            edges.append((b, a))

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    return edge_index, total_nodes, motif_label

###############################################################################
# 4) Combined Graph Data Generation
###############################################################################

def generate_graph_data_combined(num_samples=1000, feature_dim=5, p=0.35,
                                 chain_length_min=2, chain_length_max=7,
                                 candidate_matrices=None, base_distribution=None,
                                 motif_dim=3):
    """
    For each sample, a motif type is chosen at random from:
         "pair", "triangle", "square", or "pentagon".

    A graph is built with:
      - A motif (with attached chains) as defined via create_motif_edge_index.
      - Node features generated via correlated one-hot embeddings.

    The target vector is the concatenation of:
      (a) A binary vector computed from adjacent node pairs (length = feature_dim)
      (b) The motif label (length = motif_dim)

    Returns:
        A list of Data objects (PyTorch Geometric format).
    """
    data_list = []
    motif_types = ["pair", "triangle", "square", "pentagon"]

    for _ in range(num_samples):
        # Randomly select chain length and motif type.
        chain_length = random.randint(chain_length_min, chain_length_max)
        if random.random() < 0.5:
            motif_type = "pair"
        else:
            motif_type = random.choice(motif_types)

        # Build motif and chain topology.
        edge_index, total_nodes, motif_label = create_motif_edge_index(motif_type, chain_length, motif_dim)

        # Optionally choose a candidate transition matrix at random.
        if candidate_matrices is not None:
            transition_matrix = random.choice(candidate_matrices)
        else:
            transition_matrix = None

        # Generate node embeddings using the correlated method.
        # Each node embedding will have dimension (feature_dim + 1).
        x = assign_node_embeddings_correlated(
            total_nodes, num_categories=feature_dim,
            p=p, base_distribution=base_distribution,
            transition_matrix=transition_matrix
        )

        # Compute the feature part: binary vector from adjacent node pairs.
        feature_y = compute_feature_vector(edge_index, x, num_categories=feature_dim)

        # Convert motif label list to tensor.
        motif_y = torch.tensor(motif_label, dtype=torch.float)

        # Concatenate to form the overall target vector.
        y = torch.cat([feature_y, motif_y], dim=0)

        # Create the data object.
        data = Data(x=x, edge_index=edge_index, y=y)
        data.num_nodes = total_nodes
        data_list.append(data)

    return data_list

###############################################################################
# 5) Pure Graph Check (Total target vector is one-hot)
###############################################################################

def is_pure_graph_combined(target_vec):
    """
    A graph is considered "pure" only if its entire target vector is one-hot,
    i.e. exactly one element is 1 and all others are 0.
    """
    return target_vec.sum().item() == 1.0

###############################################################################
# 6) GIN Model (Using Set 2 style)
###############################################################################

def equiangular_frame(out_dim, hidden_dim):
    """
    Returns a fixed weight matrix with an equiangular configuration for some special cases.

    This helps in setting up the final linear layer in a specific configuration.
    """
    if out_dim == 3 and hidden_dim == 2:
        return torch.tensor([
            [1.0, 0.0],
            [-0.5, math.sqrt(3)/2],
            [-0.5, -math.sqrt(3)/2]
        ])
    elif out_dim == 4 and hidden_dim == 2:
        return torch.tensor([
            [1.0, 0.0],
            [0.0, 1.0],
            [-1.0, 0.0],
            [0.0, -1.0]
        ])
    elif out_dim == 5 and hidden_dim == 2:
        return torch.tensor([
            [math.cos(0 * 2*math.pi/5), math.sin(0 * 2*math.pi/5)],
            [math.cos(1 * 2*math.pi/5), math.sin(1 * 2*math.pi/5)],
            [math.cos(2 * 2*math.pi/5), math.sin(2 * 2*math.pi/5)],
            [math.cos(3 * 2*math.pi/5), math.sin(3 * 2*math.pi/5)],
            [math.cos(4 * 2*math.pi/5), math.sin(4 * 2*math.pi/5)]
        ])
    elif out_dim == 6 and hidden_dim == 2:
        return torch.tensor([
            [math.cos(0 * 2*math.pi/6), math.sin(0 * 2*math.pi/6)],
            [math.cos(1 * 2*math.pi/6), math.sin(1 * 2*math.pi/6)],
            [math.cos(2 * 2*math.pi/6), math.sin(2 * 2*math.pi/6)],
            [math.cos(3 * 2*math.pi/6), math.sin(3 * 2*math.pi/6)],
            [math.cos(4 * 2*math.pi/6), math.sin(4 * 2*math.pi/6)],
            [math.cos(5 * 2*math.pi/6), math.sin(5 * 2*math.pi/6)]
        ])
    elif out_dim == 4 and hidden_dim == 3:
        return torch.tensor([
            [1.0, 0.0, -math.sqrt(0.5)],
            [-1.0, 0.0, -math.sqrt(0.5)],
            [0.0, 1.0, math.sqrt(0.5)],
            [0.0, -1.0, math.sqrt(0.5)]
        ])
    elif out_dim == 6 and hidden_dim == 3:
        # Return the 6 vertices of a regular octahedron in 3D.
        return torch.tensor([
            [1.0,  0.0,  0.0],
            [-1.0, 0.0,  0.0],
            [0.0,  1.0,  0.0],
            [0.0, -1.0,  0.0],
            [0.0,  0.0,  1.0],
            [0.0,  0.0, -1.0]
        ])
        raise ValueError("Equiangular frame not implemented for (out_dim={}, hidden_dim={}).".format(out_dim, hidden_dim))


def initialize_output_weights(W, out_dim, hidden_dim):
    """
    Initializes the weights of the final linear layer.

    Uses an equiangular frame if available; otherwise, defaults to orthogonal initialization.
    """
    try:
        eq_frame = equiangular_frame(out_dim, hidden_dim)
        W.data.copy_(eq_frame.to(W.device).type_as(W))
    except ValueError:
        nn.init.orthogonal_(W)


class SimpleGIN(nn.Module):
    """
    A simple GIN model with multiple layers.

    The final linear layer outputs raw logits for each target dimension.
    The output dimension is (feature_dim + motif_dim).

    Note:
        The input dimension is (feature_dim + 1) because each node embedding has an extra constant dimension.
    """
    def __init__(self, in_dim=5, layer_dims=[6, 6, 2], out_dim=8):
        super(SimpleGIN, self).__init__()

        # Build GIN layers.
        self.gin_layers = nn.ModuleList()
        prev_dim = in_dim
        for hidden_dim in layer_dims:
            mlp = nn.Sequential(
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            self.gin_layers.append(GINConv(mlp, train_eps=True))
            prev_dim = hidden_dim

        # Final linear layer.
        self.lin_out = nn.Linear(prev_dim, out_dim, bias=True)
        initialize_output_weights(self.lin_out.weight, out_dim, prev_dim)

        # Freeze final layer weights for Phase 1 (unfreeze later)
        self.lin_out.weight.requires_grad = False
        if self.lin_out.bias is not None:
            self.lin_out.bias.requires_grad = True

    def forward(self, x, edge_index, batch):
        # Pass through each GIN layer.
        for i, layer in enumerate(self.gin_layers):
            x = layer(x, edge_index)
            if i < len(self.gin_layers) - 1:
                x = F.relu(x)

        # Global mean pooling to get graph-level representation.
        graph_repr = global_mean_pool(x, batch)
        logits = self.lin_out(graph_repr)
        return logits

    def get_graph_repr(self, x, edge_index, batch):
        # Obtain graph-level representation (for analysis purposes).
        for i, layer in enumerate(self.gin_layers):
            x = layer(x, edge_index)
            if i < len(self.gin_layers) - 1:
                x = F.relu(x)
        return global_mean_pool(x, batch)

    def get_hidden_embeddings(self, x, edge_index, batch):
        # Returns node-level embeddings from the final hidden layer.
        for i, layer in enumerate(self.gin_layers):
            x = layer(x, edge_index)
            if i < len(self.gin_layers) - 1:
                x = F.relu(x)
        return x

###############################################################################
# 7) Training and Evaluation Functions
###############################################################################

def train_one_epoch(model, loader, optimizer, criterion, importance, feature_dim, motif_dim, device):
    """
    Trains the model for one epoch.

    Uses two scalars for importance:
      - importance[0] is the pair importance (for the computed feature part).
      - importance[1] is the motif importance.

    For each graph, if the pair part is active then the weight is set to importance[0].
    If the motif part is active then it is set to importance[1].
    If both are active, the higher value is used.
    If neither is active, the weight remains 1.
    """
    model.train()
    total_loss = 0.0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.batch)  # [batch_size, out_dim]
        batch_size, out_dim = logits.size()
        target = data.y.float().view(batch_size, out_dim)
        loss = criterion(logits, target)  # shape: [batch_size, out_dim]

        # Compute per-sample loss.
        sample_loss = loss.mean(dim=1)  # shape: [batch_size]

        # Determine graph-level weight based on target activation.
        # For the computed pair feature part (first feature_dim elements)
        pair_mask = (target[:, :feature_dim].sum(dim=1) > 0)
        # For the motif part (last motif_dim elements)
        motif_mask = (target[:, feature_dim:].sum(dim=1) > 0)

        # Default weight is 1.
        w = torch.ones(batch_size, device=target.device)
        # If only pair part is active.
        w = torch.where(pair_mask & ~motif_mask,
                        torch.full((batch_size,), importance[0], device=target.device),
                        w)
        # If only motif part is active.
        w = torch.where(motif_mask & ~pair_mask,
                        torch.full((batch_size,), importance[1], device=target.device),
                        w)
        # If both are active, take the maximum.
        w = torch.where(pair_mask & motif_mask,
                        torch.full((batch_size,), max(importance[0], importance[1]), device=target.device),
                        w)

        # Expand the weight to match the loss shape.
        w_expanded = w.unsqueeze(1).expand_as(loss)
        weighted_loss = loss * w_expanded
        mean_loss = weighted_loss.mean()

        mean_loss.backward()
        optimizer.step()
        total_loss += mean_loss.item()

    return total_loss / len(loader)


def evaluate(model, loader, device, feature_dim):
    """
    Evaluates the model on test data.

    Only graphs whose entire target vector is one-hot (pure graphs) are used for
    collecting averages of predictions and hidden embeddings.
    """
    model.eval()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    preds_dict = {}       # key: target tuple -> list of prediction tensors
    graph_repr_dict = {}  # key: target tuple -> list of graph representation tensors

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            logits = model(data.x, data.edge_index, data.batch)
            batch_size, out_dim = logits.size()
            target = data.y.float().view(batch_size, out_dim)
            loss = criterion(logits, target)
            batch_loss = loss.mean().item() * batch_size
            total_loss += batch_loss
            total_samples += batch_size

            probs = torch.sigmoid(logits)
            pred = (probs > 0.5).float()
            correct = (pred == target).float().mean(dim=1)
            total_correct += correct.sum().item()

            for i in range(batch_size):
                t_vec = target[i].cpu()
                if not is_pure_graph_combined(t_vec):
                    continue
                key = tuple(int(round(x)) for x in t_vec.tolist())
                if key not in preds_dict:
                    preds_dict[key] = []
                    graph_repr_dict[key] = []
                preds_dict[key].append(pred[i].cpu())
                graph_repr = model.get_graph_repr(data.x, data.edge_index, data.batch)[i].cpu()
                graph_repr_dict[key].append(graph_repr)

    avg_loss = total_loss / total_samples
    avg_accuracy = total_correct / total_samples
    print(f"Overall Loss: {avg_loss:.4f}, Accuracy (exact match): {avg_accuracy*100:.2f}%")

    print("\n=== Average Predictions (Pure Graphs) ===")
    avg_predictions = {}
    for k, preds in preds_dict.items():
        preds_tensor = torch.stack(preds)
        avg_pred = preds_tensor.float().mean(dim=0)
        avg_predictions[k] = avg_pred
        print(f"Target {k} => Average Prediction: {avg_pred.tolist()}")

    print("\n=== Average Hidden Embeddings (Pure Graphs) ===")
    avg_embeddings = {}
    for k, reps in graph_repr_dict.items():
        avg_repr = torch.stack(reps).mean(dim=0)
        avg_embeddings[k] = avg_repr
        print(f"Target {k} => Average Hidden Embedding: {avg_repr.tolist()}")

    return avg_loss, avg_accuracy, preds_dict, avg_embeddings, avg_predictions

###############################################################################
# 8) Geometric Analysis Functions (generalized for m-dimensional hidden embeddings)
###############################################################################

def geometry_of_representation(num_active_targets, preds_dict_active, avg_embeddings_active):
    """
    Computes pairwise angles between average embeddings (of pure graphs that meet the active criteria)
    and returns a geometry value and the number of collapsed embeddings.

    Let n be the number of active embeddings and m be the dimensionality of each embedding.
    The ideal angle is defined as:
      - If n <= m: ideal_angle = π/2.
      - If n == m+1: ideal_angle = arccos(-1/m).
      - If n > m+1: ideal_angle = arccos( sqrt((n-m)/(m*(n-1))) ).

    With a tolerance (set here to 0.1 rad), if every active embedding’s minimum pairwise angle
    is at least (ideal_angle - tol), then the geometry is set to n; otherwise, it is 0.

    Also, the function counts the number of embeddings with a very small minimum angle (< 0.1 rad)
    as collapsed.

    Returns:
        geometry: The number of active embeddings if the condition is met, otherwise 0.
        collapsed: The number of collapsed embeddings.
    """
    embeddings = list(avg_embeddings_active.values())
    n = len(embeddings)
    if n == 0:
        return 0, 0
    m = len(embeddings[0])

    if n <= m:
        ideal_angle = np.pi / 2
    elif n == m + 1:
        ideal_angle = np.arccos(-1.0 / m)
    else:
        ideal_angle = np.arccos(np.sqrt((n - m) / (m * (n - 1))))

    tol = 0.6  # Tolerance in radians.
    smallest_angles = []
    collapsed = 0
    for i in range(n):
        angles = []
        v = np.array(embeddings[i])
        norm_v = np.linalg.norm(v)
        if norm_v == 0:
            continue
        for j in range(n):
            if i == j:
                continue
            w = np.array(embeddings[j])
            norm_w = np.linalg.norm(w)
            if norm_w == 0:
                continue
            cosine = np.dot(v, w) / (norm_v * norm_w)
            cosine = np.clip(cosine, -1.0, 1.0)
            angle = np.arccos(cosine)
            angles.append(angle)
        if angles:
            min_angle = min(angles)
            smallest_angles.append(min_angle)
            if min_angle < 0.1:
                collapsed += 1

    if all(angle >= (ideal_angle - tol) for angle in smallest_angles):
        geometry = n
    else:
        geometry = 0

    return geometry, collapsed


def active_targets_in_representation(target_dim, avg_predictions, avg_embeddings):
    """
    Compares the average predictions with the entire target vector.

    Returns:
       num_active_targets: Number of active target configurations.
       preds_dict_active: Dictionary of predictions for active targets.
       avg_embeddings_active: Dictionary of average embeddings for active targets.
       num_accurate_targets: Count of targets with predictions very close to the true value.
    """
    num_active_targets = 0
    num_accurate_targets = 0
    preds_dict_active = {}
    avg_embeddings_active = {}

    sigma_accurate = 0.3
    sigma_active = 0.5

    for key, preds in avg_predictions.items():
        # Here we check all target_dim elements.
        if all((key[i] - preds[i].item()) < sigma_active for i in range(target_dim)):
            num_active_targets += 1
            preds_dict_active[key] = preds
            avg_embeddings_active[key] = avg_embeddings[key]
        if all(abs(key[i] - preds[i].item()) < sigma_accurate for i in range(target_dim)):
            num_accurate_targets += 1

    return num_active_targets, preds_dict_active, avg_embeddings_active, num_accurate_targets


def structure_of_representation(target_dim, avg_predictions, avg_embeddings, final_loss):
    """
    Analyzes the structure of the representation based on the average predictions and embeddings.

    Checks the entire target vector.

    Returns:
        A list: [target_dim, num_active_targets, num_accurate_targets, geometry, collapsed, final_loss]
    """
    num_active_targets, preds_dict_active, avg_embeddings_active, num_accurate_targets = active_targets_in_representation(
        target_dim, avg_predictions, avg_embeddings
    )
    geometry, collapsed = geometry_of_representation(num_active_targets, preds_dict_active, avg_embeddings_active)

    if geometry > 0:
        category_with_loss = [target_dim, num_active_targets, num_accurate_targets, geometry, collapsed, final_loss]
        print(f"Category_with_loss: [target_dim, num_active_targets, num_accurate_targets, geometry, collapsed, final_loss] = {category_with_loss}")
        return category_with_loss
    else:
        return [target_dim, num_active_targets, num_accurate_targets, "Failed", collapsed, final_loss]


def geometry_analysis(results):
    """
    Groups final loss values by configuration.

    Returns:
        A dictionary mapping configuration keys to lists of loss values.
    """
    config_losses = {}
    for res in results:
        if res is None or any(item == "Failed" for item in res):
            config_key = "Failed"
            config_losses.setdefault(config_key, []).append(res[5] if res is not None else None)
            continue
        key = (res[1], res[2], res[3], res[4])
        config_losses.setdefault(key, []).append(res[5])
    return config_losses


def summarize_config_losses(config_losses):
    """
    Summarizes configuration losses by computing the average and standard deviation.

    Returns:
        A dictionary mapping configuration keys to (avg_loss, std_loss) tuples.
    """
    summary = {}
    for config, losses in config_losses.items():
        losses_clean = [l for l in losses if l is not None]
        avg_loss = np.mean(losses_clean)
        std_loss = np.std(losses_clean)
        summary[config] = (avg_loss, std_loss)
    return summary

###############################################################################
# 9) Experiment Runner
###############################################################################

def run_experiment(feature_dim, motif_dim, train_loader, test_loader, phase1_epochs, phase2_epochs, lr, importance, device):
    """
    Runs a complete experiment:
      - Phase 1: Train with the final layer frozen (only bias updates).
      - Phase 2: Unfreeze the final layer and continue training.

    Then evaluates the model and performs representation analysis.

    Returns:
        The representation analysis result.
    """
    total_target_dim = feature_dim + motif_dim
    hidden_dims = [8, 8, 3]  # Adjust hidden layer sizes as desired.

    # IMPORTANT: The input dimension for the GIN is (feature_dim + 1)
    model = SimpleGIN(in_dim=feature_dim+1, layer_dims=hidden_dims, out_dim=total_target_dim).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss(reduction='none')

    # ---------------------- Phase 1 ----------------------
    for epoch in range(1, phase1_epochs + 1):
        _ = train_one_epoch(model, train_loader, optimizer, criterion, importance, feature_dim, motif_dim, device)

    # ---------------------- Phase 2 ----------------------
    model.lin_out.weight.requires_grad = True
    if model.lin_out.bias is not None:
        model.lin_out.bias.requires_grad = True
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(phase1_epochs + 1, phase1_epochs + phase2_epochs + 1):
        _ = train_one_epoch(model, train_loader, optimizer, criterion, importance, feature_dim, motif_dim, device)

    avg_loss, avg_accuracy, preds_dict, avg_embeddings, avg_predictions = evaluate(model, test_loader, device, feature_dim)
    result = structure_of_representation(total_target_dim, avg_predictions, avg_embeddings, avg_loss)
    return result

###############################################################################
# 10) Main Script
###############################################################################

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ---------------------- Parameters ----------------------
    feature_dim = 3      # Dimension for the computed (pairwise) feature part.
    motif_dim = 3        # Dimension for the motif label part.
    p = 0.23             # Activation probability for node embeddings.
    phase1_epochs = 5
    phase2_epochs = 10
    lr = 0.01
    batch_size = 64
    num_experiments = 100
    chain_length_min = 5
    chain_length_max = 8

    # ---------------------- Prepare Transition Matrices ----------------------
    favored_groups = favoured_group_calculator(num_features=feature_dim)
    candidate_matrices = create_candidate_transition_matrices(feature_dim, favored_groups, alpha=0.17, beta=0.17)
    base_distribution = torch.ones(feature_dim) / feature_dim

    # ---------------------- Data Generation ----------------------
    print("Generating training data...")
    train_data = generate_graph_data_combined(
        num_samples=30000,
        feature_dim=feature_dim,
        p=p,
        chain_length_min=chain_length_min,
        chain_length_max=chain_length_max,
        candidate_matrices=candidate_matrices,
        base_distribution=base_distribution,
        motif_dim=motif_dim
    )

    print("Generating test data...")
    test_data = generate_graph_data_combined(
        num_samples=15000,
        feature_dim=feature_dim,
        p=p,
        chain_length_min=chain_length_min,
        chain_length_max=chain_length_max,
        candidate_matrices=candidate_matrices,
        base_distribution=base_distribution,
        motif_dim=motif_dim
    )

    # ---------------------- Data Loaders ----------------------
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    # ---------------------- Importance ----------------------
    # Instead of a per-feature vector, now we provide two scalars:
    # one for pair importance and one for motif importance.
    # If both are active in a graph, the higher value is used.
    importance = (17.0, 10.0)  # Order is (pair, motif)

    # ---------------------- Run Experiments ----------------------
    results = []
    for exp in range(num_experiments):
        print(f"\nRunning experiment {exp+1}/{num_experiments}...")
        res = run_experiment(feature_dim, motif_dim, train_loader, test_loader,
                             phase1_epochs, phase2_epochs, lr, importance, device)
        results.append(res)

    print("\nAll experiments completed. Results:")
    for res in results:
        print(res)

    # ---------------------- Geometry Analysis ----------------------
    config_losses = geometry_analysis(results)
    print("\n=== Geometry Analysis ===")
    print("Configuration -> (avg loss, std loss):")
    summary = summarize_config_losses(config_losses)
    for config, stats in summary.items():
        print(f"{config} : {stats}")

if __name__ == '__main__':
    main()

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m41.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m33.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m61.5 MB/s[0m e