# GraphSAGE

In [1]:
# Necessary to import from sibling directory
import sys
sys.path.append("..")

## Model

In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in: int, dim_h: int, dim_out: int):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(
        self,
        x: torch.Tensor,  # [num_nodes, feature_size]
        edge_index: torch.Tensor,  # [2, feature_size]
    ):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        # Return (embedding, class_prediction)
        return h, F.log_softmax(h, dim=1)

    def fit(self, epochs, batch_loader):
        self.train()
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        accuracy = lambda y_pred, y: ((y_pred == y).sum() / len(y)).item()
        for epoch in range(epochs + 1):
            # Train on batches
            for batch in batch_loader:
                print(batch)
                optimizer.zero_grad()
                _, out = self(batch.node_features, batch.edge_index)

                loss = criterion(out, batch.node_labels)
                acc = accuracy(out.argmax(dim=1), batch.node_labels)

                loss.backward()
                optimizer.step()

            # Print metrics every 50 epochs
            if epoch % 50 == 0:
                print(
                    f"Epoch: {epoch:>3} | Train loss: {loss:.3f} | Train acc: {acc*100:>6.2f}%"
                )

    def predict(self, X, edge_index):
        self.eval()
        return self(X, edge_index)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from pymdb import MDBClient, BatchLoader

with MDBClient(host="127.0.0.1", port=8080) as client:
    batch_loader = BatchLoader(
        client=client,
        feature_store_name="github",
        num_seeds=64,
        batch_size=16,
        neighbor_sizes=[5, 5],
        seed=2023,
    )

    model = GraphSAGE(dim_in=128, dim_h=64, dim_out=2)
    model.fit(epochs=100, batch_loader=batch_loader)


Graph(node_features=torch.Size([296, 128]) node_labels=torch.Size([296]) edge_index=torch.Size([2, 308]))


AttributeError: 'Graph' object has no attribute 'y'