In [None]:
import numpy as np
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
os.environ["AUTOGRAPH_VERBOSITY"] = "1"

# Set up logger
import logging
logger = logging.getLogger()
logger.handlers = []
ch = logging.StreamHandler()
formatter = logging.Formatter(
    fmt="%(asctime)s (%(levelname)s): %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
ch.setFormatter(formatter)
logger.addHandler(ch)
logger.setLevel("INFO")

import tensorflow as tf
# TensorFlow logging verbosity
tf.get_logger().setLevel("WARN")
tf.autograph.set_verbosity(1)

# GemNet imports
from gemnet.model.gemnet import GemNet
from gemnet.training.data_container import DataContainer


# Custom molecule class to use molecules from ase

In [None]:
class Molecule(DataContainer):
    """
    Implements the DataContainer but for a single molecule. Requires custom init method.
    """
    def __init__(self, R, Z, cutoff, int_cutoff, triplets_only=False):
        self.index_keys = [
            "batch_seg",
            "id_undir",
            "id_swap",
            "id_c",
            "id_a",
            "id3_expand_ba",
            "id3_reduce_ca",
            "Kidx3",
        ]
        if not triplets_only:
            self.index_keys += [
                "id4_int_b",
                "id4_int_a",
                "id4_reduce_ca",
                "id4_expand_db",
                "id4_reduce_cab",
                "id4_expand_abd",
                "Kidx4",
                "id4_reduce_intm_ca",
                "id4_expand_intm_db",
                "id4_reduce_intm_ab",
                "id4_expand_intm_ab",
            ]
        self.triplets_only = triplets_only
        self.cutoff = cutoff
        self.int_cutoff = int_cutoff
        self.keys = ["N", "Z", "R", "F", "E"]

        assert R.shape == (len(Z), 3)
        self.R = R
        self.Z = Z
        self.N = np.array([len(Z)], dtype=np.int32)
        self.E = np.zeros(1, dtype=np.float32).reshape(1, 1)
        self.F = np.zeros((len(Z), 3), dtype=np.float32)

        self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
        self.addID = False
        self.dtypes, dtypes2 = self.get_dtypes()
        self.dtypes.update(dtypes2)  # merge all dtypes in single dict

    def get(self):
        """
        Get the molecule representation in the expected format for the GemNet model.
        """
        data = self.__getitem__(0)
        for var in ["E", "F"]:
            data.pop(var)  # not needed i.e.e not kown -> want to calculate this
        return data

# Setup the model and the data

In [None]:
# Model setup
scale_file = "./scaling_factors.json"
tf_weights_file = "./pretrained/best/ckpt"
# depends on GemNet model that is loaded
triplets_only = False
direct_forces = False
cutoff = 5.0
int_cutoff = 10.0

# Data setup
from ase.build import molecule as ase_molecule_db

mol = ase_molecule_db('C7NH5')
R   = mol.get_positions()
Z   = mol.get_atomic_numbers()

molecule = Molecule(
    R, Z, cutoff=cutoff, int_cutoff=int_cutoff, triplets_only=triplets_only
)

In [None]:
# tf.config.experimental_run_functions_eagerly(True) # uncomment to run in eager mode

model = GemNet(
    num_spherical=7,
    num_radial=6,
    num_blocks=4,
    emb_size_atom=128,
    emb_size_edge=128,
    emb_size_trip=64,
    emb_size_quad=32,
    emb_size_rbf=16,
    emb_size_cbf=16,
    emb_size_sbf=32,
    emb_size_bil_trip=64,
    emb_size_bil_quad=32,
    num_before_skip=1,
    num_after_skip=1,
    num_concat=1,
    num_atom=2,
    num_targets=1,
    cutoff=cutoff,
    int_cutoff=int_cutoff,  # no effect for GemNet-(d)T
    scale_file=scale_file,
    triplets_only=triplets_only,
    direct_forces=direct_forces,
)
model.load_weights(tf_weights_file)

# Run the model

In [None]:
energy, forces = model.predict(molecule.get())

print("Energy [eV]", energy)
print("Forces [eV/°A]", forces)