In [1]:
import numpy as np
from kgcnn.data.datasets.MD17RevisedDataset import MD17RevisedDataset

In [2]:
dataset = MD17RevisedDataset("aspirin")
print("Number of steps", len(dataset))

INFO:kgcnn.data.download:Checking and possibly downloading dataset with name MD17Revised
INFO:kgcnn.data.download:Dataset directory located at C:\Users\patri\.kgcnn\datasets
INFO:kgcnn.data.download:Dataset directory found. Done.
INFO:kgcnn.data.download:Dataset found. Done.
INFO:kgcnn.data.download:Directory for extraction exists. Done.
INFO:kgcnn.data.download:Not extracting tar File. Stopped.


Number of steps 100000


In [3]:
data = dataset[:500]
data.map_list(method="set_range", node_coordinates="coords")
data[0].keys()

dict_keys(['coords', 'energies', 'forces', 'old_indices', 'old_energies', 'old_forces', 'nuclear_charges', 'range_indices', 'range_attributes'])

In [4]:
eng = np.expand_dims(data.get("energies"), axis=-1) - np.mean(data.get("energies"))
eng.shape

(500, 1)

In [5]:
forces = data.tensor({"name": "forces", "ragged": True, "shape": (None, 3)})
forces.shape

TensorShape([500, None, 3])

In [6]:
from kgcnn.model.force import EnergyForceModel

