In [1]:
# File: train_gnn_analyst_with_minibatching.py

import torch
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
from torch_geometric.loader import NeighborLoader
from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.transforms as T

# --- 1. Load the Dataset ---
# PyG will handle downloading and storing it in the 'data/' directory.
dataset = PygNodePropPredDataset(name='ogbn-mag', transform=T.ToUndirected())
data = dataset[0]

# Define the training nodes (we only train on the 'paper' nodes in this benchmark)
train_idx = data['paper'].train_mask.nonzero(as_tuple=False).view(-1)

# --- 2. Create the Efficient Data Loader ---
# This is the key component for large-scale graph training.
# It handles the neighbor sampling for us.
train_loader = NeighborLoader(
    data,
    # Number of neighbors to sample for each node in each layer.
    # [-1] means sample all neighbors. [15, 10] means 15 for 1-hop, 10 for 2-hop.
    num_neighbors=[10, 10], 
    batch_size=2048,  # A large batch size, perfect for 100GB VRAM!
    input_nodes=('paper', train_idx),
    shuffle=True,
    num_workers=4 # Use multiple CPU cores to prepare batches
)

# --- 3. Define the GNN Model ---
# We use HGTConv because it's designed for heterogeneous graphs like ogbn-mag.
class GNN_Analyst(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads=8):
        super().__init__()
        self.conv1 = HGTConv(-1, hidden_channels, data.metadata(), heads=num_heads)
        self.conv2 = HGTConv(hidden_channels * num_heads, out_channels, data.metadata(), heads=1)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: x.relu() for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict['paper'] # We only care about the output for 'paper' nodes

# --- 4. The Training Loop ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN_Analyst(hidden_channels=64, out_channels=dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train():
    model.train()
    total_loss = total_examples = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # The model only sees the small subgraph for this batch
        out = model(batch.x_dict, batch.edge_index_dict)
        
        # We only calculate loss on the seed nodes of the batch
        batch_size = batch['paper'].batch_size
        loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])
        
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss) * batch_size
        total_examples += batch_size
        
    return total_loss / total_examples

print("Starting GNN Analyst training...")
for epoch in range(1, 11): # A full run would be more epochs
    loss = train()
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}')

# --- 5. Save the Final Model ---
# This is the reusable artifact for your document synthesis pipeline.
torch.save(model.state_dict(), 'gnn_analyst_mag_pretrained.pt')
print("\nTraining complete. Model saved to 'gnn_analyst_mag_pretrained.pt'")

ModuleNotFoundError: No module named 'ogb'