# GraphSAGE

In [15]:
# Necessary to import from sibling directory
import sys
sys.path.append("..")

from pymdb import MDBClient, BatchLoader, NodeIterator, FeatureStoreManager

import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


## Model

In [18]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int, num_layers: int):
        super().__init__()
        if num_layers < 2:
            raise ValueError("Number of layers must be greater than 1")
        self.layers = torch.nn.ModuleList()
        self.layers.append(SAGEConv(dim_in, dim_h))
        for _ in range(num_layers - 2):
            self.layers.append(SAGEConv(dim_h, dim_h))
        self.layers.append(SAGEConv(dim_h, dim_out))

    def forward(
        self,
        node_features: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = node_features
        for layer in self.layers[:-1]:
            h = layer(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        h = self.layers[-1](h, edge_index)
        return h, F.log_softmax(h, dim=1)  # (embedding, prediction)

    def train(self, epochs, batch_loader):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)

                loss = criterion(out, batch.node_labels)
                acc = accuracy(out.argmax(dim=1), batch.node_labels)

                loss.backward()
                optimizer.step()

            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(
        self,
        initial_feature_store_name: str,
        final_feature_store_name: str,
        sampler: "Sampler",
        node_iterator: "NodeIterator",
    ):
        """
        Maybe it would be better to use something like a batch loader that receives a
        batch of node_ids as seed nodes and return a subgraph where:
        1. The first node_ids.size() features are the seed nodes features on the same order
        2. The remaining nodes are the neighbors of the seed nodes at depth 1
        3. The edges are the ones between the seed nodes and their neighbors

        Then the output would be the slice of the matrix [0:node_ids.size()]

        Notes:
        - Batches should have the original node_ids for storing the embeddings
        - Maybe FeatureStore should be closed after each usage in both batch loader and
          node iterator. This is because during the evaluation it will be modified and
          the changes should be saved between iterations
        - Maybe it is necessary to check if FeatureStore is closed before using it
          throwing an exception from C++
        """
        self.eval()
        # Pseudocode!

        # check stores exist

        prev_store = FeatureStore("temp1", self.dim_h)
        curr_store = FeatureStore("temp2", self.dim_h)
        final_store = FeatureStore(final_feature_store_name, self.dim_out)

        for batch in NodeIterator(initial_feature_store_name):
            h, _ = self(batch.node_features, batch.edge_index)
            prev_store[batch.node_ids] = h  # TODO: Implement this feature (multi_insert_tensor)

        # From now the variable prev_store is used for READ ONLY, while curr_store is 
        # used for WRITE ONLY
        for layer in self.layers[1:-1]:
            for batch in NodeIterator(prev_store.name):
                h, _ = layer(batch.node_features, batch.edge_index)
                h = F.relu(h)
                h = F.dropout(h, p=0.5, training=self.training)
                curr_store[batch.node_ids] = h  # TODO: Implement this feature (multi_insert_tensor)
            # Swap store references
            prev_store, curr_store = curr_store, prev_store

        for batch in NodeIterator(prev_store.name):
            h, _ = self.layers[-1](node_features, edge_index)
            final_store[batch.node_ids] = h  # TODO: Implement this feature (multi_insert_tensor)


## Train

In [19]:
with MDBClient() as client:
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=256,
        batch_size=64,
        neighbor_sizes=[10, 5],
        seed=2023,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=2, num_layers=5)
    print(model)
    model.train(epochs=50, batch_loader=batch_loader)


GraphSAGE(
  (layers): ModuleList(
    (0): SAGEConv(128, 64, aggr=mean)
    (1): SAGEConv(64, 64, aggr=mean)
    (2): SAGEConv(64, 64, aggr=mean)
    (3): SAGEConv(64, 64, aggr=mean)
    (4): SAGEConv(64, 2, aggr=mean)
  )
)
Epoch:   0 | Train loss: 0.478 | Train acc:  83.86%
Epoch:  50 | Train loss: 0.259 | Train acc:  90.13%


## Predict

In [14]:
with MDBClient() as client:
    fsm = FeatureStoreManager(client)

    with fsm.open("github") as fs:
        for batch in NodeIterator(client, 5000):
            node_ids, edge_index = expand_neighbors(batch)
        
            node_features = fs[neighborhood]

            embedding, prediction = model.predict(node_features, edge_index)

