# GraphSAGE

Important:
Optimize the [:batch.num_seeds] slicing calls!

In [2]:
# Necessary to import from sibling directory
import sys

sys.path.append("..")

from pymdb import (
    MDBClient,
    TrainBatchLoader,
    EvalBatchLoader,
    FeatureStoreManager,
    Sampler,
)

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


## Model

In [23]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int, num_layers: int):
        if num_layers < 2:
            raise ValueError("Number of layers must be greater than 1")

        super().__init__()

        self.dim_in = dim_in
        self.dim_h = dim_h
        self.dim_out = dim_out
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(dim_in, dim_h))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(dim_h, dim_h))
        self.convs.append(SAGEConv(dim_h, dim_out))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(
        self,
        node_features: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = node_features
        for layer in self.convs[:-1]:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        h = self.convs[-1](h, edge_index)
        return h, F.log_softmax(h, dim=1)  # (embedding, prediction)

    def fit(
        self,
        epochs: int,
        batch_loader: "TrainBatchLoader",
    ):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)
                out = out[: batch.num_seeds]
                loss = criterion(out, batch.node_labels[: batch.num_seeds])
                acc = accuracy(out.argmax(dim=1), batch.node_labels[: batch.num_seeds])
                loss.backward()
                optimizer.step()
            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(
        self,
        fsm: "FeatureStoreManager",
        batch_loader: "EvalBatchLoader",
    ):
        # raise NotImplemented("Implement this")
        self.eval()

        store_names = fsm.list()
        if "temp1" in store_names:
            fsm.remove("temp1")
        if "temp2" in store_names:
            fsm.remove("temp2")
        if "final" in store_names:
            fsm.remove("final")

        fsm.create(name="temp1", feature_size=self.dim_h)
        fsm.create(name="temp2", feature_size=self.dim_h)
        fsm.create(name="final", feature_size=self.dim_out)

        prev_store = fsm.open("temp1")
        curr_store = fsm.open("temp2")
        final_store = fsm.open("final")

        print(f"Evaluating layer 1/{self.num_layers}")
        for batch in batch_loader:
            h = self.layers[0](batch.node_features, batch.edge_index)[: batch.num_seeds]
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
            # TODO: implement batch.seed_ids (then num_seeds is not needed)
            # TODO: implement FeatureStore::multi_insert_tensor()
            prev_store[batch.seed_ids] = h

        # From now the variable prev_store is used for READ ONLY, while curr_store is
        # used for WRITE ONLY
        for idx, layer in enumerate(self.layers[1:-1]):
            print(f"Evaluating layer {idx + 2}/{self.num_layers}")
            for batch in batch_loader:
                h = layer(batch.node_features, batch.edge_index)[: batch.num_seeds]
                h = F.relu(h)
                h = F.dropout(h, p=0.5, training=self.training)
                curr_store[batch.seed_ids] = h
            # Swap store references
            prev_store, curr_store = curr_store, prev_store

        print(f"Evaluating layer {self.num_layers}/{self.num_layers}")
        for batch in batch_loader:
            h = self.layers[-1](batch.node_features, batch.edge_index)
            final_store[batch.seed_ids] = h

        print(
            f"Done. Final embeddings are stored in the FeatureStore {final_store.name}"
        )

        prev_store.close()
        curr_store.close()
        final_store.close()

        # Remove temporary stores
        fsm.remove("temp1")
        fsm.remove("temp2")


## Train

In [21]:
with MDBClient() as client:
    sampler = Sampler(client=client)
    seeds_ids = sampler.get_seed_ids(num_seeds=256)
    tbl = TrainBatchLoader(
        client=client,
        feature_store_name="github",
        batch_size=64,
        num_neighbors=[5, 5],
        seed_ids=seeds_ids,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=16, num_layers=5)
    model.fit(epochs=500, batch_loader=tbl)


Epoch:   0 | Train loss: 1.179 | Train acc:  71.88%
Epoch:  50 | Train loss: 0.147 | Train acc:  95.31%
Epoch: 100 | Train loss: 0.115 | Train acc:  93.75%
Epoch: 150 | Train loss: 0.119 | Train acc:  92.19%
Epoch: 200 | Train loss: 0.270 | Train acc:  96.88%
Epoch: 250 | Train loss: 0.493 | Train acc:  93.75%
Epoch: 300 | Train loss: 0.028 | Train acc:  98.44%
Epoch: 350 | Train loss: 0.367 | Train acc:  93.75%
Epoch: 400 | Train loss: 0.177 | Train acc:  95.31%
Epoch: 450 | Train loss: 0.069 | Train acc:  96.88%
Epoch: 500 | Train loss: 0.013 | Train acc: 100.00%


## Predict

In [7]:
with MDBClient() as client:
    model.predict(
        client=client,
        initial_store_name="github",
        sampler=None,
        batch_size=1,
    )


Evaluating layer 1/5
Evaluating layer 2/5
Evaluating layer 3/5
Evaluating layer 4/5
Evaluating layer 5/5


In [9]:
with MDBClient() as client:
    node_iterator = NodeIterator(client=client, batch_size=100)
    first_batch = next(node_iterator)
    fsm = FeatureStoreManager(client)
    fs = fsm.open("final")
    print(fs[first_batch].shape)
    print(fs[first_batch])

torch.Size([100, 16])
tensor([[  5.0263,   6.7727, -10.8682,  ..., -10.0266, -10.0361,  -9.7131],
        [  6.5758,   9.1326, -14.5334,  ..., -13.4076, -13.3578, -12.9348],
        [ 11.7308,  16.7017, -26.4572,  ..., -24.3571, -24.2932, -23.4111],
        ...,
        [  8.8362,  12.4285, -19.7243,  ..., -18.1675, -18.1388, -17.4762],
        [  2.6628,   2.3697,  -4.4604,  ...,  -4.4059,  -4.3617,  -4.4770],
        [  8.3151,  11.6126, -18.5166,  ..., -17.0441, -16.9884, -16.4348]])
