In [1]:
import sys, os
sys.path.append("..")
from importlib import reload
import src.dft as dft
import src.utils as util
from jax_md import space
import src.io as io

import jax.numpy as jnp
import jax
from jax import vmap, grad, jacobian
from tqdm import tqdm



In [2]:
from typing import Any, NamedTuple, Iterable, Mapping, Union, Optional

ArrayTree = Union[jnp.ndarray, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]

class GraphsTuple(NamedTuple):
    nodes: Optional[ArrayTree]
    edges: Optional[ArrayTree]
    receivers: Optional[jnp.ndarray]  # with integer dtype
    senders: Optional[jnp.ndarray]  # with integer dtype
    globals: Optional[ArrayTree]
    n_node: jnp.ndarray  # with integer dtype
    n_edge: jnp.ndarray   # with integer dtype
    e_order: Optional[jnp.ndarray]    
    e_mask: jnp.ndarray
    n_mask: jnp.ndarray

In [3]:
dataloc = "../data/sio2_201/"

savefile = util.fileloc(io.savefile, dataloc)
loadfile = util.fileloc(io.loadfile, dataloc.replace("_redo", ""))

In [4]:
def _batch(graphs, np_):
    """Returns batched graph given a list of graphs and a numpy-like module."""
    # Calculates offsets for sender and receiver arrays, caused by concatenating
    # the nodes arrays.
    offsets = np_.cumsum(
        np_.array([0] + [np_.sum(g.n_node) for g in graphs[:-1]]))

    edge_order_offsets = np_.cumsum(
        np_.array([0] + [len(g.senders) for g in graphs[:-1]]))
    
    def _map_concat(nests):
        concat = lambda *args: np_.concatenate(args)
        return jax.tree_multimap(concat, *nests)

    return GraphsTuple(
        n_node=np_.concatenate([g.n_node for g in graphs]),
        n_edge=np_.concatenate([g.n_edge for g in graphs]),
        nodes=_map_concat([g.nodes for g in graphs]),
        edges=_map_concat([g.edges for g in graphs]),
        e_mask=_map_concat([g.e_mask for g in graphs]),
        n_mask=_map_concat([g.n_mask for g in graphs]),
        e_order=_map_concat([g.e_order + o for g, o in zip(graphs, edge_order_offsets)]),
        globals=_map_concat([g.globals for g in graphs]),
        senders=np_.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
        receivers=np_.concatenate(
          [g.receivers + o for g, o in zip(graphs, offsets)]))


