<a href="https://colab.research.google.com/github/adelsuh/cs224_final_project/blob/main/graph_structure_aug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CS224W Final Project: Tutorial on the Augmentation of Graphs in PyG

### Jerry Chan, Jihee Suh, John So

### Notebook setup: install PyG + torch

In [1]:
import torch
torch_version = str(torch.__version__)
if "2.4.0" not in torch_version:
  !pip install torch==2.4.0 -q
print(torch_version)

2.4.0+cu121


In [2]:
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src -q
!pip install torch-sparse -f $sparse_src -q
!pip install torch-geometric -q
!pip install ogb -q

In [3]:
import os
import random
import numpy as np
import torch

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

using device: cuda


### Setting up the dataset and tasks

The below code sets up some hyperparameters which will be used in dataloading and training.


In [15]:
# Model settings
hidden_dim = 128 #@param {type: "integer"}
num_layers = 2 #@param {type: "integer"}
# Training settings
learning_rate = 0.0001 #@param {type: "number"}
num_epochs = 10 #@param {type: "integer"}

# Dataloader settings
batch_size = 32 #@param {type: "integer"}
fan_out = 10 #@param {type: "integer", hint: "Used in neighborhood sampling to sample a subgraph"}
dataloader_num_workers = 2 #@param {type: "integer"}

print(f"""
Running training with the following configuration:
   hidden_dim: {hidden_dim}
   num_layers: {num_layers}
   learning_rate: {learning_rate}
   num_epochs: {num_epochs}
   batch_size: {batch_size}
""")


Running training with the following configuration:
   hidden_dim: 128
   num_layers: 2
   learning_rate: 0.0001
   num_epochs: 10
   batch_size: 32



### About the task: etc. etc.

Some generic description about obgn.

Run the below block to create the dataset. If this is your first time loading the dataset, it will additionally prompt you to download files.

**Note**: this block loads the dataset into RAM each time it is called! So calling this block multiple times will likely consume all of the notebook's RAM. Take caution.

In [4]:
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name='ogbn-products', root='./products/')

  self.data, self.slices = torch.load(self.processed_paths[0])


In [17]:
split_idx = dataset.get_idx_split()
# sample test set to speed up
split_idx['test'] = split_idx['test'][:1000]
split_idx['valid'] = split_idx['valid'][:1000]

print(f"""
Summary of the OBGN Products dataset:
  Number of graphs: {len(dataset)}
  Number of features: {dataset.num_features}
  Number of classes: {dataset.num_classes}
  Length of each split:
    Training: {len(split_idx['train'])}
    Validation: {len(split_idx['valid'])}
    Test: {len(split_idx['test'])}
""")


Summary of the OBGN Products dataset:
  Number of graphs: 1
  Number of features: 100
  Number of classes: 47
  Length of each split:
    Training: 196615
    Validation: 1000
    Test: 1000



Now, let's create some dataloaders!

todo: write something about neighbor loader. why do we need this?

In [19]:
from torch_geometric.loader import NeighborLoader

