In [1]:
import numpy as np
from kgcnn.data.datasets.MD17RevisedDataset import MD17RevisedDataset

In [2]:
dataset = MD17RevisedDataset("aspirin")
print("Number of steps", len(dataset))

INFO:kgcnn.data.download:Checking and possibly downloading dataset with name MD17Revised
INFO:kgcnn.data.download:Dataset directory located at C:\Users\patri\.kgcnn\datasets
INFO:kgcnn.data.download:Dataset directory found. Done.
INFO:kgcnn.data.download:Dataset found. Done.
INFO:kgcnn.data.download:Directory for extraction exists. Done.
INFO:kgcnn.data.download:Not extracting tar File. Stopped.


Number of steps 100000


In [3]:
data = dataset[:500]
data.map_list(method="set_range", node_coordinates="coords")
data[0].keys()

dict_keys(['coords', 'energies', 'forces', 'old_indices', 'old_energies', 'old_forces', 'nuclear_charges', 'range_indices', 'range_attributes'])

In [4]:
eng = np.expand_dims(data.get("energies"), axis=-1)
eng.shape, eng[:3]

((500, 1),
 array([[-406276.78237402],
        [-406272.40203419],
        [-406282.16488753]]))

In [5]:
forces = data.tensor({"name": "forces", "ragged": True, "shape": (None, 3)})
atoms = data.get("nuclear_charges")
forces.shape

TensorShape([500, None, 3])

In [6]:
from kgcnn.scaler.force import EnergyForceExtensiveScaler
scaler = EnergyForceExtensiveScaler(standardize_coordinates=False, standardize_scale=False)  # For testing no scale.
_, eng, _ = scaler.fit_transform(X=None, y=eng, force=forces, atomic_number=atoms)
eng[:3]

array([[-2.12796113],
       [ 2.25237869],
       [-7.51047464]])

In [7]:
from kgcnn.model.force import EnergyForceModel

