# GraphSAGE

In [1]:
# Necessary to import from sibling directory
import sys
sys.path.append("..")

from pymdb import MDBClient, BatchLoader

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


## Model

In [2]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(
        self,
        x: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        # Return (embedding, class_prediction)
        return h, F.log_softmax(h, dim=1)

    def fit(self, epochs, batch_loader):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)

                loss = criterion(out, batch.node_labels)
                acc = accuracy(out.argmax(dim=1), batch.node_labels)

                loss.backward()
                optimizer.step()

            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(self, X, edge_index):
        self.eval()
        return self(X, edge_index)


## Train

In [3]:
with MDBClient(host="127.0.0.1", port=8080) as client:
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=256,
        batch_size=64,
        neighbor_sizes=[10, 5],
        seed=2023,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=2)
    model.fit(epochs=500, batch_loader=batch_loader)


Epoch:   0 | Train loss: 0.432 | Train acc:  85.17%
Epoch:  50 | Train loss: 0.226 | Train acc:  91.51%
Epoch: 100 | Train loss: 0.209 | Train acc:  92.04%
Epoch: 150 | Train loss: 0.206 | Train acc:  92.52%
Epoch: 200 | Train loss: 0.173 | Train acc:  93.63%
Epoch: 250 | Train loss: 0.180 | Train acc:  93.71%
Epoch: 300 | Train loss: 0.178 | Train acc:  92.96%
Epoch: 350 | Train loss: 0.209 | Train acc:  91.99%
Epoch: 400 | Train loss: 0.172 | Train acc:  93.39%
Epoch: 450 | Train loss: 0.166 | Train acc:  93.92%
Epoch: 500 | Train loss: 0.165 | Train acc:  93.79%
