In [5]:
import mlx.core as mx
import mlx.nn as nn

In [6]:
class MessagePassing(nn.Module):
    def __init__(self, aggr=None):
        self.aggr = aggr

    def __call__(self, x, edge_index, **kwargs):
        pass

    def propagate(self, x, edge_index, **kwargs):
        # process arguments and create *_kwargs
        
        src_idx, dst_idx = edge_index
        x_i = x[src_idx]
        x_j = x[dst_idx]

        # Message
        messages = self.message(x_i, x_j) #**msg_kwargs)

        # Aggregate
        aggregated = self.aggregate(messages, dst_idx) #**agg_kwargs)

        # Update
        output = self.update(aggregated) #**upd_kwargs)

        return output

    def message(self, x_i, x_j, **kwargs):
        return x_i

    def aggregate(self, messages, indices, **kwargs):
        if self.aggr == "add":
            nb_unique_indices = _unique(indices)
            empty_tensor = mx.zeros((nb_unique_indices, messages.shape[-1]))
            update_dim = (messages.shape[0], 1, messages.shape[1])
            return mx.scatter_add_(empty_tensor, [indices], messages.reshape(update_dim), [0], None)

    def update(self, aggregated, **kwargs):
        raise NotImplementedError

    
def _unique(array):
    return len(set(array.tolist()))

In [7]:
class GCNLayer(MessagePassing):
    def __init__(self, x_dim, h_dim, bias=True):
        super().__init__(aggr="add")
        
        self.linear = nn.Linear(x_dim, h_dim, bias)

    def __call__(self, x, edge_index, **kwargs):
        x = self.linear(x)
        x = self.propagate(x=x, edge_index=edge_index)

        return x

In [8]:
class GCN(nn.Module):
    def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
        super(GCN, self).__init__()

        layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
        self.gcn_layers = [
            GCNLayer(in_dim, out_dim, bias)
            for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
        ]
        self.dropout = nn.Dropout(p=dropout)

    def __call__(self, x, adj):
        for layer in self.gcn_layers[:-1]:
            x = nn.relu(layer(x, adj))
            x = self.dropout(x)

        x = self.gcn_layers[-1](x, adj)
        return x

In [50]:
gcn = GCN(1, 32)

x = mx.array([[1], [2], [3], [4]])
edge_index = mx.array([
    [0, 0, 0, 1, 2],
    [1, 2, 3, 0, 0]
])
# expect: [5, 1, 1, 1]

gcn(x, edge_index)

In [2]:
import torch
from torch_geometric.nn import GCNConv

x_torch = torch.tensor([[1], [2], [3], [4]])
edge_index_torch = torch.tensor([
    [0, 0, 0, 1, 2],
    [1, 2, 3, 0, 0]
])

gcn_torch = GCNConv()

TypeError: GCNConv.__init__() missing 2 required positional arguments: 'in_channels' and 'out_channels'