In [1]:
import os
import pickle
import yaml
import functools

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import ml_collections
from clu import checkpoint


from molnet import utils, train_state, train
from molnet.data import input_pipeline
from molnet.models import create_model
from configs import root_dirs
from analyses import make_predictions

from typing import Tuple



In [2]:
workdir = "/u/79/kurkil1/unix/work/molnet/runs/attention-adam-3e-4-relu-z10/"

In [None]:
def load_from_workdir(
    workdir: str,
    return_attention: bool
):
    # Load the model config
    with open(os.path.join(workdir, "config.yaml"), "rt") as f:
        config = yaml.unsafe_load(f)
    config = ml_collections.ConfigDict(config)
    config.root_dir = root_dirs.get_root_dir()
    config.model.return_attention_maps = return_attention

    # Create the model
    model = create_model(config.model)

    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.Checkpoint(checkpoint_dir)

    apply_fn = model.apply
    tx = utils.create_optimizer(config)
    restored_state = ckpt.restore(state=None)['state']

    # Load the model state
    state = train_state.EvaluationState.create(
        apply_fn=apply_fn,
        params=restored_state['params'],
        batch_stats=restored_state['batch_stats'],
        tx=tx,
    )
    state = jax.tree_util.tree_map(jnp.asarray, state)

    return state, config

In [None]:
state, config = load_from_workdir(
    workdir=workdir,
    return_attention=True
)

2024-11-11 10:49:58.895724: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [None]:
rng = jax.random.PRNGKey(0)
datarng, rng = jax.random.split(rng)
config.train_molecules = (0, 1000)
config.val_molecules = (1000, 2000)
ds = input_pipeline.get_datasets(
    rng,
    config
)['val']

batch = next(ds)

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 15664034803102229683
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing te

In [None]:
@jax.jit
def predict(
    state,
    batch,
):
    inputs, targets = batch['images'], batch['atom_map']
    preds, attention = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        inputs,
        training=False,
    )
    preds_z = preds.shape[-2]
    target = targets[..., -preds_z:, :]
    loss_by_image = jnp.mean(
        (preds - target) ** 2,
        axis=(1, 2, 3, 4),
    )
    return inputs, target, preds, attention, loss_by_image


In [None]:
inputs, targets, preds, attention, loss_by_image = predict(
    state,
    batch,
)

2024-11-11 10:52:12.907427: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[16,32,128,128,10]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[16,32,128,128,10]{4,3,2,1,0}, f32[32,32,3,3,3]{4,3,2,1,0}, f32[32]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kRelu","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2024-11-11 10:52:13.636665: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.729431737s
Trying algorithm eng0{} for conv (f32[16,32,128,128,10]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[16,32,128,128,10]{4,3,2,1,0}, f32[32,32,3,3,3]{4,3,2,1,0}, f32[32]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convB

In [None]:
print(f"inputs: {inputs.shape}")
print(f"targets: {targets.shape}")
print(f"preds: {preds.shape}")
print(f"loss_by_image: {loss_by_image.shape}")
for att in attention:
    print(f"attention: {att.shape}")

inputs: (8, 128, 128, 10, 1)
targets: (8, 128, 128, 10, 5)
preds: (8, 128, 128, 10, 5)
loss_by_image: (8,)
attention: (8, 8, 8, 10, 1)
attention: (8, 16, 16, 10, 1)
attention: (8, 32, 32, 10, 1)
attention: (8, 64, 64, 10, 1)
attention: (8, 128, 128, 10, 1)


In [3]:
make_predictions.make_predictions(
    workdir=workdir,
    outputdir=os.path.join(workdir, "analysis"),
    num_batches=1
)

2024-11-11 10:57:19.056028: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 336318792119298919
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnaps