In [1]:
import sys
sys.path.append('../')

import jraph
import jax.numpy as jnp
import jax.random as jr
import jax

import equiformer.graphs as graphs
import equiformer.layers as layers
import equiformer.examples.tetris as tetris

from jax.lax import gather

In [2]:
classifier = tetris.ShapeClassifier(jr.PRNGKey(1))

graph = graphs.create_rand_graph(100, 200, {0: 1}, jr.PRNGKey(0))
out_graph = classifier(graph)

In [3]:
out_graph

Array([0.14783327, 0.1108437 , 0.15374263, 0.16561948, 0.10789527,
       0.08064634, 0.16812624, 0.06529311], dtype=float32)

In [11]:
def create_connected_graph(coords: jnp.ndarray, add_self_edges: bool = False):
    """Create connected graph from coordinates"""
    n_node = coords.shape[-2]

    # Create all to all graph
    senders, receivers = jnp.meshgrid(
      jnp.arange(n_node), jnp.arange(n_node))
    if not add_self_edges:
        senders = jax.vmap(jnp.roll)(senders, -jnp.arange(len(senders)))[:, 1:]
        receivers = receivers[:, 1:]
    senders = senders.flatten()
    receivers = receivers.flatten()
    assert senders.shape == receivers.shape
    n_edge = senders.shape[0]

    edges = graphs.create_edge_features(coords, senders, receivers)

    return jraph.GraphsTuple(nodes={-1: coords, 0: jnp.ones((n_node, 1, 1), dtype=float)}, edges=edges, receivers=receivers, senders=senders, globals=None, n_node=n_node, n_edge=n_edge)

In [14]:
blocks = [jnp.array(block, dtype=float) for block in tetris.TETRIS_BLOCKS]

example_block = blocks[0]

In [18]:
import jraph._src.utils as jutils

In [19]:
tetris_graph = jutils.get_fully_connected_graph(n_node_per_graph=4, n_graph=1, node_features={-1: example_block, 0: jnp.ones((4, 1, 1))}, add_self_edges=False)

  num_node_features = jax.tree_leaves(node_features)[0].shape[0]


In [20]:
tetris_graph

GraphsTuple(nodes={-1: Array([[0., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 1., 0.]], dtype=float32), 0: Array([[[1.]],

       [[1.]],

       [[1.]],

       [[1.]]], dtype=float32)}, edges=None, receivers=Array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=int32), senders=Array([1, 2, 3, 2, 3, 0, 3, 0, 1, 0, 1, 2], dtype=int32), globals=None, n_node=Array([4], dtype=int32), n_edge=Array([12], dtype=int32))