In [35]:
import jraph
import jax
import jax.numpy as jnp
import e3nn_jax as e3nn
import numpy as np

In [36]:
graphs = jraph.GraphsTuple(
    nodes=jnp.array([[1, 2, 3], [10, 2, 3], [4, 5, 6], [7, 8, 9]]),
    edges=jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    globals=jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
    n_node=jnp.array([2, 1, 1]),
    n_edge=jnp.array([1, 1, 1]),
    senders=jnp.array([0, 0, 0]),
    receivers=jnp.array([0, 0, 0]),
)

In [42]:
def sample_until_completion(rng, graph):
    num_nodes = graph.nodes.shape[0]

    # node_embeddings = self.embeddings(graph)
    node_embeddings = graph.nodes

    # focus_logits = self.focus(node_embeddings)
    focus_logits = jnp.array([1.1, 0.9])
    focus_probs = jax.nn.softmax(focus_logits)

    # Sample the focus node.
    rng, focus_rng = jax.random.split(rng)
    focus_node_indices = jax.random.choice(focus_rng, jnp.eye(num_nodes), p=focus_probs)

    # Get the embeddings of the focus node.
    focus_node_embeddings = node_embeddings[focus_node_indices]

    # Get the species logits.
    # species_logits = self.target_species(focus_node_embeddings)
    NUM_ELEMENTS = 5
    species_logits = jnp.array([1.1, 0.9, 0.5, 0.3, 0.1])
    species_probs = jax.nn.softmax(species_logits)

    # Sample the target species.
    rng, species_rng = jax.random.split(rng)
    target_species = jax.random.choice(species_rng, NUM_ELEMENTS, p=species_probs)

    # Get the position coefficients.
    position_coeffs = self.target_position(
        focus_node_embeddings, target_species
    )
    # Compute the position signal projected to a spherical grid for each radius.
    position_signal = e3nn.to_s2grid(
        position_coeffs,
        res_beta,
        res_alpha,
        quadrature="gausslegendre",
        normalization="integral",
        p_val=1,
        p_arg=-1,
    )

    # Integrate the position signal over each sphere to get the normalizing factors for the radii.
    # For numerical stability, we subtract out the maximum value over all spheres before exponentiating.
    position_max = jnp.max(
        position_signal.grid_values, axis=(-3, -2, -1), keepdims=True
    )
    sphere_normalizing_factors = position_signal.apply(
        lambda pos: jnp.exp(pos - position_max)
    ).integrate()
    sphere_normalizing_factors = sphere_normalizing_factors.array.squeeze(axis=-1)
    radius_probs = jax.nn.softmax(sphere_normalizing_factors)

    # sphere_normalizing_factors is of shape (num_graphs, num_radii)
    assert radius_probs.shape == (
        num_graphs,
        num_radii,
    )

    rng, species_rng = jax.random.split(rng)
    radius = jax.random.choice(radius_rng, NUM_RADII, p=radius_probs)

    # Get the position signal for the sampled radius.
    position = position_signal.sample()

    # Create a new molecule with the sampled species and position.
    return datatypes.Fragment()



Array([0., 1., 0.], dtype=float32)

In [13]:
focus = graphs.nodes.sum(axis=1)
focus, e3nn.scatter_max(focus, nel=graphs.n_node)

(Array([ 6, 15, 15, 24], dtype=int32),
 Array([15., 15., 24.], dtype=float32, weak_type=True))

In [3]:
nodes = jnp.asarray([2, 3, 1, 4])
segment_max = e3nn.scatter_max(nodes, nel=graphs.n_node)
segment_max_expanded = e3nn.scatter_max(nodes, map_back=True, nel=graphs.n_node)
print(segment_max_expanded)
# segment_max_expanded = jnp.asarray([segment_max[0], segment_max[0], segment_max[1], segment_max[2]])

expected = 0 + jnp.log(1 + e3nn.scatter_sum(jnp.exp(nodes - 0), nel=graphs.n_node))
computed = segment_max + jnp.log(jnp.exp(-segment_max) + e3nn.scatter_sum(jnp.exp(nodes - segment_max_expanded), nel=graphs.n_node))

expected, computed

[3. 3. 1. 4.]


(Array([3.3490124, 1.3132616, 4.01815  ], dtype=float32),
 Array([3.3490121, 1.3132616, 4.01815  ], dtype=float32))

In [4]:
lmax = 3
irreps = e3nn.Irreps(e3nn.Irrep.iterator(lmax))
e3nn.IrrepsArray(irreps=irreps, array=jnp.ones((10, irreps.dim)))

1x0e+1x0o+1x1o+1x1e+1x2e+1x2o+1x3o+1x3e
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
  1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1

In [5]:
e3nn.scatter_sum(data=graphs.nodes, nel=graphs.n_node)

Array([[2., 4., 6.],
       [4., 5., 6.],
       [7., 8., 9.]], dtype=float32, weak_type=True)