<a href="https://colab.research.google.com/github/adelsuh/cs224_final_project/blob/main/graph_structure_aug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CS224W Final Project: Tutorial on the Augmentation of Graphs in PyG using Structural Augmentations

### Jerry Chan, Jihee Suh, John So

Data augmentation is a widely used technique that leverages existing data to further train a model, improving its performance and generalization. For structured data formats such as images, augmentation methods can be quite straightforward, including operations like cropping, resizing, rotating, and adding noise. These augmentations are useful for reducing overfitting to the training dataset and adding invariance to certain transformations, such as color shifts, different camera models, and even different camera poses.

PyG provides support for dataset augmentations, which primarily inherit from the `torch_geometric.transforms` class. In this Colab, we will cover how augmentations can modify the graph structure to improve performance in inductive graph learning settings.

### Notebook setup: install PyG + torch

In [None]:
import torch
torch_version = str(torch.__version__)
if "2.4.0" not in torch_version:
  !pip install torch==2.4.0 -q
print(torch_version)

In [2]:
scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
!pip install torch-scatter -f $scatter_src -q
!pip install torch-sparse -f $sparse_src -q
!pip install torch-geometric -q
!pip install ogb -q

In [None]:
import os
import random
import numpy as np
import torch
import seaborn as sns
import pandas as pd
from tqdm import tqdm
from functools import partial


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"using device: {device}")

### Setting up the dataset and tasks

The below code sets up some hyperparameters which will be used in dataloading and training.


In [None]:
# Model settings
hidden_dim = 128 #@param {type: "integer"}
num_layers = 4 #@param {type: "integer"}
# Training settings
learning_rate = 0.0001 #@param {type: "number"}
num_epochs = 25 #@param {type: "integer"}

# Dataloader settings
batch_size = 32 #@param {type: "integer"}
fan_out = 30 #@param {type: "integer", hint: "Used in neighborhood sampling to sample a subgraph"}
dataloader_num_workers = 2 #@param {type: "integer"}

print(f"""
Running training with the following configuration:
   hidden_dim: {hidden_dim}
   num_layers: {num_layers}
   learning_rate: {learning_rate}
   num_epochs: {num_epochs}
   batch_size: {batch_size}
""")

### Task: Node property prediction with GraphSAGE

For this tutorial, we will consider the [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/) dataset from the Open Graph Benchmark (OGB). This dataset consists of ~170K nodes and ~1.2M directed edges. Each node represents an arXiv CS paper, and each edge represents the citation network between arXiv papers. The prediction task is to predict one of 40 labels for each node (i.e. CS.AI, CS.OS, etc.), given a 128-dimensional feature vector consisting of averaged language embeddings for the arXiv paper.

Run the below block to create the dataset. If this is your first time loading the dataset, it will additionally prompt you to download files.

**Note**: this block loads the dataset into RAM each time it is called! So calling this block multiple times will likely consume all of the notebook's RAM. Take caution.

In [None]:
from ogb.nodeproppred import PygNodePropPredDataset
dataset = PygNodePropPredDataset(name='ogbn-arxiv', root='./arxiv/')

We'll approach this as an *inductive graph prediction* problem; we want to train one network across many graphs, such that when given new nodes in the graph, or entirely new graphs, our predictor can generalize. To aid with this, the `ogbn-arxiv` dataset provides dataset splits:

- **train split**: CS papers published up until 2017
- **validation split**: CS papers published during 2018
- **test split**: CS papers published during or after 2019.

We'll train on the subgraphs on the train split, then optimize parameters based on performance on the validation split. Lastly, we'll report the performance of methods on the test split.

In [None]:
split_idx = dataset.get_idx_split()
# sample test set to speed up
split_idx['test'] = split_idx['test']
split_idx['valid'] = split_idx['valid']

print(f"""
Summary of the OBGN Arxiv dataset:
  Number of graphs: {len(dataset)}
  Number of features: {dataset.num_features}
  Number of classes: {dataset.num_classes}
  Length of each split:
    Training: {len(split_idx['train'])}
    Validation: {len(split_idx['valid'])}
    Test: {len(split_idx['test'])}
""")

