# GraphSAGE

In [4]:
# Necessary to import from sibling directory
import sys
sys.path.append("..")

from pymdb import MDBClient, BatchLoader

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


## Model

In [5]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(
        self,
        x: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        # Return (embedding, class_prediction)
        return h, F.log_softmax(h, dim=1)

    def fit(self, epochs, batch_loader):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)

                loss = criterion(out, batch.node_labels)
                acc = accuracy(out.argmax(dim=1), batch.node_labels)

                loss.backward()
                optimizer.step()

            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(self, X, edge_index):
        self.eval()
        return self(X, edge_index)


## Train

In [6]:
with MDBClient(host="127.0.0.1", port=8080) as client:
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=64,
        batch_size=16,
        neighbor_sizes=[10, 5],
        seed=2023,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=2)
    model.fit(epochs=200, batch_loader=batch_loader)


Epoch:   0 | Train loss: 0.417 | Train acc:  88.93%
Epoch:  50 | Train loss: 0.241 | Train acc:  89.89%
Epoch: 100 | Train loss: 0.203 | Train acc:  92.47%
Epoch: 150 | Train loss: 0.220 | Train acc:  94.01%
Epoch: 200 | Train loss: 0.226 | Train acc:  91.81%


## Predict

In [17]:
with MDBClient(host="127.0.0.1", port=8080) as client:
    PRED_BATCH_SIZE = 1000
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=sys.maxsize,
        batch_size=PRED_BATCH_SIZE,
        neighbor_sizes=[10, 5],
        seed=2023,
    )

    accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
    results = list()
    for batch in batch_loader:
        emb, out = model.predict(batch.node_features, batch.edge_index)
        y_pred = out[:PRED_BATCH_SIZE].argmax(axis=1)
        y_true = batch.node_labels[:PRED_BATCH_SIZE]
        acc = accuracy(y_pred, y_true)
        results.append(acc)
        print(f"Batch accuracy: {acc}")

    print(f"Mean accuracy: {sum(results) / len(results)}")

Batch accuracy: 0.8799999952316284
Batch accuracy: 0.9020000100135803
Batch accuracy: 0.875
Batch accuracy: 0.8640000224113464
Batch accuracy: 0.890999972820282
Batch accuracy: 0.902999997138977
Batch accuracy: 0.8679999709129333
Batch accuracy: 0.8799999952316284
Batch accuracy: 0.8899999856948853
Batch accuracy: 0.8840000033378601
Batch accuracy: 0.9010000228881836
Batch accuracy: 0.9010000228881836
Batch accuracy: 0.8899999856948853
Batch accuracy: 0.8939999938011169
Batch accuracy: 0.871999979019165
Batch accuracy: 0.8820000290870667
Batch accuracy: 0.921999990940094
Batch accuracy: 0.8700000047683716
Batch accuracy: 0.8610000014305115
Batch accuracy: 0.9020000100135803
Batch accuracy: 0.9020000100135803
Batch accuracy: 0.9010000228881836
Batch accuracy: 0.9110000133514404
Batch accuracy: 0.8960000276565552
Batch accuracy: 0.8899999856948853
Batch accuracy: 0.9190000295639038
Batch accuracy: 0.9279999732971191
Batch accuracy: 0.906000018119812
Batch accuracy: 0.8820000290870667
Bat

In [12]:
from pymdb import FeatureStoreManager

with MDBClient(host="127.0.0.1", port=8080) as client:
    print("a")


TypeError: object of type 'int' has no len()

In [13]:
print(batch)

Graph(node_features=[8708, 128] node_labels=[8708] edge_index=[2, 15813])
