In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SageLayer(nn.Module):
    """
    GraphSAGE layer.
    """
    def __init__(self, in_features, out_features):
        super(SageLayer, self).__init__()
        self.fc_self = nn.Linear(in_features, out_features)
        self.fc_neigh = nn.Linear(in_features, out_features)

    def forward(self, input, adj):
        # Self features
        self_feats = self.fc_self(input)
        # Neighbors' features
        neigh_feats = self.fc_neigh(torch.mm(adj, input))
        # Mean aggregation
        out = (self_feats + neigh_feats) / 2
        return F.relu(out)

class GraphSAGE(nn.Module):
    """
    Simple GraphSAGE model with two layers.
    """
    def __init__(self, in_features, hidden_features, out_features):
        super(GraphSAGE, self).__init__()
        self.layer1 = SageLayer(in_features, hidden_features)
        self.layer2 = SageLayer(hidden_features, out_features)

    def forward(self, input, adj):
        x = self.layer1(input, adj)
        x = self.layer2(x, adj)
        return x

# Example usage:
# Initialize an adjacency matrix and node features
N = 5  # Number of nodes
F_in = 3  # Input feature dimension
adj = torch.randn((N, N))
adj = (adj + adj.t()) / 2  # Make it symmetric
adj = F.normalize(adj, p=1, dim=1)  # Normalize rows to sum to 1
node_features = torch.randn((N, F_in))

# Initialize and apply GraphSAGE
graphsage_model = GraphSAGE(F_in, 4, 2)
new_features = graphsage_model(node_features, adj)
print(new_features)
