In [1]:
import functools
import os
import pickle
from typing import Any, Dict, Iterable, Iterator, Optional, Tuple, Union
import chex
import flax
import jax
import jax.numpy as jnp
import jraph
import ml_collections
import optax
import yaml
from absl import logging
import matplotlib.pyplot as plt
import sys
sys.path.append('../')


from clu import (
    checkpoint,
    metric_writers,
    metrics,
    parameter_overview,
    periodic_actions,
)
from flax.training import train_state

from symphony import datatypes, models, loss
from symphony.data import input_pipeline_tf
from configs.silica import allegro

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = allegro.get_config()
workdir = "/data/NFS/potato/songk/spherical-harmonic-net/workdirs/silica-allegro"
dataset = "silica"
# We only support single-host training.
assert jax.process_count() == 1

# Helper for evaluation.
def evaluate_model_helper(
    eval_state: train_state.TrainState,
    step: int,
    rng: chex.PRNGKey,
    is_final_eval: bool,
) -> Dict[str, metrics.Collection]:
    # Final eval splits are usually different.
    if is_final_eval:
        splits = ["train_eval_final", "val_eval_final", "test_eval_final"]
    else:
        splits = ["train_eval", "val_eval", "test_eval"]

    # Evaluate the model.
    with report_progress.timed("eval"):
        eval_metrics = evaluate_model(
            eval_state,
            datasets,
            splits,
            rng,
            config.loss_kwargs,
        )

    # Compute and write metrics.
    for split in splits:
        eval_metrics[split] = eval_metrics[split].compute()
        writer.write_scalars(step, add_prefix_to_keys(eval_metrics[split], split))
    writer.flush()

    return eval_metrics

# Create writer for logs.
writer = metric_writers.create_default_writer(workdir)
writer.write_hparams(config.to_dict())

# Save the config for reproducibility.
config_path = os.path.join(workdir, "config.yml")
with open(config_path, "w") as f:
    yaml.dump(config, f)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
# Get datasets, organized by split.
logging.info("Obtaining datasets.")
rng = jax.random.PRNGKey(config.rng_seed)
rng, dataset_rng = jax.random.split(rng)
datasets = input_pipeline_tf.get_datasets(dataset_rng, config)

Retrieving SummaryDoc documents: 100%|██████████| 344/344 [00:00<00:00, 2845839.40it/s]


In [5]:
for struct in datasets['train']:
    print(struct)

GraphsTuple(nodes=FragmentsNodes(positions=<tf.Tensor: shape=(7680, 3), dtype=float64, numpy=
array([[ 4.335634  ,  0.        ,  2.50762626],
       [ 2.40706596, 10.96284522,  5.16731411],
       [-1.59176299,  6.90761463,  5.35004013],
       ...,
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ]])>, species=<tf.Tensor: shape=(7680,), dtype=int64, numpy=array([0, 0, 1, ..., 0, 0, 0])>, focus_and_target_species_probs=<tf.Tensor: shape=(7680, 2), dtype=float64, numpy=
array([[0. , 0.1],
       [0. , 0.1],
       [0.1, 0. ],
       ...,
       [0. , 0. ],
       [0. , 0. ],
       [0. , 0. ]])>), edges=<tf.Tensor: shape=(46080,), dtype=float64, numpy=array([1., 1., 1., ..., 0., 0., 0.])>, receivers=<tf.Tensor: shape=(46080,), dtype=int64, numpy=array([  9,   9,   9, ..., 649, 649, 649])>, senders=<tf.Tensor: shape=(46080,), dtype=int64, numpy=array([ 11,   8,  13, ..., 649, 649, 649])>, globals=

In [None]:
# Create and initialize the network.
logging.info("Initializing network.")
train_iter = datasets["train"]#.as_numpy_iterator()
init_graphs = next(train_iter)
net = models.create_model(config, run_in_evaluation_mode=False)

rng, init_rng = jax.random.split(rng)
params = jax.jit(net.init)(init_rng, init_graphs)
parameter_overview.log_parameter_overview(params)

In [None]:
# Create the optimizer.
tx = create_optimizer(config)

# Create the training state.
state = train_state.TrainState.create(
    apply_fn=jax.jit(net.apply), params=params, tx=tx
)

# Create a corresponding evaluation state.
eval_net = models.create_model(config, run_in_evaluation_mode=False)
eval_state = state.replace(apply_fn=jax.jit(eval_net.apply))

In [None]:
# Set up checkpointing of the model.
# We will record the best model seen during training.
checkpoint_dir = os.path.join(workdir, "checkpoints")
ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=5)
restored = ckpt.restore_or_initialize(
    {
        "state": state,
        "best_state": state,
        "step_for_best_state": 1.0,
        "metrics_for_best_state": None,
    }
)
state = restored["state"]
best_state = restored["best_state"]
step_for_best_state = restored["step_for_best_state"]
metrics_for_best_state = restored["metrics_for_best_state"]
if metrics_for_best_state is None:
    min_val_loss = float("inf")
