# PyTorch Tutorial: Graph Neural Networks (GNNs)

Graphs are everywhere: Social Networks (Facebook), Knowledge Graphs (Google), and Molecule Structures (Drug Discovery). **Graph Neural Networks (GNNs)** are the state-of-the-art for learning on this non-Euclidean data.

## Learning Objectives
- Understand **Graphs**: Nodes, Edges, and Adjacency Matrices.
- Understand **Message Passing**: How nodes talk to neighbors.
- Implement a **GCN (Graph Convolutional Network)** using `torch_geometric`.
- Solve a **Link Prediction** task (Recommendation System).

## 1. Vocabulary First

- **Node (Vertex)**: An entity (e.g., a User, a Product).
- **Edge (Link)**: A connection (e.g., "Friend of", "Bought").
- **Feature Matrix ($X$)**: Attributes of each node (e.g., Age, Location).
- **Adjacency Matrix ($A$)**: A grid showing who is connected to whom.
- **Message Passing**: Aggregating information from neighbors to update a node's embedding.

### Why Graphs Matter at Scale

Graphs are the natural data structure for relationships. Unlike images (grid) or text (sequence), graphs have **no fixed structure** — each node can have any number of neighbors.

**Real-world applications at top tech companies:**
- **Pinterest (PinSage)**: Recommends pins by learning embeddings of 3 billion nodes on a graph of pins and boards
- **Google (Knowledge Graph)**: Powers search results by reasoning over entities and relationships
- **Uber (Fraud Detection)**: Detects coordinated fraud rings by analyzing transaction graphs
- **Drug Discovery (DeepMind)**: Predicts molecular properties by treating molecules as graphs (atoms = nodes, bonds = edges)
- **Twitter/X**: Detects bot networks through graph structure analysis

### The Message Passing Intuition

Think of a party where you can only talk to people standing next to you:

1. **Round 1**: You hear what your direct neighbors say (1-hop information)
2. **Round 2**: Your neighbors relay what *their* neighbors said (2-hop information)
3. **Round 3**: Information from 3 hops away reaches you

Each GNN layer = one round of conversation. After L layers, each node has heard from neighbors up to L hops away.

**The Over-Smoothing Problem**: If you stack too many layers (too many rounds of conversation), every node ends up with the same information — like a game of telephone where everyone converges to the same message. This is why most GNNs use only 2-4 layers, unlike deep CNNs with 100+ layers.

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

print("Ready for Graphs!")

## 2. Creating a Simple Graph

Let's create a small social network with 3 people.
- Node 0 is friends with Node 1.
- Node 1 is friends with Node 2.

In [None]:
# Edge Index (COO format): [Source Nodes, Target Nodes]
edge_index = torch.tensor([
    [0, 1, 1, 2],
    [1, 0, 2, 1]
], dtype=torch.long)

# Node Features (e.g., Age, Activity Level) - 2 features per node
x = torch.tensor([
    [-1, 0], # Node 0
    [ 0, 1], # Node 1
    [ 1, 0]  # Node 2
], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
print(data)

## 3. Graph Convolutional Network (GCN)

A GCN layer updates a node's representation by averaging its neighbors' features.

$$ h_v^{(l+1)} = \sigma \left( \sum_{u \in \mathcal{N}(v)} \frac{1}{c_{uv}} W^{(l)} h_u^{(l)} \right) $$

### GCN vs GAT vs GraphSAGE (The Three Major Architectures)

| Architecture | How It Aggregates Neighbors | Strengths | Weaknesses |
|-------------|---------------------------|-----------|------------|
| **GCN** | Weighted average (fixed weights based on degree) | Simple, fast, good baseline | Treats all neighbors equally |
| **GAT** (Graph Attention) | Learns attention weights per neighbor | Can prioritize important neighbors | More parameters, slower |
| **GraphSAGE** | Samples a fixed number of neighbors, then aggregates | Scales to huge graphs (billions of nodes) | Sampling introduces noise |

**GCN** is the simplest: it averages neighbor features weighted by node degree (popular nodes contribute less per connection). This is analogous to normalized averaging.

**GAT** adds attention: "Not all friends are equally important." It learns a weight for each edge that determines how much each neighbor contributes. Like Transformer self-attention, but on graph edges.

**GraphSAGE** is built for scale: instead of aggregating ALL neighbors (impossible for a node with 1 million connections), it randomly samples a fixed number (e.g., 25) per layer. This makes it practical for production graphs with billions of nodes.

### Transductive vs Inductive Learning

- **Transductive** (GCN): Trained on a fixed graph. If a new node appears, you must retrain. Good for: static graphs (knowledge graphs, citation networks).
- **Inductive** (GraphSAGE, GAT): Can generalize to unseen nodes. Good for: dynamic graphs where new users/items appear constantly (social networks, e-commerce).

In [None]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(2, 16) # Input: 2 features -> Hidden: 16
        self.conv2 = GCNConv(16, 2) # Hidden: 16 -> Output: 2 (Embedding)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return x

model = GCN()
print(model)

## 4. Link Prediction (Recommendation)

To recommend a friend, we check if the dot product of two node embeddings is high.

$$ Score(u, v) = h_u \cdot h_v $$

### Why Dot Product Works for Recommendations

The embedding space is learned so that **connected nodes are close together**. The dot product measures similarity in this space:
- **High score**: Nodes are similar (likely to connect)
- **Low score**: Nodes are dissimilar (unlikely to connect)

This is the same principle behind collaborative filtering in recommendation systems — but GNNs can also incorporate node features (content-based) alongside graph structure (collaborative), giving you the best of both worlds.

### The Three Main GNN Tasks

1. **Node Classification**: Predict a label for each node (e.g., "Is this user a bot?")
2. **Link Prediction**: Predict whether an edge should exist (e.g., "Should we recommend user A follow user B?")
3. **Graph Classification**: Predict a label for an entire graph (e.g., "Is this molecule toxic?")

In [None]:
# Forward pass to get embeddings
embeddings = model(data)

# Predict link between Node 0 and Node 2 (who are NOT friends yet)
node_0 = embeddings[0]
node_2 = embeddings[2]

score = torch.matmul(node_0, node_2)
prob = torch.sigmoid(score)

print(f"Probability of connection between 0 and 2: {prob.item():.4f}")

## Key Takeaways

1. **Graphs** model relationships — they're the natural data structure for social networks, molecules, knowledge bases, and fraud detection.
2. **GNNs** learn embeddings through message passing — each layer aggregates information from neighbors, building richer representations.
3. **Over-smoothing limits depth** — unlike CNNs, GNNs work best with 2-4 layers. Too many layers cause all node embeddings to converge.
4. **Architecture choice depends on scale**: GCN for simplicity, GAT for importance-weighted neighbors, GraphSAGE for billion-node production graphs.
5. **Transductive vs Inductive**: If new nodes appear regularly (e-commerce, social media), use an inductive method like GraphSAGE.
6. **Link Prediction** is the basis of modern recommendation systems — learned embeddings capture both content features and graph structure simultaneously.