In [7]:
model = EnergyForceModel(
    module_name="kgcnn.literature.Schnet", 
    class_name="make_model", 
    config={
        "name": "Schnet",
        "inputs": [{"shape": (None,), "name": "nuclear_charges", "dtype": "float32", "ragged": True},
                   {"shape": (None, 3), "name": "coords", "dtype": "float32", "ragged": True},
                   {"shape": (None, 2), "name": "range_indices", "dtype": "int64", "ragged": True}],
        "input_embedding": {"node": {"input_dim": 95, "output_dim": 64}},
        "make_distance": True, "expand_distance": True,
        "interaction_args": {"units": 128, "use_bias": True,
                             "activation": "kgcnn>shifted_softplus", "cfconv_pool": "sum"},
        "node_pooling_args": {"pooling_method": "sum"},
        "depth": 4,
        "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4},
        "verbose": 10,
        "last_mlp": {"use_bias": [True, True], "units": [128, 64],
                     "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]},
        "output_embedding": "graph", "output_to_tensor": True,
        "use_output_mlp": True,
        "output_mlp": {"use_bias": [True, True], "units": [64, 1],
                       "activation": ["kgcnn>shifted_softplus", "linear"]}},
    output_to_tensor=False,
    output_squeeze_states=True,
)

INFO:kgcnn.utils.models:Updated model kwargs:
INFO:kgcnn.utils.models:{'name': 'Schnet', 'inputs': ListWrapper([DictWrapper({'shape': (None,), 'name': 'nuclear_charges', 'dtype': 'float32', 'ragged': True}), DictWrapper({'shape': (None, 3), 'name': 'coords', 'dtype': 'float32', 'ragged': True}), DictWrapper({'shape': (None, 2), 'name': 'range_indices', 'dtype': 'int64', 'ragged': True})]), 'input_embedding': {'node': {'input_dim': 95, 'output_dim': 64}}, 'make_distance': True, 'expand_distance': True, 'interaction_args': {'units': 128, 'use_bias': True, 'activation': 'kgcnn>shifted_softplus', 'cfconv_pool': 'sum'}, 'node_pooling_args': {'pooling_method': 'sum'}, 'depth': 4, 'gauss_args': {'bins': 20, 'distance': 4, 'offset': 0.0, 'sigma': 0.4}, 'verbose': 10, 'last_mlp': {'use_bias': ListWrapper([True, True]), 'units': ListWrapper([128, 64]), 'activation': ListWrapper(['kgcnn>shifted_softplus', 'kgcnn>shifted_softplus'])}, 'output_embedding': 'graph', 'output_to_tensor': True, 'use_out

In [8]:
x_tensor = data.tensor(model.model_config["inputs"])
print([x.shape for x in x_tensor])

[TensorShape([500, None]), TensorShape([500, None, 3]), TensorShape([500, None, 2])]


In [9]:
test_out = model.predict(x_tensor)





In [10]:
test_out[1].shape

TensorShape([500, None, 3])

In [11]:
from kgcnn.metrics.loss import RaggedMeanAbsoluteError
from tensorflow.keras.optimizers import Adam

In [12]:
model.compile(
    loss=["mean_absolute_error", RaggedMeanAbsoluteError()],
    optimizer=Adam(learning_rate=5e-04),
    metrics=[],
    loss_weights=[1, 1],
)

hist = model.fit(
    x_tensor, [eng, forces],
    shuffle=True,
    batch_size=64,
    epochs=1000,
    verbose=2,
)

Epoch 1/500




8/8 - 10s - loss: 26.5653 - output_1_loss: 4.7830 - output_2_loss: 21.7823 - 10s/epoch - 1s/step
Epoch 2/500
8/8 - 1s - loss: 26.5488 - output_1_loss: 4.7665 - output_2_loss: 21.7823 - 527ms/epoch - 66ms/step
Epoch 3/500
8/8 - 1s - loss: 26.5440 - output_1_loss: 4.7618 - output_2_loss: 21.7823 - 524ms/epoch - 65ms/step
Epoch 4/500
8/8 - 1s - loss: 26.5475 - output_1_loss: 4.7653 - output_2_loss: 21.7822 - 527ms/epoch - 66ms/step
Epoch 5/500
8/8 - 1s - loss: 26.5436 - output_1_loss: 4.7613 - output_2_loss: 21.7822 - 525ms/epoch - 66ms/step
Epoch 6/500
8/8 - 1s - loss: 26.5457 - output_1_loss: 4.7635 - output_2_loss: 21.7822 - 529ms/epoch - 66ms/step
Epoch 7/500
8/8 - 1s - loss: 26.5422 - output_1_loss: 4.7600 - output_2_loss: 21.7821 - 527ms/epoch - 66ms/step
Epoch 8/500
8/8 - 1s - loss: 26.5477 - output_1_loss: 4.7655 - output_2_loss: 21.7821 - 521ms/epoch - 65ms/step
Epoch 9/500
8/8 - 1s - loss: 26.5451 - output_1_loss: 4.7629 - output_2_loss: 21.7822 - 525ms/epoch - 66ms/step
Epoch 1

Epoch 74/500
8/8 - 1s - loss: 25.5110 - output_1_loss: 4.7336 - output_2_loss: 20.7773 - 521ms/epoch - 65ms/step
Epoch 75/500
8/8 - 1s - loss: 25.1463 - output_1_loss: 4.7253 - output_2_loss: 20.4211 - 522ms/epoch - 65ms/step
Epoch 76/500
8/8 - 1s - loss: 25.1487 - output_1_loss: 4.7157 - output_2_loss: 20.4330 - 519ms/epoch - 65ms/step
Epoch 77/500
8/8 - 1s - loss: 25.0069 - output_1_loss: 4.6504 - output_2_loss: 20.3566 - 519ms/epoch - 65ms/step
Epoch 78/500
8/8 - 1s - loss: 24.7274 - output_1_loss: 4.6450 - output_2_loss: 20.0824 - 524ms/epoch - 66ms/step
Epoch 79/500
8/8 - 1s - loss: 25.2262 - output_1_loss: 5.3595 - output_2_loss: 19.8667 - 522ms/epoch - 65ms/step
Epoch 80/500
8/8 - 1s - loss: 25.0352 - output_1_loss: 4.7068 - output_2_loss: 20.3284 - 525ms/epoch - 66ms/step
Epoch 81/500
8/8 - 1s - loss: 25.0837 - output_1_loss: 5.1217 - output_2_loss: 19.9620 - 520ms/epoch - 65ms/step
Epoch 82/500
8/8 - 1s - loss: 24.7326 - output_1_loss: 4.6303 - output_2_loss: 20.1023 - 530ms/e

8/8 - 1s - loss: 16.3571 - output_1_loss: 6.3447 - output_2_loss: 10.0125 - 522ms/epoch - 65ms/step
Epoch 147/500
8/8 - 1s - loss: 15.8265 - output_1_loss: 6.0008 - output_2_loss: 9.8256 - 527ms/epoch - 66ms/step
Epoch 148/500
8/8 - 1s - loss: 17.8102 - output_1_loss: 7.6291 - output_2_loss: 10.1811 - 517ms/epoch - 65ms/step
Epoch 149/500
8/8 - 1s - loss: 17.5786 - output_1_loss: 6.9700 - output_2_loss: 10.6085 - 516ms/epoch - 65ms/step
Epoch 150/500
8/8 - 1s - loss: 16.7490 - output_1_loss: 6.3414 - output_2_loss: 10.4075 - 517ms/epoch - 65ms/step
Epoch 151/500
8/8 - 1s - loss: 16.3537 - output_1_loss: 6.1209 - output_2_loss: 10.2328 - 524ms/epoch - 66ms/step
Epoch 152/500
8/8 - 1s - loss: 15.9146 - output_1_loss: 5.9768 - output_2_loss: 9.9378 - 526ms/epoch - 66ms/step
Epoch 153/500
8/8 - 1s - loss: 15.6515 - output_1_loss: 6.0348 - output_2_loss: 9.6166 - 528ms/epoch - 66ms/step
Epoch 154/500
8/8 - 1s - loss: 15.4284 - output_1_loss: 6.0012 - output_2_loss: 9.4273 - 527ms/epoch - 66

Epoch 219/500
8/8 - 1s - loss: 13.3249 - output_1_loss: 5.6594 - output_2_loss: 7.6655 - 518ms/epoch - 65ms/step
Epoch 220/500
8/8 - 1s - loss: 13.3396 - output_1_loss: 5.6244 - output_2_loss: 7.7152 - 517ms/epoch - 65ms/step
Epoch 221/500
8/8 - 1s - loss: 12.6510 - output_1_loss: 5.0920 - output_2_loss: 7.5590 - 519ms/epoch - 65ms/step
Epoch 222/500
8/8 - 1s - loss: 12.6177 - output_1_loss: 5.1269 - output_2_loss: 7.4909 - 519ms/epoch - 65ms/step
Epoch 223/500
8/8 - 1s - loss: 13.7434 - output_1_loss: 6.2106 - output_2_loss: 7.5328 - 520ms/epoch - 65ms/step
Epoch 224/500
8/8 - 1s - loss: 14.1703 - output_1_loss: 6.1060 - output_2_loss: 8.0643 - 530ms/epoch - 66ms/step
Epoch 225/500
8/8 - 1s - loss: 13.8099 - output_1_loss: 5.9285 - output_2_loss: 7.8813 - 532ms/epoch - 67ms/step
Epoch 226/500
8/8 - 1s - loss: 13.1387 - output_1_loss: 5.4483 - output_2_loss: 7.6905 - 533ms/epoch - 67ms/step
Epoch 227/500
8/8 - 1s - loss: 12.6539 - output_1_loss: 5.1836 - output_2_loss: 7.4703 - 535ms/e

Epoch 292/500
8/8 - 1s - loss: 12.0368 - output_1_loss: 5.3186 - output_2_loss: 6.7182 - 523ms/epoch - 65ms/step
Epoch 293/500
8/8 - 1s - loss: 10.9711 - output_1_loss: 4.3355 - output_2_loss: 6.6356 - 517ms/epoch - 65ms/step
Epoch 294/500
8/8 - 1s - loss: 11.3380 - output_1_loss: 4.6803 - output_2_loss: 6.6577 - 520ms/epoch - 65ms/step
Epoch 295/500
8/8 - 1s - loss: 11.3038 - output_1_loss: 4.6236 - output_2_loss: 6.6802 - 517ms/epoch - 65ms/step
Epoch 296/500
8/8 - 1s - loss: 11.7199 - output_1_loss: 4.8974 - output_2_loss: 6.8226 - 519ms/epoch - 65ms/step
Epoch 297/500
8/8 - 1s - loss: 11.1145 - output_1_loss: 4.3756 - output_2_loss: 6.7389 - 519ms/epoch - 65ms/step
Epoch 298/500
8/8 - 1s - loss: 11.7205 - output_1_loss: 4.9484 - output_2_loss: 6.7721 - 519ms/epoch - 65ms/step
Epoch 299/500
8/8 - 1s - loss: 12.0737 - output_1_loss: 5.2403 - output_2_loss: 6.8333 - 518ms/epoch - 65ms/step
Epoch 300/500
8/8 - 1s - loss: 11.6448 - output_1_loss: 4.8932 - output_2_loss: 6.7516 - 515ms/e

Epoch 365/500
8/8 - 1s - loss: 11.7487 - output_1_loss: 5.2484 - output_2_loss: 6.5003 - 526ms/epoch - 66ms/step
Epoch 366/500
8/8 - 1s - loss: 11.6445 - output_1_loss: 5.1140 - output_2_loss: 6.5305 - 528ms/epoch - 66ms/step
Epoch 367/500
8/8 - 1s - loss: 10.9319 - output_1_loss: 4.3724 - output_2_loss: 6.5595 - 526ms/epoch - 66ms/step
Epoch 368/500
8/8 - 1s - loss: 10.4108 - output_1_loss: 3.9216 - output_2_loss: 6.4892 - 528ms/epoch - 66ms/step
Epoch 369/500
8/8 - 1s - loss: 9.9100 - output_1_loss: 3.6405 - output_2_loss: 6.2695 - 524ms/epoch - 66ms/step
Epoch 370/500
8/8 - 1s - loss: 10.0121 - output_1_loss: 3.7251 - output_2_loss: 6.2871 - 533ms/epoch - 67ms/step
Epoch 371/500
8/8 - 1s - loss: 10.4091 - output_1_loss: 4.0337 - output_2_loss: 6.3753 - 527ms/epoch - 66ms/step
Epoch 372/500
8/8 - 1s - loss: 9.8619 - output_1_loss: 3.6101 - output_2_loss: 6.2518 - 521ms/epoch - 65ms/step
Epoch 373/500
8/8 - 1s - loss: 10.3140 - output_1_loss: 3.8380 - output_2_loss: 6.4760 - 527ms/epo

Epoch 438/500
8/8 - 1s - loss: 8.8847 - output_1_loss: 3.0354 - output_2_loss: 5.8493 - 533ms/epoch - 67ms/step
Epoch 439/500
8/8 - 1s - loss: 8.5202 - output_1_loss: 2.6610 - output_2_loss: 5.8592 - 530ms/epoch - 66ms/step
Epoch 440/500
8/8 - 1s - loss: 8.4921 - output_1_loss: 2.5652 - output_2_loss: 5.9269 - 526ms/epoch - 66ms/step
Epoch 441/500
8/8 - 1s - loss: 9.1723 - output_1_loss: 3.2308 - output_2_loss: 5.9415 - 549ms/epoch - 69ms/step
Epoch 442/500
8/8 - 1s - loss: 9.8588 - output_1_loss: 3.8131 - output_2_loss: 6.0457 - 552ms/epoch - 69ms/step
Epoch 443/500
8/8 - 1s - loss: 9.1890 - output_1_loss: 3.1940 - output_2_loss: 5.9950 - 579ms/epoch - 72ms/step
Epoch 444/500
8/8 - 1s - loss: 8.8465 - output_1_loss: 2.8890 - output_2_loss: 5.9575 - 618ms/epoch - 77ms/step
Epoch 445/500
8/8 - 1s - loss: 9.0632 - output_1_loss: 3.1234 - output_2_loss: 5.9399 - 552ms/epoch - 69ms/step
Epoch 446/500
8/8 - 1s - loss: 9.8408 - output_1_loss: 3.7932 - output_2_loss: 6.0476 - 530ms/epoch - 66

In [13]:
from kgcnn.utils.plots import plot_train_test_loss

In [14]:
plot_train_test_loss(hist)

TypeError: 'History' object is not subscriptable