In [7]:
import tempfile
import os
import yaml

import jax
import jax.numpy as jnp

import flax
#from flax.training import train_state
import flax.linen as nn
import optax
from clu import metrics, metric_writers

from typing import Any, Tuple, Optional, Dict

from molnet import utils, train, hooks, train_state
from molnet.models import create_model
from molnet.data import input_pipeline
from configs import test

In [12]:
class MolnetTrainState(train_state.TrainState):
    batch_stats: Dict[str, jnp.ndarray]
    train_metrics: Any

writer = metric_writers.create_default_writer()

In [2]:
config = test.get_config()
config.root_dir = '/l/data/molnet/atom_maps/'

rng = jax.random.PRNGKey(0)
datarng, rng = jax.random.split(rng)

datasets = input_pipeline.get_datasets(datarng, config)['train']

2024-11-05 09:28:56.295503: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805


In [13]:
model = create_model(config)
x_init = next(datasets)["images"]

variables = model.init(rng, x_init, training=True)
params, batch_stats = variables["params"], variables["batch_stats"]

state = MolnetTrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adamw(0.01),
    batch_stats=batch_stats,
    train_metrics=train.Metrics.empty()
)

log_hook = hooks.LogTrainingMetricsHook(writer, is_empty=True)

In [14]:
for i in range(100):
    batch = next(datasets)

    state, metrics = train.train_step(state, batch)
    
    state = state.replace(
        train_metrics=state.train_metrics.merge(metrics)
    )

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805


In [3]:
# Create writer for logs
workdir = tempfile.mkdtemp()

writer = metric_writers.create_default_writer(workdir)
writer.write_hparams(config.to_dict())

# Save config to workdir
config_path = os.path.join(workdir, "config.yaml")
with open(config_path, "w") as f:
    yaml.dump(config, f)


In [52]:

# Get datasets
print("Loading datasets.")
rng = jax.random.PRNGKey(config.rng_seed)
rng, data_rng = jax.random.split(rng)
datasets = input_pipeline.get_datasets(data_rng, config)
train_ds = datasets["train"]


Loading datasets.


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805


In [53]:

# Create model
print("Creating model.")
x_init = next(train_ds)['images']
rng, init_rng = jax.random.split(rng)
model = create_model(config)

variables = model.init(init_rng, x_init, training=True)
params = variables["params"]
batch_stats = variables["batch_stats"]


Creating model.


[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing 

In [54]:

# Create optimizer
tx = utils.create_optimizer(config)

# Create training state
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx,
    batch_stats=batch_stats,
    best_params=params,
    step_for_best_params=0,
    metrics_for_best_params={},
    train_metrics=train.Metrics.empty(),
)
print(state.train_metrics)


Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(0., dtype=float32), count=Array(0, dtype=int32)))


In [50]:
for i in range(5):
    batch = next(train_ds)
    state, batch_metrics = train.train_step(state, batch)

    state = state.replace(
        train_metrics=state.train_metrics.merge(batch_metrics)
    )

    print(state.train_metrics)

Metrics(_reduction_counter=_ReductionCounter(value=Array(9, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(3.8863802, dtype=float32), count=Array(8, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(10, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(4.306643, dtype=float32), count=Array(9, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(11, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(4.6948643, dtype=float32), count=Array(10, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(12, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(5.1066995, dtype=float32), count=Array(11, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(13, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(5.5362916, dtype=float32), count=Array(12, dtype=int32)))


In [55]:

# Create hooks
print("Creating hooks.")
log_hook = hooks.LogTrainingMetricsHook(writer)

# Training loop
print("Starting training loop.")

for step in range(config.num_train_steps):

    #if step % config.log_every_steps == 0:
    #    log_hook(state)

    batch = next(train_ds)
    state, batch_metrics = train.train_step(state, batch)
    print(batch_metrics)
    
    #state = state.replace(
    #    train_metrics=state.train_metrics.merge(batch_metrics)
    #)
    #log_hook.is_empty = False

Creating hooks.
Starting training loop.
Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(1.4025555, dtype=float32), count=Array(1, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(1.3746758, dtype=float32), count=Array(1, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(1.3682373, dtype=float32), count=Array(1, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(1.289912, dtype=float32), count=Array(1, dtype=int32)))
Metrics(_reduction_counter=_ReductionCounter(value=Array(1, dtype=int32)), loss=Metric.from_output.<locals>.FromOutput(total=Array(1.2165849, dtype=float32), count=Array(1, dtype=int32)))
Metrics(_reduction_co