In [None]:
# Load the edge index tensor that was extracted from KNN
# edge_index has shape [2, num_edges] where:
#   - edge_index[0] = source node indices
#   - edge_index[1] = target node indices
import torch

edge_index = torch.load("./graph/edge_index_top20.pt")
print(edge_index.shape)


torch.Size([2, 387140])


In [None]:
# Initialize node features for the graph
# num_genes = total number of unique genes in the KNN edge_index
# We create random embeddings of size [num_genes, d_model] to simulate gene embeddings
num_genes = edge_index.max().item() + 1
d_model = 128   # small for testing

# x shape: [num_genes, d_model] - each gene gets a random feature vector
x = torch.randn(num_genes, d_model)


In [None]:
# Define a Graph Attention Network (GAT) with 2 layers
# GATv2Conv uses multi-head attention to aggregate information from neighboring nodes
# heads=4: use 4 attention heads for richer representational capacity
# concat=False: average the heads instead of concatenating (keeps dimensions constant)
import torch.nn as nn
from torch_geometric.nn import GATv2Conv

class GeneGAT(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # First GAT layer: applies attention over KNN edges, outputs [num_genes, dim]
        self.gat1 = GATv2Conv(dim, dim, heads=4, concat=False)
        # Second GAT layer: further refines embeddings with attention, outputs [num_genes, dim]
        self.gat2 = GATv2Conv(dim, dim, heads=4, concat=False)

    def forward(self, x, edge_index):
        # Layer 1: attend to neighbors, apply ReLU activation
        x = self.gat1(x, edge_index).relu()
        # Layer 2: attend to neighbors again on refined features
        x = self.gat2(x, edge_index)
        return x


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Instantiate the GAT model and run a forward pass
# This tests whether the KNN-extracted edges work correctly with the GATv2 layers
model = GeneGAT(d_model)
out = model(x, edge_index)

# Verify that input and output have the same shape
# (GAT preserves node count, only refines feature representations)
print(x.shape, "→", out.shape)


torch.Size([19357, 128]) → torch.Size([19357, 128])
