### Next Steps for GraphSAGE Pipeline:
1. Set up train/validation/test splits
2. Implement heterogeneous GraphSAGE model
3. Train and evaluate venue prediction performance (supervised)
4. Make train/validation/test sets inductive. 
5. Use an unsupervised GraphSAGE to learn node embeddings
6. Implement a simple classifcation head to perform venue prediction (self supervised learning).
7. Extend model to perform link prediction 

In [None]:
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt # plotting
import numpy as np # linear algebra
import os # accessing directory structure
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.datasets import OGB_MAG
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv, GraphSAGE, to_hetero
from torch_geometric.nn.conv import MessagePassing

from src.dataset_to_inductive import to_inductive
print("Current working directory:", os.getcwd())
%load_ext autoreload
%autoreload 2

Current working directory: c:\Users\gabri\GTFO_Onedrive\DTU_Code\GraphSSL
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
# Load the OGB-MAG dataset
print("Current working directory:", os.getcwd())
print("\nLoading OGB-MAG dataset...")

# Define data path relative to the workspace folder
data_path = os.path.join('data', 'ogb_mag')
os.makedirs(data_path, exist_ok=True)

# Load dataset
transform = ["to_undirected"] # insert preprocessing steps that should be applied to the data. It is common to include reverse edges.
preprocess = "transe" # specify how to obtain initial embeddings for nodes ("transe", "metapath2vec") are some options.
dataset = OGB_MAG(root=data_path, preprocess=preprocess, transform=transform)
# Get the heterogeneous graph data. Is of type HeteroData.
data = dataset.data

node_type = "paper" # target node type
data_inductive = to_inductive(data.clone(), node_type)


print(f"\nDataset loaded successfully!")
print(f"Data saved in: {os.path.abspath(data_path)}")

Current working directory: c:\Users\gabri\GTFO_Onedrive\DTU_Code\GraphSSL

Loading OGB-MAG dataset...


  data = dataset.data



Dataset loaded successfully!
Data saved in: c:\Users\gabri\GTFO_Onedrive\DTU_Code\GraphSSL\data\ogb_mag


In [52]:
print(torch. __version__)
print(torch.version.cuda)

2.9.0+cpu
None


## Data Loaders

Create NeighborLoader for mini-batch training with neighborhood sampling.

In [49]:
# Create NeighborLoader for mini-batch training
# This implements neighborhood sampling as described in the GraphSAGE paper

# Define sampling parameters
num_neighbors = [10, 10]  # Sample 10 neighbors at each of 2 layers (as in GraphSAGE paper)
batch_size = 1024

# Create train loader
train_loader = NeighborLoader(
    data_inductive,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=('paper', data_inductive['paper'].train_mask),
    shuffle=True,
)

# Create validation loader
val_loader = NeighborLoader(
    data_inductive,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=('paper', data_inductive['paper'].val_mask),
    shuffle=False,
)

# Create test loader
test_loader = NeighborLoader(
    data_inductive,
    num_neighbors=num_neighbors,
    batch_size=batch_size,
    input_nodes=('paper', data_inductive['paper'].test_mask),
    shuffle=False,
)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")
print(f"Test loader: {len(test_loader)} batches")

# Sample a batch to inspect structure
sample_batch = next(iter(train_loader))
print(f"\nSample batch structure:")
print(f"Node types: {sample_batch.node_types}")
print(f"Edge types: {sample_batch.edge_types}")
print(f"Paper nodes in batch: {sample_batch['paper'].num_nodes}")
print(f"Batch size (target nodes): {sample_batch['paper'].batch_size}")

Train loader: 615 batches
Val loader: 0 batches
Test loader: 0 batches


ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

## GraphSAGE Model Implementation

Implementing the GraphSAGE architecture as described in "Inductive Representation Learning on Large Graphs" (Hamilton et al., 2017).

Key features:
- **Neighborhood Aggregation**: Sample and aggregate features from node neighborhoods
- **Layer-wise Propagation**: Stack multiple GraphSAGE layers
- **Heterogeneous Support**: Handle multiple node and edge types using `to_hetero`
- **Mean Aggregator**: Use mean aggregation (default in SAGEConv)

### Important Note on Heterogeneous Graphs

**Problem**: When converting to heterogeneous graphs with `to_hetero`, all node types must be updated during message passing. However, some node types (like `'author'`, `'institution'`, `'field_of_study'`) don't have initial features in OGB-MAG.