In [5]:
def pad_with_graphs(graph: GraphsTuple,
                    n_node: int,
                    n_edge: int,
                    n_graph: int = 2) -> GraphsTuple:
    """Pads a ``GraphsTuple`` to size by adding computation preserving graphs.
    The ``GraphsTuple`` is padded by first adding a dummy graph which contains the
    padding nodes and edges, and then empty graphs without nodes or edges.
    The empty graphs and the dummy graph do not interfer with the graphnet
    calculations on the original graph, and so are computation preserving.
    The padding graph requires at least one node and one graph.
    This function does not support jax.jit, because the shape of the output
    is data-dependent.
    Args:
    graph: ``GraphsTuple`` padded with dummy graph and empty graphs.
    n_node: the number of nodes in the padded ``GraphsTuple``.
    n_edge: the number of edges in the padded ``GraphsTuple``.
    n_graph: the number of graphs in the padded ``GraphsTuple``. Default is 2,
      which is the lowest possible value, because we always have at least one
      graph in the original ``GraphsTuple`` and we need one dummy graph for the
      padding.
    Raises:
    ValueError: if the passed ``n_graph`` is smaller than 2.
    RuntimeError: if the given ``GraphsTuple`` is too large for the given
      padding.
    Returns:
    A padded ``GraphsTuple``.
    """
    np = jnp
    if n_graph < 2:
        raise ValueError(
            f'n_graph is {n_graph}, which is smaller than minimum value of 2.')
    graph = jax.device_get(graph)
    pad_n_node = int(n_node - np.sum(graph.n_node))
    pad_n_edge = int(n_edge - np.sum(graph.n_edge))
    pad_n_graph = int(n_graph - graph.n_node.shape[0])
    if pad_n_node <= 0 or pad_n_edge < 0 or pad_n_graph <= 0:
        raise RuntimeError(
            'Given graph is too large for the given padding. difference: '
            f'n_node {pad_n_node}, n_edge {pad_n_edge}, n_graph {pad_n_graph}')

    pad_n_empty_graph = pad_n_graph - 1

    tree_nodes_pad = (
        lambda leaf: np.zeros((pad_n_node,) + leaf.shape[1:], dtype=leaf.dtype))
    tree_edges_pad = (
        lambda leaf: np.zeros((pad_n_edge,) + leaf.shape[1:], dtype=leaf.dtype))
    tree_globs_pad = (
        lambda leaf: np.zeros((pad_n_graph,) + leaf.shape[1:], dtype=leaf.dtype))

    padding_graph = GraphsTuple(
        n_node=np.concatenate(
          [np.array([pad_n_node], dtype=np.int32),
           np.zeros(pad_n_empty_graph, dtype=np.int32)]),
        n_edge=np.concatenate(
          [np.array([pad_n_edge], dtype=np.int32),
           np.zeros(pad_n_empty_graph, dtype=np.int32)]),
        nodes=jax.tree_map(tree_nodes_pad, graph.nodes),
        edges=jax.tree_map(tree_edges_pad, graph.edges),
        globals=jax.tree_map(tree_globs_pad, graph.globals),
        senders=np.zeros(pad_n_edge, dtype=np.int32),
        receivers=np.zeros(pad_n_edge, dtype=np.int32),
        e_order=jax.tree_map(tree_edges_pad, graph.e_order),
        e_mask=jax.tree_map(tree_edges_pad, graph.e_mask),
        n_mask=jax.tree_map(tree_nodes_pad, graph.n_mask),
    )
    return _batch([graph, padding_graph], np_=np)

In [6]:
def PADGRAPH(graph, max_edges):
    try:
        return pad_with_graphs(graph, graph.n_node.sum()+1, max_edges+1)
    except:
        max_edges += int(0.1*max_edges) + 1
        return PADGRAPH(graph, max_edges)

def mkgraph(*args, mass=None, L=None, max_edges=None, atoms=None, **kwargs):
    nodes = kwargs["nodes"]
    if mass is not None:
        nodes["mass"] = mass[nodes["type"]]
    graph = GraphsTuple(*args, 
                        e_mask=jnp.ones(kwargs["senders"].shape, dtype=bool), 
                        n_mask=jnp.ones(jnp.sum(kwargs["n_node"]), dtype=bool), 
                        **kwargs)
    return PADGRAPH(graph, max_edges)

def samegraph(*args, L=None, atoms=None, **kwargs):
    graph = GraphsTuple(*args, 
                        e_mask=jnp.ones(kwargs["senders"].shape, dtype=bool), 
                        n_mask=jnp.ones(jnp.sum(kwargs["n_node"]), dtype=bool), 
                        **kwargs)
    return graph


In [11]:
graphs, _ = loadfile("graphs_dicts.pkl", tag="graphs/")

max_edges = max([len(g["senders"]) for g in graphs])
print(max_edges)

3557


In [10]:
graphs[0]['nodes']['position'].shape

(201, 3)

In [8]:
mass = jnp.array([16.0, 28.0])

GRAPHS = []
for g in tqdm(graphs[::10]):
    GRAPHS += [mkgraph(**g, max_edges=max_edges, globals=None)]

  lax_internal._check_user_dtype_supported(dtype, "zeros")
100%|███████████████████████████████████████████████████████████████████████| 868/868 [00:11<00:00, 75.17it/s]


In [9]:
savefile("mkgraphs_dicts.pkl", GRAPHS, tag="graphs/")