In [1]:
import functools

import jax
import jax.numpy as jnp

import flax
import flax.linen as nn
import optax
from clu import metrics

import matplotlib.pyplot as plt

from molnet.data import input_pipeline_online
from molnet.models import create_model
from molnet import train_state
from molnet import loss

from configs.tests import attention_test
from configs import root_dirs

from typing import Any, Dict, Tuple, Callable



In [2]:
config = attention_test.get_config()
config.root_dir = root_dirs.get_root_dir()

@flax.struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output("loss") # type: ignore

In [3]:
ds = input_pipeline_online.get_datasets(config)['train']
rng = jax.random.PRNGKey(0)
init_rng, rng = jax.random.split(rng)

model = create_model(config.model)
dummy_input = jnp.empty((1, 128, 128, int(config.z_cutoff/0.1), 1))
variables = model.init(init_rng, dummy_input, training=False)
params = variables["params"]
batch_stats = variables["batch_stats"]

loss_fn = loss.get_loss_function(config.loss_fn)
tx = optax.adam(learning_rate=3e-4)

2024-12-04 16:47:31.386644: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 1479210051963560085


In [4]:
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    batch_stats=batch_stats,
    tx=tx,
    best_params=params,
    step_for_best_params=0,
    metrics_for_best_params={},
)

In [5]:
@functools.partial(jax.jit, static_argnums=(2,))
def train_step(
    state: train_state.TrainState,
    batch: Dict[str, Any],
    loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
) -> Tuple[train_state.TrainState, metrics.Collection]:
    """Train step."""

    def loss_wrapper(params):
        preds, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            batch["images"],
            training=True,
            mutable='batch_stats',
        )
        total_loss = loss_fn(
            preds,
            batch["atom_map"]
        )
        #mean_loss = jnp.mean(total_loss)

        return total_loss, (preds, updates)

    # Compute loss and gradients
    grad_fn = jax.value_and_grad(loss_wrapper, has_aux=True)
    (batch_loss, (_, updates)), grads = grad_fn(state.params)

    batch_metrics = Metrics.single_from_model_output(
        loss=batch_loss,
    )

    # Update parameters
    new_state = state.apply_gradients(
        grads=grads,
        batch_stats=updates["batch_stats"],)

    return new_state, batch_metrics


In [6]:
for i in range(10):
    batch = next(ds)
    state, metrics = train_step(state, batch, loss_fn)
    print(metrics.loss)


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 8757321662871330594
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12317595142917291262
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing t

Metric.from_output.<locals>.FromOutput(total=Array(0.03872421, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.02137289, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.01785921, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.01249757, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.00908747, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.00700767, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.00600772, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.00491124, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.FromOutput(total=Array(0.00467337, dtype=float32), count=Array(1, dtype=int32))
Metric.from_output.<locals>.