data = dataset[0]
train_loader = NeighborLoader(
    data,
    input_nodes=split_idx['train'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=dataloader_num_workers
)
val_loader = NeighborLoader(
    data,
    input_nodes=split_idx['valid'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=dataloader_num_workers,
)
test_loader = NeighborLoader(
    data,
    input_nodes=split_idx['test'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)



In [21]:
from IPython.display import display, HTML

# Print summary data for each split
print(f"Example batch:")
train_batch = next(iter(train_loader))
print(train_batch)

Example batch:
Data(num_nodes=3143, edge_index=[2, 3400], x=[3143, 100], y=[3143, 1], n_id=[3143], e_id=[3400], input_id=[32], batch_size=32)


### Training and Evaluation Utilities

In [23]:
from torch_geometric.nn.models import GraphSAGE

input_dim = dataset[0].x.shape[1]

def get_model():
    class GraphSAGENodeClassification(torch.nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
            super(GraphSAGENodeClassification, self).__init__()
            self.graph_sage = GraphSAGE(in_channels = input_dim, hidden_channels = hidden_dim, num_layers=num_layers)
            self.cls_head = torch.nn.Sequential(
                torch.nn.Dropout(0.1),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, num_classes),
            )
            self.loss_fn = torch.nn.CrossEntropyLoss()

        def forward(self, x, edge_index):
            h = self.graph_sage(x, edge_index)
            return self.cls_head(h)

    model = GraphSAGENodeClassification(input_dim, hidden_dim, num_layers, dataset.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.to(device)
    return model, optimizer

Next, we define a simple training loop and evaluation function:


In [30]:
from tqdm import tqdm

# training process
def train_one_epoch(model, dataloader, optimizer, transform=None):
    model.train()

    # define stats
    total_loss = 0
    total_correct = 0
    num_examples = 0

    for batch in tqdm(dataloader):

        # transform batch if needed
        batch_size = batch.batch_size
        batch = batch.to(device)
        if transform is not None:
          batch = transform(batch)

        # forward pass
        optimizer.zero_grad()
        logits = model(batch.x, batch.edge_index)[:batch_size]

        # backward pass
        labels = batch.y[:batch_size].squeeze(-1)
        loss = model.loss_fn(logits, labels)
        loss.backward()
        optimizer.step()

        # log stats
        total_loss += loss.item() * batch_size
        total_correct += logits.argmax(dim=-1).eq(labels).sum().item()
        num_examples += batch_size

    loss = total_loss / num_examples
    acc = total_correct / num_examples
    return loss, acc

# test process
@torch.no_grad()
def test(model, dataloader, transform=None, apply_transform=True):
    model.eval()

    # define states
    total_loss = 0
    total_correct = 0
    num_examples = 0

    for batch in tqdm(dataloader):

        # transform batch if needed
        batch_size = batch.batch_size
        batch = batch.to(device)
        if apply_transform and (transform is not None):
          batch = transform(batch)

        # forward pass
        logits = model(batch.x, batch.edge_index)[:batch_size]
        labels = batch.y[:batch_size].squeeze(-1)
        loss = model.loss_fn(logits, labels)

        # log stats
        total_loss += loss.item() * batch_size
        total_correct += logits.argmax(dim=-1).eq(labels).sum().item()
        num_examples += batch_size

    loss = total_loss / num_examples
    acc = total_correct / num_examples
    return loss, acc

To train and evaluate out model, call the below function!

In [31]:
def train(model, optimizer, num_epochs, transform=None, apply_transform_at_test=True):
    best_val_acc = 0
    for epoch in range(1, num_epochs + 1):
        print(f'Epoch: {epoch:02d}')

        # training
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, transform)
        print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {100.0 * train_acc:.2f}%')

        # validation
        val_loss, val_acc = test(model, val_loader, transform, apply_transform=apply_transform_at_test)
        print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {100.0 * val_acc:.2f}%')

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            print('New best validation accuracy, saving model...')
            torch.save(model.state_dict(), 'best_model.pth')


    print(f'Best Validation Accuracy: {100.0 * best_val_acc:.2f}%')

    # eval best model
    model.load_state_dict(torch.load('best_model.pth', weights_only=True))
    test_loss, test_final_acc = test(model, test_loader)
    print(f'Test Accuracy: {100.0 * test_final_acc:.2f}%')
    return {
        'test_acc': test_final_acc,
        'val_acc': best_val_acc,
        'model': model
    }

To illustrate some example usage, let's run one epoch of our train function.

In [33]:
model, optimizer = get_model()
results = train(model, optimizer, num_epochs=1, transform=None, apply_transform_at_test=False)
# print results nicely, similar to how i've printed above
for key, value in results.items():
    if key == 'model':
        continue
    print(f'{key}: {value}')

Epoch: 01


100%|██████████| 6145/6145 [02:31<00:00, 40.64it/s]


Train Loss: 0.9142, Train Accuracy: 76.87%


100%|██████████| 32/32 [00:01<00:00, 29.88it/s]


Val Loss: 0.6181, Val Accuracy: 84.10%
New best validation accuracy, saving model...
Best Validation Accuracy: 84.10%


100%|██████████| 32/32 [00:00<00:00, 64.61it/s]

Test Accuracy: 67.30%
test_acc: 0.673
val_acc: 0.841





## Graph Structure Augmentation

Modifying the **struture** of the graph is also a powerful way to improve the performance of GNNs. The performance of GNNs is very much related to the structure of the graph.

To better motivate graph structure augmentations, let’s first revisit the core idea of GNNs: message passing. At each layer, nodes aggregate information from their neighbors, gradually building a representation that reflects their local neighborhood structure. In theory, deeper networks should be able to capture broader relationships in the graph, integrating information from distant nodes.

- **Over-smoothing**:    Recall that a GNN with $k$ layers aggregates information from each node's $k$-hop neighborhood. Thus, as the network deepens, node representations increasingly mix, and after many layers, nodes tend to converge to very similar representations. This “blending” means that the network struggles to distinguish between nodes, especially in large, densely connected graphs. In extreme cases, the output becomes almost uniform across all nodes, rendering the GNN ineffective for tasks like classification or clustering.

- **Global relationships**:    While increasing the receptive field of each node by stacking more GNN layers might seem like a solution, it exacerbates the over-smoothing problem noted above, highlighting the trade-offs between depth and effective information propagation.

Graph structure augmentations tackle these challenges head-on by altering the graph’s connectivity, introducing extra nodes and/or edges. to improve the flow of information across the graph, mitigate over-smoothing, and enable GNNs to better capture both local and global patterns.

### Half Hop

Half-Hop (introduced in [Azabou 2023](https://arxiv.org/abs/2308.09198)) enhances message passing in neural networks by inserting intermediate "slow nodes" between connected nodes in a graph. This approach mitigates over-smoothing and improves performance, especially in scenarios where neighboring nodes have different labels. The PyG documentation can be found [here](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.transforms.HalfHop.html).

In [None]:
import pandas as pd

from torch_geometric.transforms import HalfHop
# blend features from src and dest equally (alpha=0.5).
# add a virtual node to all edges (p=1)

rows = []
for hh_prob in [0.1, 0.5, 1.0]:
    hh_transform = HalfHop(alpha=0.5, p=hh_prob)
    model, optimizer = get_model()
    print(f"Training with edge probability of {hh_prob}\n")
    result = train(model,
                   optimizer,
                   num_epochs=num_epochs,
                   transform=hh_transform,
                   apply_transform_at_test=False)
    rows.append({
        "masking_prob": hh_prob,
        "test_acc": result['test_acc'],
        "val_acc": result['val_acc']
    })

df = pd.DataFrame(rows)
df.to_csv("out/masking_prob_results.csv")
df

Training with masking probability of 0.1

Epoch: 01


100%|██████████| 6145/6145 [02:33<00:00, 39.93it/s]


Train Loss: 0.9359, Train Accuracy: 76.16%


100%|██████████| 32/32 [00:01<00:00, 30.24it/s]


Val Loss: 0.5973, Val Accuracy: 84.10%
New best validation accuracy, saving model...
Epoch: 02


100%|██████████| 6145/6145 [02:30<00:00, 40.71it/s]


Train Loss: 0.5623, Train Accuracy: 85.20%


100%|██████████| 32/32 [00:01<00:00, 30.68it/s]


Val Loss: 0.5476, Val Accuracy: 85.50%
New best validation accuracy, saving model...
Epoch: 03


100%|██████████| 6145/6145 [02:34<00:00, 39.80it/s]


Train Loss: 0.5003, Train Accuracy: 86.62%


100%|██████████| 32/32 [00:01<00:00, 30.85it/s]


Val Loss: 0.5039, Val Accuracy: 86.70%
New best validation accuracy, saving model...
Epoch: 04


100%|██████████| 6145/6145 [02:31<00:00, 40.48it/s]


Train Loss: 0.4665, Train Accuracy: 87.37%


100%|██████████| 32/32 [00:01<00:00, 30.68it/s]


Val Loss: 0.4782, Val Accuracy: 87.00%
New best validation accuracy, saving model...
Epoch: 05


100%|██████████| 6145/6145 [02:29<00:00, 40.97it/s]


Train Loss: 0.4389, Train Accuracy: 88.04%


100%|██████████| 32/32 [00:01<00:00, 28.12it/s]


Val Loss: 0.4469, Val Accuracy: 88.30%
New best validation accuracy, saving model...
Epoch: 06


100%|██████████| 6145/6145 [02:30<00:00, 40.74it/s]


Train Loss: 0.4206, Train Accuracy: 88.51%


100%|██████████| 32/32 [00:01<00:00, 27.90it/s]


Val Loss: 0.4343, Val Accuracy: 88.00%
Epoch: 07


100%|██████████| 6145/6145 [02:29<00:00, 41.08it/s]


Train Loss: 0.4054, Train Accuracy: 88.88%


100%|██████████| 32/32 [00:01<00:00, 30.64it/s]


Val Loss: 0.4445, Val Accuracy: 88.00%
Epoch: 08


100%|██████████| 6145/6145 [02:30<00:00, 40.94it/s]


Train Loss: 0.3921, Train Accuracy: 89.16%


100%|██████████| 32/32 [00:01<00:00, 30.77it/s]


Val Loss: 0.4255, Val Accuracy: 88.30%
Epoch: 09


100%|██████████| 6145/6145 [02:29<00:00, 41.00it/s]


Train Loss: 0.3827, Train Accuracy: 89.37%


100%|██████████| 32/32 [00:01<00:00, 20.70it/s]


Val Loss: 0.3868, Val Accuracy: 89.10%
New best validation accuracy, saving model...
Epoch: 10


100%|██████████| 6145/6145 [02:37<00:00, 38.99it/s]


Train Loss: 0.3710, Train Accuracy: 89.68%


100%|██████████| 32/32 [00:01<00:00, 22.87it/s]


Val Loss: 0.3900, Val Accuracy: 89.30%
New best validation accuracy, saving model...
Best Validation Accuracy: 89.30%


100%|██████████| 32/32 [00:00<00:00, 38.69it/s]


Test Accuracy: 73.20%
Training with masking probability of 0.5

Epoch: 01


100%|██████████| 6145/6145 [02:39<00:00, 38.59it/s]


Train Loss: 1.0564, Train Accuracy: 73.09%


100%|██████████| 32/32 [00:01<00:00, 30.26it/s]


Val Loss: 0.6344, Val Accuracy: 82.80%
New best validation accuracy, saving model...
Epoch: 02


100%|██████████| 6145/6145 [02:32<00:00, 40.35it/s]


Train Loss: 0.6630, Train Accuracy: 82.51%


100%|██████████| 32/32 [00:01<00:00, 30.54it/s]


Val Loss: 0.5484, Val Accuracy: 85.10%
New best validation accuracy, saving model...
Epoch: 03


100%|██████████| 6145/6145 [02:41<00:00, 38.07it/s]


Train Loss: 0.5886, Train Accuracy: 84.18%


100%|██████████| 32/32 [00:01<00:00, 30.22it/s]


Val Loss: 0.5146, Val Accuracy: 85.40%
New best validation accuracy, saving model...
Epoch: 04


100%|██████████| 6145/6145 [02:31<00:00, 40.60it/s]


Train Loss: 0.5466, Train Accuracy: 85.11%


100%|██████████| 32/32 [00:01<00:00, 19.48it/s]


Val Loss: 0.4855, Val Accuracy: 87.60%
New best validation accuracy, saving model...
Epoch: 05


100%|██████████| 6145/6145 [02:34<00:00, 39.79it/s]


Train Loss: 0.5147, Train Accuracy: 85.84%


100%|██████████| 32/32 [00:01<00:00, 29.63it/s]


Val Loss: 0.4653, Val Accuracy: 87.20%
Epoch: 06


100%|██████████| 6145/6145 [02:30<00:00, 40.75it/s]


Train Loss: 0.4910, Train Accuracy: 86.38%


100%|██████████| 32/32 [00:01<00:00, 30.50it/s]


Val Loss: 0.4345, Val Accuracy: 87.80%
New best validation accuracy, saving model...
Epoch: 07


100%|██████████| 6145/6145 [02:33<00:00, 39.96it/s]


Train Loss: 0.4717, Train Accuracy: 87.03%


100%|██████████| 32/32 [00:01<00:00, 28.99it/s]


Val Loss: 0.4365, Val Accuracy: 88.30%
New best validation accuracy, saving model...
Epoch: 08


100%|██████████| 6145/6145 [02:33<00:00, 39.99it/s]


Train Loss: 0.4559, Train Accuracy: 87.28%


100%|██████████| 32/32 [00:01<00:00, 30.28it/s]


Val Loss: 0.4226, Val Accuracy: 87.70%
Epoch: 09


100%|██████████| 6145/6145 [02:34<00:00, 39.74it/s]


Train Loss: 0.4452, Train Accuracy: 87.53%


100%|██████████| 32/32 [00:01<00:00, 31.16it/s]


Val Loss: 0.3967, Val Accuracy: 88.60%
New best validation accuracy, saving model...
Epoch: 10


100%|██████████| 6145/6145 [02:36<00:00, 39.36it/s]


Train Loss: 0.4311, Train Accuracy: 87.92%


100%|██████████| 32/32 [00:01<00:00, 29.71it/s]


Val Loss: 0.4192, Val Accuracy: 88.10%
Best Validation Accuracy: 88.60%


100%|██████████| 32/32 [00:00<00:00, 53.59it/s]


Test Accuracy: 70.90%
Training with masking probability of 1.0

Epoch: 01


  2%|▏         | 128/6145 [00:05<02:30, 39.94it/s]

In [None]:
import seaborn as sns

df.to_csv("out/masking_prob_results.csv")
sns.lineplot(data=df, x="masking_prob", y="test_acc")

### Virtual Node

VirtualNode (introduced in [Gilmer 2017](https://arxiv.org/abs/1704.01212)) appends a virtual node to the given homogeneous graph that is connected to all other nodes. The virtual node serves as a global scratch space that each node both reads from and writes to in every step of message passing. This allows information to travel long distances during the propagation phase.

In [None]:
import pandas as pd

from torch_geometric.transforms import VirtualNode

rows = []

vn_transform = VirtualNode()
model, optimizer = get_model()
print(f"Training with virtual node\n")
result = train(model,
                optimizer,
                num_epochs=num_epochs,
                transform=hh_transform,
                apply_transform_at_test=False)
rows.append({
    "masking_prob": hh_prob,
    "test_acc": result['test_acc'],
    "val_acc": result['val_acc']
})

df = pd.DataFrame(rows)
df.to_csv("out/masking_prob_results.csv")
df

Training with masking probability of 0.1

Epoch: 01


100%|██████████| 6145/6145 [02:33<00:00, 39.93it/s]


Train Loss: 0.9359, Train Accuracy: 76.16%


100%|██████████| 32/32 [00:01<00:00, 30.24it/s]


Val Loss: 0.5973, Val Accuracy: 84.10%
New best validation accuracy, saving model...
Epoch: 02


100%|██████████| 6145/6145 [02:30<00:00, 40.71it/s]


Train Loss: 0.5623, Train Accuracy: 85.20%


100%|██████████| 32/32 [00:01<00:00, 30.68it/s]


Val Loss: 0.5476, Val Accuracy: 85.50%
New best validation accuracy, saving model...
Epoch: 03


100%|██████████| 6145/6145 [02:34<00:00, 39.80it/s]


Train Loss: 0.5003, Train Accuracy: 86.62%


100%|██████████| 32/32 [00:01<00:00, 30.85it/s]


Val Loss: 0.5039, Val Accuracy: 86.70%
New best validation accuracy, saving model...
Epoch: 04


100%|██████████| 6145/6145 [02:31<00:00, 40.48it/s]


Train Loss: 0.4665, Train Accuracy: 87.37%


100%|██████████| 32/32 [00:01<00:00, 30.68it/s]


Val Loss: 0.4782, Val Accuracy: 87.00%
New best validation accuracy, saving model...
Epoch: 05


100%|██████████| 6145/6145 [02:29<00:00, 40.97it/s]


Train Loss: 0.4389, Train Accuracy: 88.04%


100%|██████████| 32/32 [00:01<00:00, 28.12it/s]


Val Loss: 0.4469, Val Accuracy: 88.30%
New best validation accuracy, saving model...
Epoch: 06


100%|██████████| 6145/6145 [02:30<00:00, 40.74it/s]


Train Loss: 0.4206, Train Accuracy: 88.51%


100%|██████████| 32/32 [00:01<00:00, 27.90it/s]


Val Loss: 0.4343, Val Accuracy: 88.00%
Epoch: 07


100%|██████████| 6145/6145 [02:29<00:00, 41.08it/s]


Train Loss: 0.4054, Train Accuracy: 88.88%


100%|██████████| 32/32 [00:01<00:00, 30.64it/s]


Val Loss: 0.4445, Val Accuracy: 88.00%
Epoch: 08


100%|██████████| 6145/6145 [02:30<00:00, 40.94it/s]


Train Loss: 0.3921, Train Accuracy: 89.16%


100%|██████████| 32/32 [00:01<00:00, 30.77it/s]


Val Loss: 0.4255, Val Accuracy: 88.30%
Epoch: 09


100%|██████████| 6145/6145 [02:29<00:00, 41.00it/s]


Train Loss: 0.3827, Train Accuracy: 89.37%


100%|██████████| 32/32 [00:01<00:00, 20.70it/s]


Val Loss: 0.3868, Val Accuracy: 89.10%
New best validation accuracy, saving model...
Epoch: 10


 38%|███▊      | 2361/6145 [00:57<01:55, 32.75it/s]

In [None]:
from torch_geometric.utils import add_random_edge
from functools import partial
import pandas as pd

def add_random_edges_batch(batch, p):
    batch.edge_index, added_edges = add_random_edge(batch.edge_index, p=p)
    return batch

rows = []
for prob in [0.1, 0.2, 0.3, 0.4, 0.5]:
    model, optimizer = get_model()
    print(f"Training with masking probability of {prob}\n")
    transform = partial(add_random_edges_batch, p=prob)
    result = train(model, optimizer, transform=transform, apply_transform_at_test=False)
    rows.append({
        "edge_prob":prob,
        "test_acc": result['test_acc'],
        "val_acc": result['val_acc']
    })

df = pd.DataFrame(rows)
df.to_csv("out/edge_prob_results.csv")
df

In [None]:
sns.lineplot(data=df, x="edge_prob", y="test_acc")