# GraphSAGE

In [39]:
# Necessary to import from sibling directory
import sys

sys.path.append("..")

from pymdb import (
    MDBClient,
    TrainBatchLoader,
    EvalBatchLoader,
    FeatureStoreManager,
    Sampler,
)

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from typing import List


## Model

In [40]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int, num_layers: int):
        if num_layers < 2:
            raise ValueError("Number of layers must be greater than 1")

        super().__init__()

        self.dim_in = dim_in
        self.dim_h = dim_h
        self.dim_out = dim_out
        self.num_layers = num_layers

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(dim_in, dim_h))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(dim_h, dim_h))
        self.convs.append(SAGEConv(dim_h, dim_out))

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(
        self,
        node_features: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = node_features
        for layer in self.convs[:-1]:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        h = self.convs[-1](h, edge_index)
        return h, F.log_softmax(h, dim=1)  # (embedding, prediction)

    def fit(
        self,
        epochs: int,
        train_batch_loader: "TrainBatchLoader",
    ):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in train_batch_loader:
                num_seeds = len(batch.seed_ids)
                optimizer.zero_grad()
                out = self(batch.node_features, batch.edge_index)[1][:num_seeds]
                loss = criterion(out, batch.node_labels[:num_seeds])
                acc = accuracy(out.argmax(dim=1), batch.node_labels[:num_seeds])
                loss.backward()
                optimizer.step()
            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def compute_embeddings(
        self,
        client: "MDBClient",
        initial_feature_store_name: str,
        batch_size: int,
        num_neighbors: List[int],
    ):
        self.eval()
        ebl_args = {
            "client": client,
            "batch_size": batch_size,
            "num_neighbors": num_neighbors,
        }

        fsm = FeatureStoreManager(client)

        store_names = fsm.list()
        if "temp1" in store_names:
            fsm.remove("temp1")
        if "temp2" in store_names:
            fsm.remove("temp2")
        if "final" in store_names:
            fsm.remove("final")

        fsm.create(name="temp1", feature_size=self.dim_h)
        fsm.create(name="temp2", feature_size=self.dim_h)
        fsm.create(name="final", feature_size=self.dim_out)

        prev_store = fsm.open("temp1")
        curr_store = fsm.open("temp2")
        final_store = fsm.open("final")

        print(f"Evaluating layer 1/{self.num_layers}")
        for batch in EvalBatchLoader(
            **ebl_args, feature_store_name=initial_feature_store_name
        ):
            num_seeds = len(batch.seed_ids)
            h = self.convs[0](batch.node_features, batch.edge_index)[:num_seeds]
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
            prev_store[batch.seed_ids] = h
        # Flush the store to disk
        prev_store.close()
        prev_store = fsm.open(prev_store.name)

        # From now the variable prev_store is used for READ ONLY, while curr_store is
        # used for WRITE ONLY
        for idx, layer in enumerate(self.convs[1:-1]):
            print(f"Evaluating layer {idx + 2}/{self.num_layers}")
            for batch in EvalBatchLoader(**ebl_args, feature_store_name=prev_store.name):
                num_seeds = len(batch.seed_ids)
                h = layer(batch.node_features, batch.edge_index)[:num_seeds]
                h = F.relu(h)
                h = F.dropout(h, p=0.5, training=self.training)
                curr_store[batch.seed_ids] = h
            # Flush the store to disk
            curr_store.close()
            curr_store = fsm.open(curr_store.name)
            # Swap store references
            prev_store, curr_store = curr_store, prev_store

        print(f"Evaluating layer {self.num_layers}/{self.num_layers}")
        for batch in EvalBatchLoader(**ebl_args, feature_store_name=prev_store.name):
            num_seeds = len(batch.seed_ids)
            h = self.convs[-1](batch.node_features, batch.edge_index)[:num_seeds]
            final_store[batch.seed_ids] = h

        print(
            f"Done. Final embeddings are stored in the FeatureStore {final_store.name}"
        )

        prev_store.close()
        curr_store.close()
        final_store.close()

        # Remove temporary stores
        fsm.remove("temp1")
        fsm.remove("temp2")


## Train

In [41]:
with MDBClient() as client:
    sampler = Sampler(client=client)
    seeds_ids = sampler.get_seed_ids(num_seeds=256)
    tbl = TrainBatchLoader(
        client=client,
        feature_store_name="github",
        batch_size=64,
        num_neighbors=[5, 5],
        seed_ids=seeds_ids,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=16, num_layers=5)
    model.fit(epochs=1, train_batch_loader=tbl)


Epoch:   0 | Train loss: 1.001 | Train acc:  71.88%


## Compute embeddings

In [42]:
with MDBClient() as client:
    initial_feature_store_name = "github"
    batch_size = 1000
    num_neighbors = [5, 5]

    model.compute_embeddings(
        client=client,
        initial_feature_store_name=initial_feature_store_name,
        batch_size=batch_size,
        num_neighbors=num_neighbors,
    )


Evaluating layer 1/5
Evaluating layer 2/5


Exception: The node id "2334102065170612224" does not exist in the feature store

In [45]:
with MDBClient() as client:
    ebl = EvalBatchLoader(
        client=client,
        batch_size=1000,
        num_neighbors=[5, 5],
        feature_store_name="github"
    )
    tot = 0
    for batch in ebl:
        tot+=len(batch.seed_ids)
    print(tot)

37700