Now, let's create some dataloaders. Training GNNs on graphs with 100k+ node is computationally prohibitive, making full-batch training infeasible. To address this, we can use neighbor sampling, a technique designed for efficient mini-batch training on large graphs by sampling smaller subgraphs. In particular, PyG provides native functionality for this with the `NeighborLoader` class.

The `num_neighbors` parameter controls how neighbors are sampled. It consists of a length $k$ list of integers, and performs $k$ sampling iterations. Starting from a sampled node, NeighborLoader samples `num_neighbors[i]` neighbors from nodes involved in the previous iteration. In our code, we sample $k$-hop neighborhoods, where $k$ is the depth of our GNN. This approach ensures scalability by iteratively constructing smaller, representative sub-graphs while preserving the graph structure necessary for effective learning.

In [None]:
from torch_geometric.loader import NeighborLoader

data = dataset[0]
train_loader = NeighborLoader(
    data,
    input_nodes=split_idx['train'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=dataloader_num_workers
)
val_loader = NeighborLoader(
    data,
    input_nodes=split_idx['valid'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=dataloader_num_workers,
)
test_loader = NeighborLoader(
    data,
    input_nodes=split_idx['test'],
    num_neighbors=[fan_out] * num_layers,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

In [None]:
print(f"Example batch:")
train_batch = next(iter(train_loader))
print(train_batch)

### Training and Evaluation Utilities

Given this dataset, let's now choose a GNN model. For the purposes of this tutorial, we will choose a simple yet powerful network. [GraphSage](https://arxiv.org/abs/1706.02216) (Hamilton 2017) leverages node features and neighborhood aggregation to learn deep features for nodes. This lends itself to a variety of prediction problems, including our node classification problem.

PyG includes a native implementation of [GraphSAGE](https://pytorch-geometric.readthedocs.io/en/2.5.3/generated/torch_geometric.nn.models.GraphSAGE.html), which we will use to construct our predictor. The predictor consists of a GraphSAGE model to learn deep node embeddings. Then, we pass in a small MLP to output prediction logits for each node.

In [11]:
from torch_geometric.nn.models import GraphSAGE

input_dim = dataset.num_features

def get_model():
    class GraphSAGENodeClassification(torch.nn.Module):
        def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
            super(GraphSAGENodeClassification, self).__init__()
            self.graph_sage = GraphSAGE(in_channels = input_dim, hidden_channels = hidden_dim, num_layers=num_layers)
            self.cls_head = torch.nn.Sequential(
                torch.nn.Dropout(0.1),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, num_classes),
            )
            self.loss_fn = torch.nn.CrossEntropyLoss()

        def forward(self, x, edge_index):
            h = self.graph_sage(x, edge_index)
            return self.cls_head(h)

    model = GraphSAGENodeClassification(input_dim, hidden_dim, num_layers, dataset.num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.to(device)
    return model, optimizer

Next, we'll define some helper functions, `train_one_epoch` and `test` below. Notably, our functions will take in two arguments:
- `transform`: whether to apply a unary transformation function to the data, such as removing nodes or adding features.
- `filter_output_fn`: whether to apply a binary transformation function to the output. This is intended to handle any extra information introduced by the transformation.


In [12]:
# training process
def train_one_epoch(model,
                    dataloader,
                    optimizer,
                    transform=None,
                    filter_output_fn=None):
    """
    Run one epoch of training on the model on the given dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        dataloader (torch.utils.data.DataLoader): The dataloader for the dataset.
        transform: if specified and apply_transform is True, a transformation to apply to each batch
        filter_output_fn: if specified and apply_transform is True, a transformation to apply to the output of each batch.
    """
    model.train()

    # define stats
    total_loss = 0
    total_correct = 0
    num_examples = 0

    for batch in tqdm(dataloader):

        # transform batch if needed
        batch_size = batch.batch_size
        batch = batch.to(device)
        if transform is not None:
          batch = transform(batch)
        # forward pass
        optimizer.zero_grad()
        logits = model(batch.x, batch.edge_index)
        if filter_output_fn is not None:
          logits = filter_output_fn(logits, batch)

        # backward pass
        num_labels = logits.shape[0]
        labels = batch.y.squeeze(-1)

        # select supervision nodes
        labels = labels[:batch_size]
        logits = logits[:batch_size]
        num_examples += batch_size

        loss = model.loss_fn(logits, labels)
        loss.backward()
        optimizer.step()


        # log stats
        total_loss += loss.item() * num_labels
        total_correct += logits.argmax(dim=-1).eq(labels).sum().item()

    loss = total_loss / num_examples
    acc = total_correct / num_examples
    return loss, acc

# test process
@torch.no_grad()
def test(model,
         dataloader,
         transform=None,
         filter_output_fn=None,
         apply_transform=True):
    """
    Calculate metrics for the model on the given dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        dataloader (torch.utils.data.DataLoader): The dataloader for the dataset.
        apply_transform: whether to use the arguments transform and filter_output_fn.
        transform: if specified and apply_transform is True, a transformation to apply to each batch
        filter_output_fn: if specified and apply_transform is True, a transformation to apply to the output of each batch.
    """
    model.eval()

    # define states
    total_loss = 0
    total_correct = 0
    num_examples = 0

    for batch in tqdm(dataloader):
        # transform batch if needed
        batch_size = batch.batch_size
        batch = batch.to(device)
        if apply_transform and (transform is not None):
          batch = transform(batch)

        # forward pass
        logits = model(batch.x, batch.edge_index)
        if apply_transform and (filter_output_fn is not None):
          logits = filter_output_fn(logits, batch)

        # compute loss
        num_labels = logits.shape[0]
        labels = batch.y.squeeze(-1)

        # select supervision nodes
        labels = labels[:batch_size]
        logits = logits[:batch_size]
        num_examples += batch_size

        loss = model.loss_fn(logits, labels)
        # log stats
        total_loss += loss.item() * num_labels
        total_correct += logits.argmax(dim=-1).eq(labels).sum().item()
        

    loss = total_loss / num_examples
    acc = total_correct / num_examples
    return loss, acc

To train and evaluate the model, call the below `train` function!

In [13]:
def train(model,
          optimizer,
          num_epochs,
          transform=None,
          filter_output_fn=None,
          apply_transform_at_test=True):
    all_train_acc, all_val_acc, all_test_acc = [], [], []
    best_val_ind, best_val_acc = 0, 0
    for epoch in range(num_epochs):
        print(f'Epoch: {epoch+1:02d}')

        # training
        train_loss, train_acc = train_one_epoch(model,
                                                train_loader,
                                                optimizer,
                                                transform,
                                                filter_output_fn)
        val_loss, val_acc = test(model,
                                 val_loader,
                                 transform,
                                 filter_output_fn=filter_output_fn,
                                 apply_transform=apply_transform_at_test)
        test_loss, test_acc = test(model,
                                   test_loader,
                                   transform,
                                   filter_output_fn=filter_output_fn,
                                   apply_transform=apply_transform_at_test)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_ind = epoch

        print(f'Train {train_loss:.4f} ({100.0 * train_acc:.2f}%) | Val {val_loss:.4f} ({100.0 * val_acc:.2f}%) | Test {test_loss:.4f} ({100.0 * test_acc:.2f}%)')

        all_train_acc.append(train_acc)
        all_val_acc.append(val_acc)
        all_test_acc.append(test_acc)

    return {
        'all_train_acc': np.array(all_train_acc),
        'all_val_acc': np.array(all_val_acc),
        'all_test_acc': np.array(all_test_acc),
        'best_val_ind': best_val_ind,
        'model': model
    }

To illustrate some example usage, let's run a baseline. This trains a GraphSAGE network with no graph structure augmentations.

In [None]:
model, optimizer = get_model()
results = train(model, optimizer, num_epochs=num_epochs)

best_bl_train_acc = results['all_train_acc'][results['best_val_ind']]
best_bl_val_acc = results['all_val_acc'][results['best_val_ind']]
best_bl_test_acc = results['all_test_acc'][results['best_val_ind']]

## Training Augmentation

Beyond augmentations on the graph and its nodes as covered above, it is useful to dynamically augment batches during training, analogous to random cropping, blurring, and color shifting for images. To accomplish this, we refer to several methods in torch_geometric.utils. In this section, we introduce two training augmentation methods: `mask_feature` and `dropout_edge` These methods modify the sampled graph during training to dynamically perturb the input, preventing overfitting by discouraging the model from over-relying on specific features or edges. This approach leads to a more robust and generalizable model.


### Dropout Edge
The `dropout_edge` function randomly removes edges from the graph. It operates on the edge_index matrix and returns:
* A modified `edge_index` with some edges dropped.
* A binary tensor `edge_mask` indicating which edges were retained (True) or dropped (False).

Key Arguments:

* `p`: The probability of dropping an edge.
* `Force_undirected`: When set to True setting ensures that the resulting edge_index remains undirected.

The following code applied dropout edge with `p` $\in$ [0.1, 0.2, 0.3, 0.4, 0.5]. 

In [None]:
from torch_geometric.utils import dropout_edge

def dropout_edge_batch(batch, p):
    batch.edge_index, removed_edge = dropout_edge(batch.edge_index, p=p)
    return batch

rows = []
for prob in [0.1, 0.2, 0.3, 0.4, 0.5]:
    model, optimizer = get_model()
    print(f"Training with masking probability of {prob}\n")
    transform = partial(dropout_edge_batch, p=prob)
    result = train(model, optimizer, transform=transform, apply_transform_at_test=False, num_epochs=num_epochs)
    rows.append({
        "edge_prob":prob,
        "test_acc": result['all_test_acc'][result['best_val_ind']],
        "val_acc": result['all_val_acc'][result['best_val_ind']]
    })

df = pd.DataFrame(rows)
df.to_csv("out/dropout_edge_result.csv")
df

Now let's visualize the result!

In [None]:
df = pd.read_csv('out/dropout_edge_result.csv', index_col=0)
df.loc[5] = [0, 0.5418, 0.615]

sns.set_style("ticks")

ax = sns.lineplot(x='edge_prob', y='test_acc', data=df)

ax.set(xlabel='Edge Dropout Probability', ylabel='Test Accuracy')
ax.set_yticks([0.54, 0.55, 0.56], ["54%", "55%", "56%"])
ax.set_ylim(0.538, 0.562)
sns.despine()

### Mask Feature

The `mask_feature` function randomly masks parts of node features. It takes the node feature matrix `x` as input and returns the modified features along with a mask indicating the positions of the masked features.

Key arguments:
* `p`: The probability of masking a feature.
* `fill_value`: The value used to replace masked features (default: 0).
* `mode`: The masking scheme.

There are three masking modes:
* `col` (default): Masks entire feature columns across all nodes.
* `row`: Masks all features of selected nodes.
* `all`: Masks individual features independently.


Following code applied feature masking with `p` $\in$ [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]

In [None]:
from torch_geometric.utils import mask_feature
from functools import partial
import pandas as pd

def mask_feature_batch(batch, p, mode="all"):
    masked_x, feature_mask = mask_feature(batch.x, p = p, mode = mode)
    batch.x = masked_x
    return batch

rows = []
for masking_prob in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
    model, optimizer = get_model()
    print(f"Training with masking probability of {masking_prob}\n")
    transform = partial(mask_feature_batch, p=masking_prob)
    results = train(model, optimizer, transform=transform, apply_transform_at_test=False, num_epochs=25)
    rows.append({
        "masking_prob": masking_prob,
        "test_acc": results['all_test_acc'][results['best_val_ind']],
        "val_acc": results['all_val_acc'][results['best_val_ind']]
    })

df = pd.DataFrame(rows)
df.to_csv("out/masking_prob_results.csv")
df

Following code applied feature masking with `mode` $\in$ ["all", "col", "row"]

In [None]:
rows = []
for masking_mode in ["all", "col", "row"]:
    model, optimizer = get_model()
    print(f"Training with masking mode {masking_mode}\n")
    transform = partial(mask_feature_batch, p=0.05, mode=masking_mode)
    results = train(model, optimizer, transform=transform, apply_transform_at_test=False, num_epochs=25)
    rows.append({
        "masking_mode": masking_mode,
        "test_acc": results['all_test_acc'][results['best_val_ind']],
        "val_acc": results['all_val_acc'][results['best_val_ind']]
    })

df = pd.DataFrame(rows)
df.to_csv("out/masking_mode_results.csv")
df