In [None]:
import tensorflow as tf
import ase
import numpy as np
import itertools
import jax
import jraph

import sys
sys.path.append("..")

import input_pipeline
import input_pipeline_tf
import fragments
import datatypes

In [None]:
def get_unbatched_tetris_datasets(
    config
):
    """Loads the raw Tetris dataset as tf.data.Datasets for each split."""
    # Taken from e3nn Tetris example.
    # https://docs.e3nn.org/en/stable/examples/tetris_gate.html
    pieces = [
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 1, 0)],   # chiral_shape_1
        [(0, 0, 0), (0, 0, 1), (1, 0, 0), (1, -1, 0)],  # chiral_shape_2
        [(0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0)],   # square
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3)],   # line
        [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0)],   # corner
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0)],   # L
        [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 1)],   # T
        [(0, 0, 0), (1, 0, 0), (1, 1, 0), (2, 1, 0)],   # zigzag
    ]

    # Convert to molecules, and then jraph.GraphsTuples.
    pieces_as_molecules = [
        ase.Atoms(numbers=[1] * 4, positions=np.array(piece)) for piece in pieces
    ]
    pieces_as_graphs = [
        input_pipeline.ase_atoms_to_jraph_graph(molecule, [1], nn_cutoff=1.1)
        for molecule in pieces_as_molecules
    ]
    print(pieces_as_graphs[0])
    fragments_for_pieces = itertools.chain.from_iterable(
        fragments.generate_fragments(
            jax.random.PRNGKey(0),
            graph,
            n_species=1,
            nn_tolerance=0.01,
            max_radius=1.01,
            mode="nn",
        )
        for graph in pieces_as_graphs
    )
    def fragment_yielder():
        yield from fragments_for_pieces

    graph = next(iter(fragments_for_pieces))
    dataset = tf.data.Dataset.from_generator(
        fragment_yielder,
        output_signature = jraph.GraphsTuple(
            nodes=datatypes.FragmentsNodes(
                positions=tf.TensorSpec(shape=(None, 3), dtype=graph.nodes.positions.dtype),
                species=tf.TensorSpec(shape=(None,), dtype=graph.nodes.species.dtype),
                focus_and_target_species_probs=tf.TensorSpec(
                    shape=(None, 1),
                    dtype=graph.nodes.focus_and_target_species_probs.dtype,
                ),
            ),
            globals=datatypes.FragmentsGlobals(
                target_positions=tf.TensorSpec(shape=(1, 3), dtype=graph.globals.target_positions.dtype),
                target_species=tf.TensorSpec(shape=(1,), dtype=graph.globals.target_species.dtype),
                stop=tf.TensorSpec(shape=(1,), dtype=graph.globals.stop.dtype),
            ),
            edges=tf.TensorSpec(shape=(None,), dtype=graph.edges.dtype),
            receivers=tf.TensorSpec(shape=(None,), dtype=graph.receivers.dtype),
            senders=tf.TensorSpec(shape=(None,), dtype=graph.senders.dtype),
            n_node=tf.TensorSpec(shape=(None,), dtype=graph.n_node.dtype),
            n_edge=tf.TensorSpec(shape=(None,), dtype=graph.n_edge.dtype),
        )
    )

    return dataset

for x in get_unbatched_tetris_datasets(None):
    print(x)
