# GraphSAGE

In [1]:
import sys

# Necessary to import from sibling directory
sys.path.append("..")

from typing import TYPE_CHECKING

from pymdb import (
    MDBClient,
    TrainGraphLoader,
    EvalGraphLoader,
    SamplingGraphLoader,
    TensorStore,
    Sampler,
)

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

if TYPE_CHECKING:
    from pymdb import GraphLoader


## Model

In [2]:
class GraphSAGE(torch.nn.Module):
    def __init__(
        self,
        dim_in: int,
        dim_h: int,
        dim_out: int,
        num_layers: int,
    ):
        if num_layers < 2:
            raise ValueError("Number of layers must be greater than 1")
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(dim_in, dim_h))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(dim_h, dim_h))
        self.convs.append(SAGEConv(dim_h, dim_out))

        self.dim_in = dim_in
        self.dim_h = dim_h
        self.dim_out = dim_out
        self.num_layers = num_layers

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(
        self,
        node_features: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, num_edges]
    ) -> torch.Tensor:
        h = node_features
        for layer in self.convs[:-1]:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        return self.convs[-1](h, edge_index)  # Embedding


## Trainers

### fit_sup

In [3]:
def fit_sup(
    model: "GraphSAGE",
    epochs: int,
    graph_loader: "GraphLoader",
):
    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in range(epochs + 1):
        total_seeds = 0
        total_loss = 0
        total_correct = 0
        for batch in graph_loader:
            optimizer.zero_grad()

            out = model(batch.node_features, batch.edge_index)[: batch.num_seeds]
            y_pred = F.log_softmax(out, dim=-1)
            y_true = batch.node_labels[: batch.num_seeds]

            loss = criterion(y_pred, y_true)
            loss.backward()
            optimizer.step()

            total_seeds += batch.num_seeds
            total_loss += loss
            total_correct += y_pred.argmax(dim=-1).eq(y_true).sum().item()

        if epoch % 50 == 0 or epoch == epochs:
            epoch_loss = total_loss / len(graph_loader)
            epoch_acc = 100 * total_correct / total_seeds
            print(f"Epoch {epoch: 4d} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")


In [4]:
with MDBClient() as client:
    # Parameters
    model_args = {
        "dim_in": 128,
        "dim_h": 64,
        "dim_out": 16,
        "num_layers": 5,
    }
    graph_loader_args = {
        "batch_size": 64,
        "num_neighbors": [5, 5],
        "tensor_store_name": "github",
    }
    epochs = 300
    num_seeds = 256

    model = GraphSAGE(**model_args)

    # Initialize both graph loaders
    sampler = Sampler(client=client)
    seed_ids = sampler.get_seed_ids(num_seeds)
    tgl = TrainGraphLoader(client=client, seed_ids=seed_ids, **graph_loader_args)
    sgl = SamplingGraphLoader(client=client, num_seeds=num_seeds, **graph_loader_args)

    # Train with both graph loaders
    for loader in [sgl, tgl]:
        print(f"Training with {loader.__class__.__name__}...\n")
        fit_sup(
            model=model,
            epochs=epochs,
            graph_loader=tgl,
        )
        print()
        model.reset_parameters()


Training with SamplingGraphLoader...

Epoch    0 | Loss: 2.2864 | Acc: 44.14%
Epoch   50 | Loss: 0.2673 | Acc: 95.70%
Epoch  100 | Loss: 0.1027 | Acc: 97.27%
Epoch  150 | Loss: 0.1501 | Acc: 95.31%
Epoch  200 | Loss: 0.2160 | Acc: 96.09%
Epoch  250 | Loss: 0.2130 | Acc: 95.31%
Epoch  300 | Loss: 0.1035 | Acc: 96.88%

Training with TrainGraphLoader...

Epoch    0 | Loss: 2.3623 | Acc: 39.45%
Epoch   50 | Loss: 0.1095 | Acc: 95.31%
Epoch  100 | Loss: 0.2151 | Acc: 96.48%
Epoch  150 | Loss: 0.3328 | Acc: 97.27%
Epoch  200 | Loss: 0.0917 | Acc: 96.88%
Epoch  250 | Loss: 0.2114 | Acc: 96.48%
Epoch  300 | Loss: 0.0968 | Acc: 98.05%



### fit_sup_with_embeddings

This trainer will use the original node features concatenated with the latest version of the computed embeddings. If no embeding is found, a tensor of zeroes is used. The model must have `dim_features + dim_embedding` dimensions on each layer until the final layer of `dim_embedding`.

