# Explain the GNN predictions

## 0. Preparations
### Imports

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn.models import GraphSAGE

# Snippets from my Master Thesis repository so that the non-relevant parts of the code are as short as possible
from src.data_loading import get_spatial_data
from src.graph_construction import build_radius_delaunay_graph

### Check for GPU

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### Define a Random Number Generator

In [3]:
rng = torch.Generator()
rng.manual_seed(42)

<torch._C.Generator at 0x7f71efd33d30>

## 1. Graph Construction and Model Training

### Define Hyperparameters

The hyperparameters are all chosen according to the results of the evaluation that is the main part of the thesis.

In [4]:
# For graph construction (Radius Delanay)
graph_construction_radius = 0.00018225
self_loops = False  # Self-Loops will not be used since we want to minimize the effect of the gene expression profile of the node itself
edge_weights = False
train_frac = 0.8  # The fraction of nodes that should be used for training. It will be assigned randomly and the rest will be used for testing.

# For the GNN model (GraphSAGE)
num_layers = 3
hidden_channels = 128
activation = "relu"
dropout_rate = 0.2

# For the training algorithm (Adam)
lr = 0.01
weight_decay = 1e-5
n_epochs = 144

### Load Data

In [5]:
cells, genes = get_spatial_data("intestine")

cell_coordinates = cells[["x", "y"]].values
features = genes.values
true_labels = cells["cluster_id"].values
ordered_names = (
    cells[["cluster_id", "cell_type"]].drop_duplicates().set_index("cluster_id").sort_index()["cell_type"].values
)

print("Cell coordinates: ", cell_coordinates.shape)
print("Features: ", features.shape)
print("True labels: ", true_labels.shape)  # Must be 1D array!

Cell coordinates:  (7416, 2)
Features:  (7416, 241)
True labels:  (7416,)


### Construct Radius-Delaunay Graph

This is the overall best performing graph construction method so we will use it for the following.

In [6]:
graph = build_radius_delaunay_graph(
    positions=cell_coordinates,
    radius=graph_construction_radius,
    include_self_loops=self_loops,
    add_distance=edge_weights,
    features=features,
    labels=true_labels, # TODO: Change to `features` since this is our new target
    library="pyg", # The only available option since the code snippet is not the complete version of the normal code
)

num_labels = len(np.unique(true_labels))
num_nodes = graph.num_nodes
num_edges = graph.num_edges

print("Number of labels:", num_labels)
print("Number of nodes:", num_nodes)
print("Number of edges:", num_edges)
print("Average degree:", num_edges / num_nodes)

Number of labels: 19
Number of nodes: 7416
Number of edges: 32636
Average degree: 4.400755124056095


### Train/Test Split

Since we are not changing any hyperparameter and are not really interested in the performance of the model itself, we will omit the validation set and just use the test set for evaluation.

In [7]:
def assign_random_splits(graph: Data) -> Data:
    """Takes the generated PyG graph as input and assigns a mask for the train and test split that is used for training and evaluation. The PyG graph will be returned with the assigned split."""
    graph = graph.clone()
    random_indices = torch.randperm(num_nodes, generator=rng)

    train_indices = random_indices[: int(num_nodes * train_frac)]
    test_indices = random_indices[int(num_nodes * train_frac) :]

    # Not all datatypes are allowed for the index: Only bool, byte or long
    graph.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    graph.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    graph.train_mask[train_indices] = 1
    graph.test_mask[test_indices] = 1
    return graph

# TODO

In [None]:
# I think the easiest way to implement this is to wrap the initialised GNN model we want to use
# This should then only be used in combination with gradient descent versions that do not use any batches but work on every node individually

class MaskedGNN(torch.nn.Module):

    def __init__(self, model: torch.nn.Module, masking_type: str = "ones"):
        """Initializes the masker model

        Args:
            model: Any GNN model
            masking_type: Either "ones" or "avg" which could be the cell type average as discussed. Defaults to "ones".
        """
        self.model = model
        self.masking_type = masking_type

    def forward(self, x, edge_index):
        """Forward function that masks the node that is used for prediction

        Args:
            x: The input node features.
            edge_index: The edge indices (the graph or more specifically the adjacency list representation that is used by PyG)

        Returns:
            The predictions
        """

        # TODO: Find the index of the node that should be predicted and mask it
        
        return self.model(x, edge_index) # Do the rest

In [8]:
# The configurations chosen in the following are chosen as the overall best after the main grid search
def train_model(graph: Data) -> tuple[GraphSAGE, Data]:
    """Initializes and trains the model"""
    model = MaskedGNN(
        model=GraphSAGE(
            in_channels=features.shape[1],
            out_channels=num_labels, # TODO: Change to input size as we want to essentially predict the input
            hidden_channels=hidden_channels,
            num_layers=num_layers,
            dropout=dropout_rate,
            act=activation,
            jk="last",  # This adds a final linear layer after the GNN layers
            # arguments that are passed to the Convolutional layer
            aggr="max",
            normalize=True,
            root_weight=True,
            project=True,
            bias=True,
        ),
        masking_type="ones"
    )

    # Move model to the correct device memory. Needed for training on the GPU.
    model = model.to(device)

    graph = assign_random_splits(graph)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Training loop
    for _ in range(1, n_epochs):
        
        # TODO: Change from full batch to single batch
        # Can be done with the NeighborLoader in PyG: https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader
        # An Example: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_products_sage.py

        model.train()
        optimizer.zero_grad()
        out = model(graph.x, graph.edge_index)
        loss = F.cross_entropy(out[graph.train_mask], graph.y[graph.train_mask]) # TODO: Change to something that is fitting for a regression task. MSE probably...
        loss.backward()
        optimizer.step()

    # Test the model
    pred = model(graph.x, graph.edge_index).argmax(dim=-1)
    acc = int((pred[graph.test_mask] == graph.y[graph.test_mask]).sum()) / int(graph.test_mask.sum())
    print(f"Finished training with an accuracy of {acc:.4f} for the test set")

    return model, graph