In [5]:

import schnetpack as spk
from schnetpack.data import ASEAtomsData
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl

import os 

In [8]:
from schnetpack.datasets import rMD17
from schnetpack.transform import ASENeighborList

filepath_db = os.path.join(os.getcwd(), 'data\\rMD17\\rMD17.db')
#filepath_split = os.path.join(os.getcwd(), 'data\\rMD17\\split_qm9.npz')

ethanol_data = rMD17(
    filepath_db, 
    molecule='ethanol',
    batch_size=10,
    num_train=100_000,
    num_val=10_000,
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.RemoveOffsets(rMD17.energy, remove_mean=True, remove_atomrefs=False),
        trn.CastTo32()
    ],
    num_workers=1,
    pin_memory=True, # set to false, when not using a GPU
)


ethanol_data.prepare_data()
ethanol_data.setup()


100%|██████████| 10000/10000 [05:10<00:00, 32.19it/s]


## Run define_hessian_database.py here

In [17]:
filepath_hessian_db = os.path.join(os.getcwd(), 'data\\ene_grad_hess_1000eth\\data.db')
filepath_no_hessian_db = os.path.join(os.getcwd(), 'data\\ene_grad_hess_1000eth\\data-no-hessian.db')

hessianData = spk.data.AtomsDataModule(
    filepath_hessian_db, 
    distance_unit="Ang",
    property_units={"energy": "Hartree",
                    "forces": "Hartree/Bohr",
                    "hessian": "Hartree/Bohr/Bohr"
                    },
    batch_size=10,
    
    transforms=[
        trn.ASENeighborList(cutoff=5.),
        trn.RemoveOffsets("energy", remove_mean=True, remove_atomrefs=False),
        trn.CastTo32()
    ],
    
    num_train=800,
    num_val=100,
    num_test=100,
    
    #pin_memory=True, # set to false, when not using a GPU
)
hessianData.prepare_data()
hessianData.setup()

100%|██████████| 80/80 [00:11<00:00,  7.09it/s]


In [3]:
print(len(hessianData.dataset))
print(len(hessianData.train_dataset))
print(len(hessianData.train_idx))
print(hessianData.train_idx[:10])


1000
100000
100000
[14155, 77555, 45811, 39209, 26021, 40902, 74535, 71920, 62531, 37630]


In [14]:
print('Number of reference calculations:', len(hessianData.dataset))
print('Number of train data:', len(hessianData.train_dataset))
print('Number of validation data:', len(hessianData.val_dataset))
print('Number of test data:', len(hessianData.test_dataset))
print('Available properties:')

for p in hessianData.dataset.available_properties:
    print('-', p)

Number of reference calculations: 1000
Number of train data: 800
Number of validation data: 100
Number of test data: 100
Available properties:
- energy
- forces
- hessian


In [15]:
for p in hessianData.property_units.keys():
    print('-', p)

- energy
- forces
- hessian


In [21]:
hessianData.val_dataloader()


<schnetpack.data.loader.AtomsLoader at 0x22149fcb610>

In [18]:
cutoff = 5.
n_atom_basis = 30

pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
paiNN = spk.representation.PaiNN(
    n_atom_basis=n_atom_basis, 
    n_interactions=3,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
)

pred_energy = spk.atomistic.Atomwise(n_in=n_atom_basis, output_key="energy")
pred_forces = spk.atomistic.Forces(energy_key="energy", force_key="forces")
pred_polarizability = spk.atomistic.Polarizability(n_in = n_atom_basis, polarizability_key = "polarizability")

nnpot = spk.model.NeuralNetworkPotential(
    representation=paiNN,
    input_modules=[pairwise_distance],
    output_modules=[pred_energy, pred_forces, pred_polarizability],
    postprocessors=[
        trn.CastTo64(),
        trn.AddOffsets("energy", add_mean=True, add_atomrefs=False)
    ]
)

output_energy = spk.task.ModelOutput(
    name="energy",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=0.01,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)

output_forces = spk.task.ModelOutput(
    name="forces",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=0.99,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)

output_polarizability = spk.task.ModelOutput(
    name="polarizability",
    loss_fn=torch.nn.MSELoss(),
    loss_weight=0,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)


task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_energy, output_forces],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)

directory_training = os.path.join(os.getcwd(), "maxim\\data\\ene_grad_hess_1000eth")
filepath_model = os.path.join(directory_training, "best_inference_model")

logger = pl.loggers.TensorBoardLogger(save_dir=directory_training)
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=filepath_model,
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=directory_training,
    max_epochs=5, # for testing, we restrict the number of epochs
)
trainer.fit(task, datamodule=hessianData)

c:\Users\maxim\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\utilities\parsing.py:199: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 43.1 K
1 | outputs | ModuleList             | 0     
---------------------------------------------------
43.1 K    Trainable params
0         Non-trainable params
43.1 K    Total params
0.172     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\maxim\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
c:\Users\maxim\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\utilities\data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 10. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
c:\Users\maxim\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:436: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
from ase import Atoms

# set device
#device = torch.device("cuda")
device = "cpu"

# load model
best_model = torch.load(filepath_model, map_location=device)

# set up converter
converter = spk.interfaces.AtomsConverter(
    neighbor_list=trn.ASENeighborList(cutoff=5.0), dtype=torch.float32, device=device
)


# create atoms object from dataset
structure = ethanol_data.test_dataset[0]
atoms = Atoms(
    numbers=structure[spk.properties.Z], positions=structure[spk.properties.R]
)

# convert atoms to SchNetPack inputs and perform prediction
inputs = converter(atoms)
results = best_model(inputs)

print(results)

{'energy': tensor([-97078.5350], dtype=torch.float64, grad_fn=<AddBackward0>), 'forces': tensor([[  7.2651,  57.8519, -58.1546],
        [ -9.6054,   2.3938,  28.5954],
        [ 14.8588, -39.4147,  87.0120],
        [ 11.8051,  -9.7743,  -2.4411],
        [ -1.5639,  -6.6200,   3.7865],
        [ -4.2221, -12.5190,  11.4114],
        [  2.5417,  -3.6735,  -3.0063],
        [ 12.2374,   7.4088, -21.9501],
        [-33.3167,   4.3472, -45.2533]], dtype=torch.float64,
       grad_fn=<ToCopyBackward0>)}


In [26]:
temp = hessianData
print(temp)

<schnetpack.data.atoms.ASEAtomsData object at 0x000001490E82D690>