else:
    min_val_loss = metrics_for_best_state["val_eval"]["total_loss"]
initial_step = int(state.step) + 1

# Replicate the training and evaluation state across devices.
state = flax.jax_utils.replicate(state)
best_state = flax.jax_utils.replicate(best_state)
eval_state = flax.jax_utils.replicate(eval_state)

# Hooks called periodically during training.
report_progress = periodic_actions.ReportProgress(
    num_train_steps=config.num_train_steps, writer=writer
)
profile = periodic_actions.Profile(
    logdir=workdir,
    every_secs=10800,
)
hooks = [report_progress, profile]

# Begin training loop.
logging.info("Starting training.")
train_metrics = flax.jax_utils.replicate(Metrics.empty())
train_metrics_empty = True
all_grad_norms = []
all_param_norms = []
all_params = []
all_focus_and_atom_type_losses = []
all_num_nodes = []
all_num_edges = []

for step in range(initial_step, config.num_train_steps + 1):
    # Log, if required.
    first_or_last_step = step in [initial_step, config.num_train_steps]
    if step % config.log_every_steps == 0 or first_or_last_step:
        if not train_metrics_empty:
            writer.write_scalars(
                step,
                add_prefix_to_keys(flax.jax_utils.unreplicate(train_metrics).compute(), "train"),
            )
        train_metrics = flax.jax_utils.replicate(Metrics.empty())
        train_metrics_empty = True

    # Evaluate on validation and test splits, if required.
    if step % config.eval_every_steps == 0 or first_or_last_step:
        eval_state = eval_state.replace(params=state.params)
        # Evaluate on validation and test splits.
        rng, eval_rng = jax.random.split(rng)
        eval_metrics = evaluate_model_helper(
            eval_state,
            step,
            eval_rng,
            is_final_eval=False,
        )

        # Note best state seen so far.
        # Best state is defined as the state with the lowest validation loss.
        if eval_metrics["val_eval"]["total_loss"] < min_val_loss:
            min_val_loss = eval_metrics["val_eval"]["total_loss"]
            metrics_for_best_state = eval_metrics
            best_state = state
            step_for_best_state = step
            logging.info("New best state found at step %d.", step)

        # Save the current state and best state seen so far.
        with open(os.path.join(checkpoint_dir, f"params_{step}.pkl"), "wb") as f:
            pickle.dump(flax.jax_utils.unreplicate(state.params), f)
        with open(os.path.join(checkpoint_dir, "params_best.pkl"), "wb") as f:
            pickle.dump(flax.jax_utils.unreplicate(best_state.params), f)
        ckpt.save(
            {
                "state": flax.jax_utils.unreplicate(state),
                "best_state": flax.jax_utils.unreplicate(best_state),
                "step_for_best_state": step_for_best_state,
                "metrics_for_best_state": metrics_for_best_state,
            }
        )

    # Get a batch of graphs.
    try:
        graphs = next(device_batch(train_iter))

    except StopIteration:
        logging.info("No more training data. Continuing with final evaluation.")
        break

    # Perform one step of training.
    with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
        step_rng, rng = jax.random.split(rng)
        step_rngs = jax.random.split(step_rng, jax.local_device_count())
        state, batch_metrics = train_step(
            graphs,
            state,
            config.loss_kwargs,
            step_rngs,
            config.add_noise_to_positions,
            config.position_noise_std,
        )

        # Update metrics.
        train_metrics = train_metrics.merge(batch_metrics)
        train_metrics_empty = False

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 10, step)
    for hook in hooks:
        hook(step)

# Once training is complete, return the best state and corresponding metrics.
logging.info(
    "Evaluating best state from step %d at the end of training.",
    step_for_best_state,
)
eval_state = eval_state.replace(params=best_state.params)

# Evaluate on validation and test splits, but at the end of training.
rng, eval_rng = jax.random.split(rng)
final_metrics_for_best_state = evaluate_model_helper(
    eval_state,
    step,
    eval_rng,
    is_final_eval=True,
)

# Checkpoint the best state and corresponding metrics seen during training.
# Save pickled parameters for easy access during evaluation.
with report_progress.timed("checkpoint"):
    with open(os.path.join(checkpoint_dir, "params_best.pkl"), "wb") as f:
        pickle.dump(flax.jax_utils.unreplicate(best_state.params), f)
    ckpt.save(
        {
            "state": flax.jax_utils.unreplicate(state),
            "best_state": flax.jax_utils.unreplicate(best_state),
            "step_for_best_state": step_for_best_state,
            "metrics_for_best_state": metrics_for_best_state,
        }
    )