# Implementation: Global Pooling (Graph Classification)

**Goal**: Reduce N nodes to 1 Vector.

In [None]:
import torch

# 1. Batch of 2 Graphs
# Graph A has 2 nodes. Graph B has 3 nodes.
x = torch.tensor([
    [1.0], [2.0],       # Graph A nodes
    [3.0], [4.0], [5.0] # Graph B nodes
])

# 2. Batch Index (Which graph does this node belong to?)
batch = torch.tensor([0, 0, 1, 1, 1])

# 3. Global Mean Pooling
def global_mean_pool(x, batch):
    # In reality, utilize torch_geometric.nn.global_mean_pool
    num_graphs = batch.max().item() + 1
    out = torch.zeros(num_graphs, x.size(1))
    
    for i in range(num_graphs):
        mask = (batch == i)
        out[i] = x[mask].mean(dim=0)
    return out

pooled = global_mean_pool(x, batch)
print("Graph Embeddings:")
print(pooled)
print("Graph A Mean: (1+2)/2 = 1.5")
print("Graph B Mean: (3+4+5)/3 = 4.0")

## Conclusion
Now you can feed these vectors into a standard Classifier.