In [8]:
model_config = {
    "module_name": "kgcnn.literature.Schnet", 
    "class_name":"make_model", 
    "output_as_dict":True,
    "config":{
        "name": "Schnet",
        "inputs": [{"shape": (None,), "name": "nuclear_charges", "dtype": "float32", "ragged": True},
                   {"shape": (None, 3), "name": "coords", "dtype": "float32", "ragged": True},
                   {"shape": (None, 2), "name": "range_indices", "dtype": "int64", "ragged": True}],
        "input_embedding": {"node": {"input_dim": 95, "output_dim": 64}},
        "make_distance": True, "expand_distance": True,
        "interaction_args": {"units": 128, "use_bias": True,
                             "activation": "kgcnn>shifted_softplus", "cfconv_pool": "sum"},
        "node_pooling_args": {"pooling_method": "sum"},
        "depth": 4,
        "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4},
        "verbose": 10,
        "last_mlp": {"use_bias": [True, True], "units": [128, 64],
                     "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]},
        "output_embedding": "graph", "output_to_tensor": True,
        "use_output_mlp": True,
        "output_mlp": {"use_bias": [True, True], "units": [64, 1],
                       "activation": ["kgcnn>shifted_softplus", "linear"]}},
    "output_to_tensor":False,
    "output_squeeze_states":True,
}
model = EnergyForceModel(**model_config)

INFO:kgcnn.model.utils:Updated model kwargs:
INFO:kgcnn.model.utils:{'name': 'Schnet', 'inputs': ListWrapper([DictWrapper({'shape': (None,), 'name': 'nuclear_charges', 'dtype': 'float32', 'ragged': True}), DictWrapper({'shape': (None, 3), 'name': 'coords', 'dtype': 'float32', 'ragged': True}), DictWrapper({'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64', 'ragged': True})]), 'input_embedding': {'node': {'input_dim': 95, 'output_dim': 64}}, 'make_distance': True, 'expand_distance': True, 'interaction_args': {'units': 128, 'use_bias': True, 'activation': 'kgcnn>shifted_softplus', 'cfconv_pool': 'sum'}, 'node_pooling_args': {'pooling_method': 'sum'}, 'depth': 4, 'gauss_args': {'bins': 20, 'distance': 4, 'offset': 0.0, 'sigma': 0.4}, 'verbose': 10, 'last_mlp': {'use_bias': ListWrapper([True, True]), 'units': ListWrapper([128, 64]), 'activation': ListWrapper(['kgcnn>shifted_softplus', 'kgcnn>shifted_softplus'])}, 'output_embedding': 'graph', 'output_to_tensor': True, 'use_outpu

In [9]:
x_tensor = data.tensor(model.model_config["inputs"])
print([x.shape for x in x_tensor])

[TensorShape([500, None]), TensorShape([500, None, 3]), TensorShape([500, None, 2])]


In [10]:
test_out = model.predict(x_tensor)





In [11]:
[i.shape for i in test_out.values()]

[(500, 1), TensorShape([500, None, 3])]

In [12]:
from kgcnn.metrics.loss import RaggedMeanAbsoluteError
from tensorflow.keras.optimizers import Adam

In [13]:
model.compile(
    loss={"energy": "mean_absolute_error", "force": RaggedMeanAbsoluteError()},
    optimizer=Adam(learning_rate=5e-04),
    metrics=None,
    loss_weights=[1, 20],
)

hist = model.fit(
    x_tensor, {"energy": eng, "force": forces},
    shuffle=True,
    batch_size=64,
    epochs=1500,
    verbose=2,
)

In [14]:
from kgcnn.utils.plots import plot_train_test_loss

In [16]:
# plot_train_test_loss([hist]);

In [24]:
from kgcnn.md.base import MolDynamicsModelPredictor, MolDynamicsModelPostprocessorBase

In [25]:
class ExtensiveEnergyForceScalerPostProcessor(MolDynamicsModelPostprocessorBase):
    
    def __init__(self, scaler=None, energy_key="energy", force_key="forces", atomic_number_key="node_number"):
        self.scaler=scaler
    
    def __call__(self, x, y):
        _, y, forces = self.scaler.inverse_transform(X=None, y=[y["energy"]], force=[y["forces"]], atomic_number=x["node_number"])
        y.udpate()

In [20]:
dyn_model = MolDynamicsModelPredictor(
    model=model, 
    model_inputs=model_config["config"]["inputs"], 
    model_outputs={"energy":"energy", "forces": "force"},
    graph_preprocessors=[{"method": "set_range", "node_coordinates": "coords"}],
    model_postprocessors=[]
)

In [21]:
dyn_model(dataset[500:510])

[{'energy': array([-0.06797101], dtype=float32),
  'forces': array([[-1.8314649e-03,  3.6576260e-03, -4.5695604e-04],
         [ 3.7442998e-03,  1.2417401e-03, -9.9266137e-05],
         [ 1.5709931e-03,  2.0535409e-03, -2.2504595e-05],
         [-4.3749745e-04, -1.3863982e-03,  2.2349555e-04],
         [ 1.7847682e-04,  3.5513099e-04, -2.2926083e-04],
         [ 1.6972204e-03,  1.5914502e-03, -8.5451647e-06],
         [ 3.7249760e-04, -8.1027835e-04,  1.1499882e-03],
         [-2.2408157e-04,  1.7101439e-03, -7.0502341e-04],
         [-2.6327354e-04,  4.7402042e-03, -5.9944917e-03],
         [-3.0927006e-03,  5.9563218e-04,  1.5192307e-03],
         [-2.6641127e-03, -4.0802620e-03,  1.3006423e-03],
         [-2.3875667e-03, -3.1099624e-03,  1.2218144e-03],
         [ 5.0539947e-03, -3.7004778e-03,  2.4854958e-03],
         [ 2.8393797e-03,  2.9339902e-03, -1.2315839e-03],
         [-4.6455525e-03, -9.5430762e-05, -5.1842636e-04],
         [ 5.2910987e-03,  1.1143128e-03,  4.9953500e-04