### Sources:

- https://mlabonne.github.io/blog/graphsage/

In [1]:
import torch
torchversion = torch.__version__

# Install PyTorch Scatter, PyTorch Sparse, and PyTorch Geometric
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torchversion}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


### GraphSAGE Model

In [2]:
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_out)
        
    def forward(self, x, edge_index):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(x, edge_index)
        # Return (embedding, prediction)
        return h, F.log_softmax(h, dim=1)

    def fit(self, epochs, node_loader):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=53-4)
        accuracy = lambda pred_y, y: ((pred_y == y).sum() / len(y)).item()

        self.train()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in node_loader:
                optimizer.zero_grad()
                _, pred = self(batch.x, batch.edge_index)
                
                loss = criterion(pred, batch.y)
                accuracy = accuracy(pred.argmax(dim=1), batch.y)

                loss.backward()
                optimizer.step()

            # Print metrics every 10 epochs
            if epoch % 10 == 0:
                print(f"Epoch          : {epoch}")
                print(f"Train loss     : {loss}")
                print(f"Train accuracy : {accuracy}")
                print("-" * 30)

### Dataset

In [3]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root=".", name="Pubmed")
data = dataset[0]

# Print information about the dataset
print(f"Dataset: {dataset}")
print("-"*20)
print(f"Number of graphs   : {len(dataset)}")
print(f"Number of nodes    : {data.x.shape[0]}")
print(f"Number of edges    : {data.edge_index.shape[1]}")
print(f"Number of features : {data.x.shape[1]}")
print(f"Number of classes  : {dataset.num_classes}")

Dataset: Pubmed()
--------------------
Number of graphs   : 1
Number of nodes    : 19717
Number of edges    : 88648
Number of features : 500
Number of classes  : 3


In [10]:
from torch_geometric.data import Data
# Note: 
# The entire edge index wouldn't be necessary for a custom neighbor loader 
# implementation using MillenniumDB

X_train = data.x[data.train_mask]
y_train = data.y[data.train_mask]
data_train = Data(x=X_train, y=y_train, edge_index=data.edge_index)

X_test = data.x[data.test_mask]
y_test = data.y[data.test_mask]
data_test = Data(x=X_test, y=y_test, edge_index=data.edge_index)

### Neighbor loader

#### 1. PyTorch implementation (Just for inspiration)



In [17]:
from torch_geometric.loader import NeighborLoader
node_loader = NeighborLoader(
    data,
    num_neighbors=[5, 5],
    batch_size=128
)

for i, batch in enumerate(node_loader):
    print(len(batch), batch)
    if i>0 and i % 3 == 0:
        break
print(len(node_loader))

10 Data(x=[1475, 500], edge_index=[2, 1698], y=[1475], train_mask=[1475], val_mask=[1475], test_mask=[1475], n_id=[1475], e_id=[1698], input_id=[128], batch_size=128)
10 Data(x=[1400, 500], edge_index=[2, 1566], y=[1400], train_mask=[1400], val_mask=[1400], test_mask=[1400], n_id=[1400], e_id=[1566], input_id=[128], batch_size=128)
10 Data(x=[1413, 500], edge_index=[2, 1624], y=[1413], train_mask=[1413], val_mask=[1413], test_mask=[1413], n_id=[1413], e_id=[1624], input_id=[128], batch_size=128)
10 Data(x=[1487, 500], edge_index=[2, 1743], y=[1487], train_mask=[1487], val_mask=[1487], test_mask=[1487], n_id=[1487], e_id=[1743], input_id=[128], batch_size=128)
155


#### 2. Custom implementation

- The input nodes are the "seed nodes"
- Each iteration takes `batch_size` seed nodes and ends up getting neighbors according to the `num_neighbors` list
- The return type must be a `torch_geometric.data.Data` object with at least `x := [node_id, features]` and `y := [labels]` defined

In [None]:
# TODO

### Model instance

In [11]:
dim_in = data.x.shape[1]
dim_h = 64 # You can tune this
dim_out = dataset.num_classes
model = GraphSAGE(dim_in, dim_h, dim_out)
print(model)

GraphSAGE(
  (sage1): SAGEConv(500, 64, aggr=mean)
  (sage2): SAGEConv(64, 3, aggr=mean)
)


In [15]:
model.fit(
    epochs=300,
    node_loader=node_loader
)