In [1]:
import os
import pickle
import yaml
import functools

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import ml_collections
from clu import checkpoint


from molnet import utils, train_state, train
from molnet.data import input_pipeline_wds
from molnet.models import create_model
from configs import root_dirs
from analyses import make_predictions

from typing import Tuple

In [2]:
#workdir = "/u/79/kurkil1/unix/work/molnet/runs/attention-adam-3e-4/"
workdir = "/Users/kurkil1/work/molnet/runs/attention-adam-3e-4/"

In [34]:
def update_old_config(config):

    model_config = ml_collections.ConfigDict()

    kernel_size: int = config.model.kernel_size
    num_blocks: int = len(config.model.channels)

    model_config.encoder_kernel_size = [
        [3, 3, 3] for _ in range(num_blocks)
    ]

    model_config.decoder_kernel_size = [
        [3, 3, 3] for _ in range(num_blocks)
    ]
    
    model_config.model_name = config.model.model_name
    model_config.output_channels = config.model.output_channels
    model_config.encoder_channels = [16, 32, 64, 128]
    model_config.decoder_channels = [128, 64, 32, 16]
    model_config.attention_channels = [16, 16, 16, 16]
    model_config.conv_activation = "relu"
    model_config.attention_activation = "sigmoid"
    model_config.return_attention_maps = True

    return model_config

In [35]:
def load_from_workdir(
    workdir: str,
    return_attention: bool
):
    # Load the model config
    with open(os.path.join(workdir, "config.yaml"), "rt") as f:
        config = yaml.unsafe_load(f)
    config = ml_collections.ConfigDict(config)
    config.root_dir = root_dirs.get_root_dir()
    config.model.return_attention_maps = return_attention

    print(config)

    model_config = update_old_config(config)
    print(model_config)

    # Create the model
    model = create_model(model_config)

    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.Checkpoint(checkpoint_dir)

    apply_fn = model.apply
    tx = utils.create_optimizer(config)
    restored_state = ckpt.restore(state=None)['state']

    # Load the model state
    state = train_state.EvaluationState.create(
        apply_fn=apply_fn,
        params=restored_state['params'],
        batch_stats=restored_state['batch_stats'],
        tx=tx,
    )
    state = jax.tree_util.tree_map(jnp.asarray, state)

    return state, config

In [36]:
state, config = load_from_workdir(
    workdir=workdir,
    return_attention=True
)

batch_size: 24
dataset: edafm
debug: false
eval_every_steps: 10000
learning_rate: 0.0003
learning_rate_schedule: constant
log_every_steps: 100
max_atoms: 54
model:
  attention_channels:
  - 16
  - 16
  - 16
  - 16
  channels:
  - 16
  - 32
  - 64
  - 128
  kernel_size:
  - 3
  - 3
  - 3
  - 3
  model_name: attention-unet
  output_channels: 5
  return_attention_maps: true
momentum: null
noise_std: 0.1
num_eval_steps: 1000
num_train_steps: 1000000
optimizer: adam
predict_every_steps: 10000
predict_num_batches: 2
predict_num_batches_at_end_of_training: 10
rng_seed: 0
root_dir: /Users/kurkil1/data/atom_maps
shuffle_datasets: true
train_molecules: !!python/tuple
- 0
- 230000
val_molecules: !!python/tuple
- 230000
- 264000

attention_activation: sigmoid
attention_channels:
- 16
- 16
- 16
- 16
conv_activation: relu
decoder_channels:
- 128
- 64
- 32
- 16
decoder_kernel_size:
- - 3
  - 3
  - 3
- - 3
  - 3
  - 3
- - 3
  - 3
  - 3
- - 3
  - 3
  - 3
encoder_channels:
- 16
- 32
- 64
- 128
encoder_k

In [37]:
rng = jax.random.PRNGKey(0)
datarng, rng = jax.random.split(rng)
config.train_molecules = (0, 64)
config.val_molecules = (64, 96)
config.batch_size = 4
ds = input_pipeline_wds.get_datasets(
    config
)['val']

batch = next(iter(ds))

for k, v in batch.items():
    print(k, v.shape)

images (4, 128, 128, 10, 1)
atom_map (4, 128, 128, 10, 5)
xyz (4, 54, 5)


In [38]:
@jax.jit
def predict(
    state,
    batch,
):
    inputs, targets = batch['images'], batch['atom_map']
    preds, attention = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        inputs,
        training=False,
    )
    loss_by_image = jnp.mean(
        (preds - targets) ** 2,
        axis=(1, 2, 3, 4),
    )
    return inputs, targets, preds, attention, loss_by_image


In [39]:
inputs, targets, preds, attention, loss_by_image = predict(
    state,
    batch,
)

In [40]:
print(f"inputs: {inputs.shape}")
print(f"targets: {targets.shape}")
print(f"preds: {preds.shape}")
print(f"loss_by_image: {loss_by_image.shape}")
for att in attention:
    print(f"attention: {att.shape}")

inputs: (4, 128, 128, 10, 1)
targets: (4, 128, 128, 10, 5)
preds: (4, 128, 128, 10, 5)
loss_by_image: (4,)
attention: (4, 16, 16, 10, 1)
attention: (4, 32, 32, 10, 1)
attention: (4, 64, 64, 10, 1)
attention: (4, 128, 128, 10, 1)


In [3]:
make_predictions.make_predictions(
    workdir=workdir,
    outputdir=os.path.join(workdir, "analysis_mac"),
    num_batches=1
)