In [1]:
import os
from typing import Any, Dict, Iterable, Tuple, Optional, Union

from absl import logging
from clu import checkpoint
from clu import metric_writers
from clu import metrics
from clu import parameter_overview
from clu import periodic_actions
import e3nn_jax as e3nn
import flax
import flax.core
import flax.linen as nn
from flax.training import train_state
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import numpy as np
import optax
import random
import tensorflow as tf

import datatypes
import input_pipeline
import models

from configs import graphmlp, graphnet
from input_pipeline import get_datasets, ase_atoms_to_jraph_graph, generative_sequence
from train import *
from qm9 import load_qm9

2023-02-15 02:04:24.491311: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-02-15 02:04:26.135717: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-15 02:04:26.135882: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [4]:
## this takes a very long time

key = jax.random.PRNGKey(0)
qm9_data = load_qm9("qm9_data")
subgraphs = []
# collect graphs of partially-assembled molecules
ct = 0
for mol in qm9_data:
    mol_graph = ase_atoms_to_jraph_graph(mol, 3.5)
    for subgraph in generative_sequence(key, mol_graph):
        subgraphs.append(subgraph)
    ct += 1
random.shuffle(subgraphs)

import pickle

pickle.dump(subgraphs, open("qm9_subgraphs_seed=0_cutoff=3.5.p", "w"))

In [2]:
key = jax.random.PRNGKey(0)
qm9_data = load_qm9("qm9_data")
subgraphs = []
# collect graphs of partially-assembled molecules
ct = 0
for mol in qm9_data[:100]:
    mol_graph = ase_atoms_to_jraph_graph(mol, 3.5)
    for subgraph in generative_sequence(key, mol_graph):
        subgraphs.append(subgraph)
    ct += 1
random.shuffle(subgraphs)



In [3]:
gen = jraph.dynamically_batch(
    subgraphs,
    n_node=100 * 64,
    n_edge=100 * 64,
    n_graph=64,
)

In [13]:
gen_list = list(gen)
len(gen_list)

16

In [4]:
config = graphmlp.get_config()

In [7]:
def train_and_evaluate(
    config: ml_collections.ConfigDict, workdir: str, train_data
):
    """Execute model training and evaluation loop.

    Args:
      config: Hyperparameter configuration for training and evaluation.
      workdir: Directory where the TensorBoard summaries are written to.

    Returns:
      The train state (which includes the `.params`).
    """
    # We only support single-host training.
    assert jax.process_count() == 1

    # Create writer for logs.
    writer = metric_writers.create_default_writer(workdir)
    writer.write_hparams(config.to_dict())

    # Get datasets, organized by split.
    logging.info("Obtaining datasets.")
    train_iter = iter(train_data)

    # Create and initialize the network.
    logging.info("Initializing network.")
    rng = jax.random.PRNGKey(0)
    rng, init_rng = jax.random.split(rng)
    init_graphs = next(train_data)
    init_graphs = replace_globals(init_graphs)
    init_net = create_model(config, deterministic=True)
    params = jax.jit(init_net.init)(init_rng, init_graphs)
    parameter_overview.log_parameter_overview(params)

    # Create the optimizer.
    tx = create_optimizer(config)

    # Create the training state.
    net = create_model(config, deterministic=False)
    state = train_state.TrainState.create(apply_fn=net.apply, params=params, tx=tx)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Create the evaluation state, corresponding to a deterministic model.
    eval_net = create_model(config, deterministic=True)
    eval_state = state.replace(apply_fn=eval_net.apply)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer
    )
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    # Begin training loop.
    logging.info("Starting training.")
    train_metrics = None
    for step in range(initial_step, config.num_train_steps + 1):
        # Split PRNG key, to ensure different 'randomness' for every step.
        rng, dropout_rng = jax.random.split(rng)

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation("train", step_num=step):
            graphs = jax.tree_util.tree_map(jnp.asarray, next(train_iter))
            state, metrics_update = train_step(
                state,
                graphs,
                rngs={"dropout": dropout_rng},
                loss_kwargs=config.loss_kwargs.to_dict(),
            )

            # Update metrics.
            if train_metrics is None:
                train_metrics = metrics_update
            else:
                train_metrics = train_metrics.merge(metrics_update)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, "Finished training step %d.", 10, step)
        for hook in hooks:
            hook(step)

    return state

In [12]:
gen = jraph.dynamically_batch(
    subgraphs,
    n_node=100 * 64,
    n_edge=100 * 64,
    n_graph=64,
)
train_and_evaluate(config, "graphmlp", gen)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[6400,3])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function init at /home/songk/anaconda3/envs/sh-net/lib/python3.10/site-packages/flax/linen/module.py:1346 for jit. This concrete value was not available in Python because it depends on the value of the argument 'args'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError