In [1]:
import os
import yaml
import tqdm

import ase
from ase import io
from ase import data
from ase import db
from ase.visualize.plot import plot_atoms

from skimage import feature
import numpy as np

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

import ml_collections
from clu import checkpoint


from molnet import utils, train_state, train
from molnet.data import input_pipeline_online
from molnet.models import create_model
from configs import root_dirs
from analyses import make_predictions

from typing import Tuple

INDEX_TO_SYMBOL = {
    0: 'H',
    1: 'C',
    2: 'N',
    3: 'O',
    4: 'F'
}



In [14]:
workdir = "/u/79/kurkil1/unix/work/molnet/runs/bf16-augs-rebias-adam-3e-4-z10-reverse-z/"
workdir = "/u/79/kurkil1/unix/work/molnet/runs/bf16-augs-rebias-adam-3e-4-z20-interp/"
#workdir = "/Users/kurkil1/work/molnet/runs/attention-adam-3e-4/"

In [15]:
def load_from_workdir(
    workdir: str,
    return_attention: bool
):
    # Load the model config
    with open(os.path.join(workdir, "config.yaml"), "rt") as f:
        config = yaml.unsafe_load(f)
    config = ml_collections.ConfigDict(config)
    config.root_dir = root_dirs.get_root_dir("afms_rebias")
    config.model.return_attention_maps = return_attention

    print(config)

    # Create the model
    model = create_model(config.model)

    checkpoint_dir = os.path.join(workdir, "checkpoints")
    ckpt = checkpoint.Checkpoint(checkpoint_dir)

    apply_fn = model.apply
    tx = utils.create_optimizer(config)
    restored_state = ckpt.restore(state=None)['state']

    # Load the model state
    state = train_state.EvaluationState.create(
        apply_fn=apply_fn,
        params=restored_state['params'],
        batch_stats=restored_state['batch_stats'],
        tx=tx,
    )
    state = jax.tree_util.tree_map(jnp.asarray, state)

    return state, config

In [16]:
state, config = load_from_workdir(
    workdir=workdir,
    return_attention=False
)

batch_size: 12
cutout_probs:
- 0.5
- 0.3
- 0.1
- 0.05
- 0.05
dataset: afms_rebias
debug: false
eval_every_steps: 2000
gaussian_factor: 5.0
interpolate_input_z: 20
learning_rate: 0.0003
learning_rate_schedule: constant
learning_rate_schedule_kwargs:
  decay_steps: 50000
  init_value: 0.0003
  peak_value: 0.0006
  warmup_steps: 2000
log_every_steps: 100
loss_fn: mse
max_atoms: 54
max_shift_per_slice: 0.02
model:
  attention_activation: sigmoid
  attention_channels:
  - 32
  - 32
  - 32
  - 32
  - 32
  conv_activation: relu
  decoder_channels:
  - 256
  - 128
  - 64
  - 32
  - 16
  decoder_kernel_size:
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  dtype: bfloat16
  encoder_channels:
  - 16
  - 32
  - 64
  - 128
  - 256
  encoder_kernel_size:
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  - - 3
    - 3
    - 3
  model_name: Attention-UNet
  output_activation: null


In [17]:
rng = jax.random.PRNGKey(0)
datarng, rng = jax.random.split(rng)
with config.unlocked():
    #config.z_cutoff = 1.0
    #config.interpolate_z = None
    config.target_z_cutoff = 2.0
    #config.train_molecules = (0, 80000)
    #config.val_molecules = (80000, 100000)
    #config.max_shift_per_slice = 0.02

ds = input_pipeline_online.get_full_molecule_datasets(
    config
)['val']

batch = next(ds)

for k, v in batch.items():
    print(k, v.shape)

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 18085879115725111131
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 1400486575639093112
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing t

images (12, 128, 128, 20, 1)
xyz (12, 54, 5)
sw (12, 2, 3)
atom_map (12, 128, 128, 20, 5)


In [18]:
def grid_to_mol(
    grid: jnp.ndarray,
    peak_threshold: float = 0.5,
    z_cutoff: float = 1.0,
) -> ase.Atoms:
    grid = grid[..., ::-1, :]

    peaks = feature.peak_local_max(
        grid,
        min_distance=5,
        exclude_border=0,
        threshold_rel=peak_threshold
    )

    xyz_from_peaks = peaks[:, [1, 0, 2]] * (.125, .125, .1)
    elem_from_peaks = peaks[:, 3]

    mol = ase.Atoms(
        positions=xyz_from_peaks,
        symbols=[INDEX_TO_SYMBOL[i] for i in elem_from_peaks],
        cell=[16, 16, 0],
    )
    mol.positions[:, 2] -= mol.get_positions()[:, 2].max() - z_cutoff
    return mol

In [19]:
@jax.jit
def predict(
    state,
    batch,
):
    inputs, targets = batch['images'], batch['atom_map']
    preds = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        inputs,
        training=False,
    )
    return inputs, targets, preds, batch["xyz"]

In [20]:
outputdir = os.path.join(workdir, "ase_0.3")
# Create the output directory
os.makedirs(outputdir, exist_ok=True)

# Create 3 dbs
# 1. Target molecules
# 2. Predicted molecules
# 3. Molecules from xyz

xyz_target = os.path.join(outputdir, "target.xyz")
xyz_pred = os.path.join(outputdir, "pred.xyz")
#xyz_full = os.path.join(outputdir, "full.xyz")

for i in tqdm.tqdm(range(1000)):
    try:
        batch = next(ds)
    except StopIteration:
        print("End of dataset")
        break
    inputs, targets, preds, xyzs = predict(state, batch)
    target_mols = [
        grid_to_mol(t, z_cutoff=1.0, peak_threshold=0.3) for t in targets
    ]
    pred_mols = [
        grid_to_mol(p, z_cutoff=1.0, peak_threshold=0.3) for p in preds
    ]
    #full_mols = [
    #    ase.Atoms(
    #        positions=xyz[xyz[:, -1] > 0, :3],
    #        numbers=xyz[xyz[:, -1] > 0, -1],
    #    ) for xyz in xyzs
    #]

    # Write the molecules to the xyz files
    io.write(xyz_target, target_mols, format="extxyz", append=True)
    io.write(xyz_pred, pred_mols, format="extxyz", append=True)
    #io.write(xyz_full, full_mols, format="extxyz", append=True)

    #for j in range(inputs.shape[0]):
    #    fig = plt.figure(figsize=(10, 5))
    #    subfigs = fig.subfigures(1, 2, wspace=0.1, hspace=0.1)
    #    ax = subfigs[0].add_subplot(111)
    #    ax.imshow(inputs[j, ..., -1, 0], origin='lower', cmap='gray')
    #    ax = subfigs[1].add_subplot(231)
    #    plot_atoms(target_mols[j], ax=ax, show_unit_cell=2)
    #    ax = subfigs[1].add_subplot(234)        
    #    plot_atoms(target_mols[j], ax=ax, rotation='-90x', show_unit_cell=2)
    #    ax = subfigs[1].add_subplot(232)
    #    plot_atoms(pred_mols[j], ax=ax, show_unit_cell=2)
    #    
    #    ax = subfigs[1].add_subplot(235)
    #    plot_atoms(pred_mols[j], ax=ax, rotation='-90x', show_unit_cell=2)
    #    ax = subfigs[1].add_subplot(233)
    #    plot_atoms(full_mols[j], ax=ax, show_unit_cell=2)
    #    ax = subfigs[1].add_subplot(236)
    #    plot_atoms(full_mols[j], ax=ax, rotation='-90x', show_unit_cell=2)
    #    plt.tight_layout()
    #    plt.savefig(f'{outputdir}/mol_{i}_{j}.png')
    #    plt.close()


  0%|          | 0/1000 [00:00<?, ?it/s]

2025-02-04 18:01:20.149186: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (bf16[12,128,16,16,20]{4,3,2,1,0}, u8[0]{0}) custom-call(bf16[12,288,16,16,20]{4,3,2,1,0}, bf16[128,288,3,3,3]{4,3,2,1,0}, bf16[128]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2025-02-04 18:01:20.646019: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.496926946s
Trying algorithm eng0{} for conv (bf16[12,128,16,16,20]{4,3,2,1,0}, u8[0]{0}) custom-call(bf16[12,288,16,16,20]{4,3,2,1,0}, bf16[128,288,3,3,3]{4,3,2,1,0}, bf16[128]{0}), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__c

End of dataset



