In [1]:
import os

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from clu import metric_writers

import optax

from molnet.models import create_model
from molnet.data import input_pipeline
from molnet import train_state, loss, hooks

from configs import test



In [2]:
workdir = './test'
writer = metric_writers.create_default_writer(workdir)

config = test.get_config()
config.root_dir = '/l/data/molnet/atom_maps/'

rng = jax.random.PRNGKey(0)
datasets = input_pipeline.get_datasets(rng, config)
train_iter = datasets['train']

model = create_model(config)
variables = model.init(rng, next(train_iter)['images'], training=True)
params = variables['params']
batch_stats = variables['batch_stats']

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    batch_stats=batch_stats,
    tx=optax.adamw(1e-3),
    best_params=params,
    metrics_for_best_params={},
    step_for_best_params=0,
)

2024-11-05 13:59:52.703336: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 12126270835042841805
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSna

In [3]:
@jax.jit
def predict_step(state, batch):
    inputs, targets = batch['images'], batch['atom_map']
    preds = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        inputs,
        training=False,
    )
    preds_z = preds.shape[-2]
    target = targets[..., -preds_z:, :]
    loss_by_image = jnp.mean(
        (preds - target) ** 2,
        axis=(1, 2, 3, 4),
    )
    return inputs, target, preds, loss_by_image

def predict_with_state(state, dataset, num_batches=1):
    losses = []
    preds = []
    inputs = []
    targets = []
    
    for i in range(num_batches):
        batch = next(dataset)
        (
            batch_inputs, batch_targets, batch_preds, batch_loss
        ) = predict_step(state, batch)
        inputs.append(batch_inputs)
        targets.append(batch_targets)
        preds.append(batch_preds)
        losses.append(batch_loss)

    inputs = jnp.concatenate(inputs)
    targets = jnp.concatenate(targets)
    preds = jnp.concatenate(preds)
    losses = jnp.concatenate(losses)

    return inputs, targets, preds, losses

In [4]:
(
    inputs, targets, preds, losses
) = predict_with_state(
    state, datasets['val'], 2
)

print(f"inputs: {inputs.shape}")
print(f"targets: {targets.shape}")
print(f"preds: {preds.shape}")

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 9108667528181848758
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 9108667528181848758
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing te

inputs: (8, 128, 128, 10, 1)
targets: (8, 128, 128, 10, 5)
preds: (8, 128, 128, 10, 5)


In [29]:
outdir = './test/'
os.makedirs(outdir, exist_ok=True)


In [21]:
titles = ['H', 'C', 'N', 'O', 'F']

n_samples = inputs.shape[0]

for sample in range(n_samples):
    inp = inputs[sample]
    target = targets[sample]
    pred = preds[sample]
    loss = losses[sample]

    fig = plt.figure(figsize=(18, 10), layout='constrained')
    subfigs = fig.subfigures(1, 5, wspace=0.07, width_ratios=[1, 2, 2, 1, 1])

    fig.suptitle(f'mse: {loss:.4f}', fontsize=16)
    subfigs[0].suptitle(f'Input')
    subfigs[1].suptitle(f'Prediction')
    subfigs[2].suptitle(f'Target')
    subfigs[3].suptitle(f'Prediction (sum over species)')
    subfigs[4].suptitle(f'Target (sum over species)')

    axs_input = subfigs[0].subplots(5, 1)
    axs_pred = subfigs[1].subplots(10, 5)
    axs_target = subfigs[2].subplots(10, 5)
    axs_pred_sum = subfigs[3].subplots(10, 1)
    axs_target_sum = subfigs[4].subplots(10, 1)

    for i in range(10):
        for j in range(5):
            axs_pred[i, j].imshow(pred[..., i, j], cmap='gray')
            axs_pred[i, j].set_xticks([])
            axs_pred[i, j].set_yticks([])
            axs_target[i, j].imshow(target[..., i, j], cmap='gray')
            axs_target[i, j].set_xticks([])
            axs_target[i, j].set_yticks([])

    axs_input[0].set_ylabel('Far')
    axs_input[-1].set_ylabel('Close')
    for i in range(10):
        axs_input[i//2].imshow(inp[..., i//2, 0], cmap='gray')
        ps = axs_pred_sum[i].imshow(jnp.sum(pred[..., i, :], axis=-1), cmap='gray')
        ts = axs_target_sum[i].imshow(jnp.sum(target[..., i, :], axis=-1), cmap='gray')
    
        for ax in [axs_pred_sum[i], axs_target_sum[i]]:
            ax.set_aspect('equal')
            ax.set_xticks([])
            ax.set_yticks([])

        for ax in [axs_input, axs_pred_sum, axs_target_sum]:
            ax[0].set_ylabel('Far')
            ax[-1].set_ylabel('Close')

    subfigs[3].colorbar(ps, ax=axs_pred_sum, location='right', shrink=0.5)
    subfigs[4].colorbar(ts, ax=axs_target_sum, location='right', shrink=0.5)

    for i, title in enumerate(titles):
        axs_pred[0, i].set_title(title)
        axs_target[0, i].set_title(title)
        axs_pred[0, 0].set_ylabel('Far')
        axs_pred[-1, 0].set_ylabel('Close')
        axs_target[0, 0].set_ylabel('Far')
        axs_target[-1, 0].set_ylabel('Close')


    plt.savefig(f'{outdir}/sample_{sample}.png')
    plt.close()


In [9]:
n_samples = inputs.shape[0]

for sample in range(n_samples):
    inp = inputs[sample, ..., -1, 0]
    pred = preds[sample].sum(axis=(-1, -2))
    target = targets[sample].sum(axis=(-1, -2))

    fig, axs = plt.subplots(1, 3, figsize=(18, 6))
    axs[0].imshow(inp, cmap='gray')
    axs[0].set_title('Input')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    axs[1].imshow(pred, cmap='gray')
    axs[1].set_title('Prediction')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    axs[2].imshow(target, cmap='gray')
    axs[2].set_title('Target')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.savefig(f'{outdir}/total_{sample:02}.png')
    plt.close()


In [4]:
hook = hooks.PredictionHook(
    workdir=workdir,
    predict_fn=lambda state: predict_with_state(
        state,
        datasets['val'],
        2
    ),
    writer=writer,
)

In [5]:
hook(state)

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 9108667528181848758
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 9108667528181848758
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing te

(8, 128, 128) (8, 128, 128) (8, 128, 128)
