# GraphSAGE

In [7]:
# Necessary to import from sibling directory
import sys
sys.path.append("..")

from pymdb import MDBClient, BatchLoader, NodeIterator, FeatureStoreManager

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


## Model

In [34]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int, num_layers: int):
        super().__init__()
        if num_layers < 2:
            raise ValueError("Number of layers must be greater than 1")
        self.layers = torch.nn.ModuleList()
        self.layers.append(SAGEConv(dim_in, dim_h))
        for _ in range(num_layers - 2):
            self.layers.append(SAGEConv(dim_h, dim_h))
        self.layers.append(SAGEConv(dim_h, dim_out))

    def forward(
        self,
        x: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = self.layers[0](x, edge_index)
        for layer in self.layers[1:-1]:
            h = layer(h, edge_index)
            h = torch.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        h = self.layers[-1](h, edge_index)
        return h, F.log_softmax(h, dim=1)  # (embedding, prediction)

    def fit(self, epochs, batch_loader):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)

                loss = criterion(out, batch.node_labels)
                acc = accuracy(out.argmax(dim=1), batch.node_labels)

                loss.backward()
                optimizer.step()

            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(self, X, edge_index):
        self.eval()
        embedding, prediction = self(X, edge_index)
        return embedding, prediction.argmax(dim=1)


## Train

In [35]:
with MDBClient() as client:
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=256,
        batch_size=64,
        neighbor_sizes=[10, 5],
        seed=2023,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=2, num_layers=2)
    print(model)
    model.fit(epochs=100, batch_loader=batch_loader)


GraphSAGE(
  (layers): ModuleList(
    (0): SAGEConv(128, 64, aggr=mean)
    (1): SAGEConv(64, 2, aggr=mean)
  )
)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1250x128 and 64x2)

## Predict

In [16]:
with MDBClient() as client:
    fsm = FeatureStoreManager(client)
    fs = fsm.open("github")
    node_iterator = NodeIterator(client, 1000)
    for batch in node_iterator:
        node_features = torch.zeros((len(batch), 128))
        for i, node_id in enumerate(batch):
            node_features[i] = fs[node_id]
        edge_index = torch.zeros((2, 0), dtype=torch.int64)
        # print(model.predict(node_features, edge_index)[1])