**Solution**: Add a `Linear` layer at the beginning that creates embeddings for ALL node types. The `-1` in `torch.nn.Linear(-1, hidden_channels)` allows PyTorch Geometric to automatically infer the input dimension for each node type after `to_hetero` conversion, creating separate linear layers for each node type.

In [None]:
class MyGraphSAGE(torch.nn.Module):
    """
    GraphSAGE model for node classification on heterogeneous graphs.
    
    Architecture:
    - Linear projections for each node type (handles nodes without features)
    - Two GraphSAGE convolutional layers
    - ReLU activation between layers
    - Dropout for regularization
    - Supports heterogeneous graphs via to_hetero conversion
    """
    def __init__(self, hidden_channels, out_channels, num_layers=2, dropout=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.dropout = dropout
        
        # Linear layer to project input features to hidden_channels
        # Use -1 to let PyG infer the input dimension for each node type
        self.lin = torch.nn.Linear(-1, hidden_channels)
        
        # Define GraphSAGE layers
        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv((-1, -1), hidden_channels))
        self.convs.append(SAGEConv((-1, -1), out_channels))
        
    def forward(self, x, edge_index):
        # Project input features to hidden dimension
        x = self.lin(x).relu()
        
        # Apply GraphSAGE layers
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

# Initialize the model
hidden_channels = 256
num_classes = int(data_inductive['paper'].y.max().item() + 1)
num_layers = 2

# Create base model
model = MyGraphSAGE(
    hidden_channels=hidden_channels,
    out_channels=num_classes,
    num_layers=num_layers,
    dropout=0.5
)

# Convert to heterogeneous model. Makes GraphSAGE work for different node and edge types in the same graph
# aggr='mean' is more standard for GraphSAGE
model = to_hetero(model, data_inductive.metadata(), aggr='mean')

print(f"Model initialized with {hidden_channels} hidden channels")
print(f"Number of output classes: {num_classes}")
print(f"Number of layers: {num_layers}")
print(f"\nModel structure:")
print(model)

  self.validate()
  return transformer.transform()


ValueError: Cannot generate a graph node 'relu' for type 'author' since it does not exist. Please make sure that all node types get updated during message passing.

In [None]:
# Training and evaluation functions
def train_epoch(model, loader, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    total_loss = 0
    total_correct = 0
    total_examples = 0
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        out = model(batch.x_dict, batch.edge_index_dict)
        
        # Get predictions for target nodes only (papers in the batch)
        pred = out['paper'][:batch['paper'].batch_size]
        y = batch['paper'].y[:batch['paper'].batch_size]
        
        # Compute loss
        loss = F.cross_entropy(pred, y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item() * pred.size(0) # float(loss)
        total_correct += int((pred.argmax(dim=-1) == y).sum())
        total_examples += pred.size(0)
    
    return total_loss / total_examples, total_correct / total_examples


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    total_correct = 0
    total_examples = 0
    
    for batch in loader:
        batch = batch.to(device)
        
        # Forward pass
        out = model(batch.x_dict, batch.edge_index_dict)
        
        # Get predictions for target nodes only
        pred = out['paper'][:batch['paper'].batch_size]
        y = batch['paper'].y[:batch['paper'].batch_size]
        
        # Compute loss
        loss = F.cross_entropy(pred, y)
        
        # Track metrics
        total_loss += float(loss) * pred.size(0)
        total_correct += int((pred.argmax(dim=-1) == y).sum())
        total_examples += pred.size(0)
    
    return total_loss / total_examples, total_correct / total_examples

print("Training and evaluation functions defined.")

## Training Loop

Train the supervised GraphSAGE model for venue (label) prediction.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to device
model = model.to(device)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# Training parameters
num_epochs = 50
best_val_acc = 0
patience = 10
patience_counter = 0

# Training history
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

print(f"\nStarting training for {num_epochs} epochs...")
print("=" * 80)

for epoch in range(1, num_epochs + 1):
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device)
    
    # Evaluate
    val_loss, val_acc = evaluate(model, val_loader, device)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print progress
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:3d} | "
              f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    
    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch}")
            break

print("=" * 80)
print(f"Training complete! Best validation accuracy: {best_val_acc:.4f}")

# Load best model for testing
model.load_state_dict(torch.load('best_model.pt'))
test_loss, test_acc = evaluate(model, test_loader, device)
print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

## Visualize Training Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Plot accuracy
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Loss: {test_loss:.4f}")