In [2]:
import msprime
from IPython.display import SVG, display
import networkx as nx
import jax
import jax.numpy as nmp
from jax.nn import log_softmax
import equinox as eqx
import itertools as it
import optax
import tskit
import numpy as np
from collections import defaultdict
from functools import partial

Here, I play around a little bit with torch GNN trying to get a more sophisticated version of message passing to work.

In [19]:
class MLPMessageLayer(eqx.Module):
  phi: eqx.nn.MLP
  gamma: eqx.nn.MLP

  def __init__(self, in_dim, hidden_dim = 64, *, key):
    k1, k2 = jax.random.split(key)

    self.phi = eqx.nn.MLP(
      in_size=2 * in_dim, # (features of the node and its neighbor)
      out_size= hidden_dim,
      width_size= hidden_dim,
      depth = 2,
      activation=jax.nn.relu,
      final_activation=jax.nn.relu,
      key = k1
    )

    self.gamma = eqx.nn.MLP(
      in_size=in_dim + hidden_dim,
      out_size=hidden_dim,
      width_size=hidden_dim,
      depth=2,
      activation=jax.nn.relu,
      final_activation=None,
      key=k2
    )

  def __call__(self, X, senders, receivers):
    """
    X: [N, D] node features
    senders: [E] edge source indices
    receivers: [E] edge target indices
    """
    # gather features
    X_j = X[senders]
    X_i = X[receivers]

    m_ji = self.phi(nmp.concatenate([X_i, X_j], axis=-1))

    # aggregate by mean
    N = X.shape[0]
    m_sum = jax.ops.segment_sum(m_ji, receivers, N)
    deg = jax.ops.segment_sum(nmp.ones_like(receivers), receivers, N)
    deg = nmp.maximum(deg, 1)[:, None]
    m_i = m_sum / deg

    return self.gamma(nmp.concatenate([X, m_i], axis=-1))

In [25]:
class MPNNEncoder(eqx.Module):
  layers: list

  def __init__(self, in_dim, hidden_dim, num_layers, *, key):
    keys = jax.random.split(key, num_layers)
    self.layers = [MLPMessageLayer(in_dim, hidden_dim, key=keys[0])]
    for k in keys[1:]:
      self.layers.append(MLPMessageLayer(hidden_dim, hidden_dim, key = k))
  
  def __call__(self, X, senders, receivers):
    for layer in self.layers:
      X = layer(X, senders, receivers)
    return X

In [26]:
class SPRGPredictor(eqx.Module):
  encoder: MPNNEncoder
  head: eqx.nn.MLP

  def __init__(self, in_dim, hidden_dim, num_layers, *, key):
    k_enc, k_head = jax.random.split(key)
    self.encoder = MPNNEncoder(in_dim, hidden_dim, num_layers, key = k_enc)
    self.head = eqx.nn.MLP(
      in_size = 2*hidden_dim,
      out_size= 1, # this will be whole edge prediction, where it starts, where it ends. Right now, we are doing logit prediction for all possib  le edge starts, so the cardinality of this set will be E.
      width_size=hidden_dim,
      depth=2,
      activation=jax.nn.relu,
      final_activation=None,
      key = k_head
    )
  
  def __call__(self, X, senders, receivers):
    H = self.encoder(X, senders, receivers)
    h_i, h_j = H[receivers], H[senders]
    logits = self.head(nmp.concatenate([h_i, h_j], axis=-1)).squeeze(-1)
    return logits

In [27]:
key = jax.random.PRNGKey(0)

In [29]:
model = SPRGPredictor(in_dim=100, hidden_dim=64, num_layers=8, key=key)
model(X=node_feats)

NameError: name 'node_feats' is not defined

Training loops

In [None]:
@eqx.filter_value_and_grad
def loss(model, X, senders, receivers, y):
  logits = model(X, senders, receivers)
  preds = nmp.argmax()