In [7]:
def fit_sup_with_embeddings(
    client: "MDBClient",
    model: "GraphSAGE",
    epochs: int,
    graph_loader: "GraphLoader",
):
    if model.dim_in < model.dim_out:
        raise ValueError("Model must have dim_in >= dim_out")
    if model.dim_in != model.dim_h:
        raise ValueError("Model must have dim_in == dim_h")

    dim_features = model.dim_in - model.dim_out
    dim_embeddings = model.dim_out

    model.train()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Create a temporary feature store to store the latest embeddings
    temp_store_name = "temp"
    if TensorStore.exists(client, temp_store_name):
        TensorStore.remove(client, temp_store_name)
    TensorStore.create(client, temp_store_name, dim_embeddings)

    temp_store = TensorStore(client, temp_store_name)

    for epoch in range(epochs + 1):
        total_seeds = 0
        total_loss = 0
        total_correct = 0
        for batch in graph_loader:
            optimizer.zero_grad()

            # Allocate space for concatenation
            concat_matrix = torch.zeros(
                len(batch.node_ids), dim_features + dim_embeddings
            )
            # Insert original node features
            concat_matrix[:, :dim_features] = batch.node_features
            # Insert latest available embeddings from temporary store
            for idx, node_id in enumerate(batch.node_ids):
                if node_id in temp_store:
                    concat_matrix[idx, dim_features:] = temp_store[node_id]
            out = model(concat_matrix, batch.edge_index)[: batch.num_seeds]

            # Update seed node embeddings in temporary store
            temp_store[batch.node_ids[: batch.num_seeds]] = out

            y_pred = F.log_softmax(out, dim=-1)
            y_true = batch.node_labels[: batch.num_seeds]

            loss = criterion(y_pred, y_true)
            loss.backward()
            optimizer.step()

            total_seeds += batch.num_seeds
            total_loss += loss
            total_correct += y_pred.argmax(dim=-1).eq(y_true).sum().item()

        if epoch % 50 == 0 or epoch == epochs:
            epoch_loss = total_loss / len(graph_loader)
            epoch_acc = 100 * total_correct / total_seeds
            print(f"Epoch {epoch: 4d} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")

    # Cleanup
    temp_store.close()
    TensorStore.remove(client, temp_store_name)


In [8]:
with MDBClient() as client:
    # Parameters
    dim_features = 128
    dim_out = 16
    model_args = {
        "dim_in": dim_features + dim_out,
        "dim_h": dim_features + dim_out,
        "dim_out": dim_out,
        "num_layers": 5,
    }
    graph_loader_args = {
        "batch_size": 64,
        "num_neighbors": [5, 5],
        "tensor_store_name": "github",
    }
    epochs = 300
    num_seeds = 256

    model = GraphSAGE(**model_args)

    # Initialize both graph loaders
    sampler = Sampler(client=client)
    seed_ids = sampler.get_seed_ids(num_seeds)
    tgl = TrainGraphLoader(client=client, seed_ids=seed_ids, **graph_loader_args)
    sgl = SamplingGraphLoader(client=client, num_seeds=num_seeds, **graph_loader_args)

    # Train with both graph loaders
    for loader in [sgl, tgl]:
        print(f"Training with {loader.__class__.__name__}...\n")
        fit_sup_with_embeddings(
            client=client,
            model=model,
            epochs=300,
            graph_loader=loader,
        )
        print()
        model.reset_parameters()


Training with SamplingGraphLoader...

Epoch    0 | Loss: 1.8828 | Acc: 52.34%
Epoch   50 | Loss: 0.6591 | Acc: 74.22%
Epoch  100 | Loss: 0.6948 | Acc: 69.92%
Epoch  150 | Loss: 0.6245 | Acc: 69.53%
Epoch  200 | Loss: 0.6230 | Acc: 75.00%
Epoch  250 | Loss: 0.5563 | Acc: 75.39%
Epoch  300 | Loss: 0.6150 | Acc: 74.61%

Training with TrainGraphLoader...

Epoch    0 | Loss: 1.7887 | Acc: 47.27%
Epoch   50 | Loss: 0.6263 | Acc: 67.97%
Epoch  100 | Loss: 0.7791 | Acc: 72.27%
Epoch  150 | Loss: 0.6041 | Acc: 74.61%
Epoch  200 | Loss: 0.7046 | Acc: 74.61%
Epoch  250 | Loss: 0.5786 | Acc: 75.00%
Epoch  300 | Loss: 0.5773 | Acc: 75.39%

