In [None]:
from typing import Sequence, Tuple, Callable

import os
import sys

from absl import flags
from absl import app
from absl import logging
import ase
import ase.data
from ase.db import connect
import ase.io
import ase.visualize
import jax
import jax.numpy as jnp
import jraph
import numpy as np
import tqdm
import chex
import optax
import time

sys.path.append("..")

import analyses.analysis as analysis
from symphony import datatypes
from symphony.data import input_pipeline
from symphony import models

FLAGS = flags.FLAGS

# stop [False True False]
# segment_ids     [0 0 0 1 1 2 (3) 3 (3) 3 3]
# new_segment_ids [0 0 0 1 1 2 (0) 3 (2) 3 3]
# sort: [0 0 0 (0) 1 1 2 (2) 3  3 3]


@jax.jit
def append_predictions_new(
    preds: datatypes.Predictions, padded_fragment: datatypes.Fragments, nn_cutoff: float
) -> datatypes.Fragments:
    """Appends the predictions to the padded fragment."""
    num_nodes = padded_fragment.nodes.positions.shape[0]
    num_graphs = padded_fragment.n_node.shape[0]
    num_padding_nodes = padded_fragment.n_node[-1]
    num_valid_nodes = num_nodes - num_padding_nodes
    num_padding_graphs = 1
    num_valid_graphs = num_graphs - num_padding_graphs
    num_edges = padded_fragment.senders.shape[0]

    stop = preds.globals.stop[:num_valid_graphs]
    num_unstopped_graphs = (~stop).sum()
    dummy_nodes_indices = num_valid_nodes + jnp.arange(num_valid_graphs)

    print(num_valid_nodes, dummy_nodes_indices)
    # Update segment ids of the first dummy nodes.
    segment_ids = models.get_segment_ids(padded_fragment.n_node, num_nodes)
    print("old segment_ids", segment_ids)
    dummy_new_segment_ids = segment_ids[dummy_nodes_indices]
    dummy_new_segment_ids = jnp.where(
        stop, dummy_new_segment_ids, jnp.arange(num_valid_graphs)
    )
    segment_ids = segment_ids.at[dummy_nodes_indices].set(dummy_new_segment_ids)
    print("new segment_ids", segment_ids)

    # Update positions of the first dummy nodes.
    positions = padded_fragment.nodes.positions
    print("old positions", positions)
    focuses = preds.globals.focus_indices[:num_valid_graphs]
    focus_positions = positions[focuses]
    target_positions_relative_to_focus = preds.globals.position_vectors[:num_valid_graphs]
    target_positions = (
        target_positions_relative_to_focus + focus_positions
    )
    print("target positions", target_positions)
    dummy_positions = positions[dummy_nodes_indices]
    dummy_new_positions = jnp.where(stop[:, None], dummy_positions, target_positions)
    positions = positions.at[dummy_nodes_indices].set(dummy_new_positions)
    print("new positions", positions)

    # Update the species of the first dummy nodes.
    species = padded_fragment.nodes.species
    print("old species", species)
    target_species = preds.globals.target_species
    print("target species", target_species)
    dummy_species = species[dummy_nodes_indices]
    dummy_new_species = jnp.where(stop, dummy_species, target_species)
    species = species.at[dummy_nodes_indices].set(dummy_new_species)
    print("new species", species)

    # Sort nodes according to segment ids.
    sort_indices = jnp.argsort(segment_ids, kind='stable')
    segment_ids = segment_ids[sort_indices]
    positions = positions[sort_indices]
    species = species[sort_indices]

    # Compute the distance matrix to select the edges.
    distance_matrix = jnp.linalg.norm(
        positions[None, :, :] - positions[:, None, :], axis=-1
    )
    node_indices = jnp.arange(num_nodes)

    # Avoid self-edges and linking across graphs.
    valid_edges = (distance_matrix > 0) & (distance_matrix <= nn_cutoff)
    valid_edges = valid_edges & (segment_ids[None, :] == segment_ids[:, None])
    valid_edges = (
        valid_edges
        & (node_indices[None, :] < num_valid_nodes + num_unstopped_graphs)
        & (node_indices[:, None] < num_valid_nodes + num_unstopped_graphs)
    )
    senders, receivers = jnp.nonzero(valid_edges, size=num_edges, fill_value=-1)

    # Update the number of nodes and edges.
    n_node = jnp.bincount(segment_ids, length=num_graphs)
    # Segment_ids map nodes to graph.
    # senders and receivers map nodes to edges.
    n_edge = jnp.bincount(segment_ids[senders], length=num_graphs)

    return padded_fragment._replace(
        nodes=padded_fragment.nodes._replace(
            positions=positions,
            species=species,
        ),
        n_node=n_node,
        n_edge=n_edge,
        senders=senders,
        receivers=receivers,
    )

In [None]:
preds = datatypes.Predictions(
    nodes=None,
    edges=None,
    receivers=None,
    senders=None,
    n_node=None,
    n_edge=None,
    globals=datatypes.GlobalPredictions(
        stop=jnp.array([False, True, False]),
        focus_indices=jnp.array([0, 1, 3]),
        position_vectors=jnp.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
        target_species=jnp.array([1, 2, 3]),
        stop_logits=None,
        stop_probs=None,
        position_logits=None,
        position_coeffs=None,
        position_probs=None,
    ),
)
fragment = datatypes.Fragments(
    nodes=datatypes.NodesInfo(
        positions=jnp.asarray([[1, 0, 0], [2, 0, 0], [3, 0, 0], [4, 0, 0], [5, 0, 0], [6, 0, 0]], dtype=jnp.float32),
        species=jnp.asarray([1, 1, 1, 2, 2, 2]),
    ),
    edges=jnp.ones(5),
    senders=jnp.asarray([0, 2, 1, 3, 3]),
    receivers=jnp.asarray([0, 1, 2, 4, 4]),
    n_node=jnp.array([1, 2, 3]),
    n_edge=jnp.array([1, 2, 2]),
    globals=None,
)
fragment

In [None]:
padded_fragment = jraph.pad_with_graphs(fragment, n_node=10, n_edge=20, n_graph=4)
padded_fragment = jax.tree_map(jnp.asarray, padded_fragment)
padded_fragment

In [None]:
append_predictions_new(preds, padded_fragment, 1.5)