In [2]:
import os
import tempfile
import functools

import jax
import jax.numpy as jnp

from absl import logging

import flax
import optax
import chex
from clu import metric_writers, parameter_overview

from molnet import train_state, hooks, train, utils, loss
from molnet.models import create_model

from molnet.data import input_pipeline
from configs import test

from typing import Any, Dict, Tuple



In [3]:
@functools.partial(jax.pmap, axis_name="device")
def eval_step(
    state: train_state.TrainState,
    images,
    atom_map,
):
    """Evaluation step."""
    preds = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        images,
        training=False
    )
    preds_z = preds.shape[-2]
    batch_loss = loss.mse(
        preds,
        atom_map[..., -preds_z:, :]
    )

    return train.Metrics.gather_from_model_output(
        axis_name="device",
        loss=batch_loss,
    )

cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x') 
def evaluate_model(
    state: train_state.TrainState,
    datasets,
    num_eval_steps: int,
):
    """Evaluate over all datasets."""

    eval_metrics = {}
    for split, data_iterator in datasets.items():
        split_metrics = train.Metrics.empty()
        split_metrics = flax.jax_utils.replicate(split_metrics)

        state.replace(batch_stats=cross_replica_mean(state.batch_stats))
        # Loop over graphs.
        for step in range(num_eval_steps):
            batch = next(device_batch(data_iterator))
            #batch = jax.tree_util.tree_map(jnp.asarray, batch)
           
            # Evaluate the model.
            batch_metrics = eval_step(state, batch['images'], batch['atom_map'])

            split_metrics = split_metrics.merge(batch_metrics)

        split_metrics = flax.jax_utils.unreplicate(split_metrics)
        eval_metrics[split + "_eval"] = split_metrics

    return eval_metrics  
          
@functools.partial(jax.pmap, axis_name="device")
def train_step(
    state: train_state.TrainState,
    images,
    atom_map,
    rng: chex.PRNGKey,
):
    """Train step."""

    def loss_fn(params):
        preds, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            images,
            training=True,
            mutable='batch_stats',
        )
        preds_z = preds.shape[-2]
        batch_loss = loss.mse(
            preds,
            atom_map[..., -preds_z:, :]
        )

        return batch_loss, (preds, updates)

    # Compute loss and gradients
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (batch_loss, (_, updates)), grad = grad_fn(state.params)

    # Average gradients across devices
    grads = jax.lax.pmean(grad, axis_name="device")
    state = state.apply_gradients(
        grads=grads,
    )
    state = state.replace(batch_stats=updates["batch_stats"])

    batch_metrics = train.Metrics.gather_from_model_output(
        axis_name="device",
        loss=batch_loss,
    )
    return state, batch_metrics


def device_batch(
    batch_iterator
):
    """Batches a set of inputs to the size of the number of devices."""
    num_devices = jax.local_device_count()
    batch = []
    for idx, b in enumerate(batch_iterator):
        if idx % num_devices == num_devices - 1:
            batch.append(b)
            batch = jax.tree_util.tree_map(lambda *x: jnp.stack(x, axis=0), *batch)
            yield batch

            batch = []
        else:
            batch.append(b)


In [5]:
config = test.get_config()
config.root_dir = '/l/data/molnet/atom_maps'
config.workdir = tempfile.mkdtemp()

# Create writer for logs
writer = metric_writers.create_default_writer(config.workdir)
writer.write_hparams(config.to_dict())


2024-11-04 12:31:51.131422: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [6]:
# Get datasets
rng = jax.random.PRNGKey(config.rng_seed)
rng, data_rng = jax.random.split(rng)
datasets = input_pipeline.get_datasets(data_rng, config)
train_ds = datasets["train"]


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805


In [30]:
# Create model
import flax.jax_utils

x_init = next(train_ds)['images']
rng, init_rng = jax.random.split(rng)
model = create_model(config)

variables = model.init(init_rng, x_init, training=True)
params = variables["params"]
batch_stats = variables["batch_stats"]
parameter_overview.log_parameter_overview(params)

# Create optimizer
tx = utils.create_optimizer(config)

# Create training state
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
    batch_stats=batch_stats,
    best_params=params,
    step_for_best_params=0,
    metrics_for_best_params={},
    train_metrics=train.Metrics.empty(),
)

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing 

In [31]:
# Set up checkpointing
checkpoint_path = os.path.join(config.workdir, "checkpoint")
checkpoint_hook = hooks.CheckpointHook(
    checkpoint_path, max_keep=1
)
state = checkpoint_hook.restore_or_init(state)
initial_step = state.get_step()

# Replicate states across devices
state = flax.jax_utils.replicate(state)

train_metrics_hook = hooks.LogTrainMetricsHook(
    writer,
)
evaluate_model_hook = hooks.EvaluateModelHook(
    evaluate_model_fn=lambda state: evaluate_model(
        state,
        datasets,
        config.num_eval_steps,
    ),
    writer=writer,
)

In [33]:
state = evaluate_model_hook(state)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [None]:
batch = next(device_batch(train_ds))

In [None]:
for k, v in batch.items():
    print(k, v.shape)