# train part

In [1]:
import re
from ase import io
import numpy as np


def trans(array=None):
    a = np.empty((3,3))
    a[0,0] = array[0]
    a[1,1] = array[1]
    a[2,2] = array[2]
    a[0,1] = a[1,0] = array[3]
    a[0,2] = a[2,0] = array[4]
    a[1,2] = a[2,1] = array[5]
    return a

atoms_list = io.read('qm7b_coords.xyz', index=':')  # 读取所有帧
property_list = []
with open('qm7b_coords.xyz') as f:
    lines = f.readlines()
    for line in lines:
        if line.startswith('Properties'):
            pattern = r'\w+_pol="([^"]+)"'
            matches = re.findall(pattern, line)
            ccsd = trans(np.array(list(map(float,matches[0].split()))))
            b3lpy = trans(np.array(list(map(float,matches[1].split())))) 
            scan0 = trans(np.array(list(map(float,matches[2].split()))))
            property_list.append({
                'ccsd_pol': ccsd.reshape(1,3,3),
                'b3lyp_pol':b3lpy.reshape(1,3,3),
                'scan0_pol':scan0.reshape(1,3,3),
            })

In [2]:
from schnetpack.data import ASEAtomsData,AtomsDataModule
from schnetpack.transform import ASENeighborList
import os
from ase.db import connect

dbfile = os.path.join('.','qm7b.db')
if not os.path.exists(dbfile):
    new_dataset = ASEAtomsData.create(
        dbfile,
        distance_unit='Ang',
        property_unit_dict={
            'ccsd_pol':'a.u.', 
            'b3lyp_pol':'a.u.',
            'scan0_pol':'a.u.',
        }
    )
    new_dataset.add_systems(property_list, atoms_list)
    for p in new_dataset.available_properties:
        print('-', p)
print()
db = connect(dbfile)
n = len(db)
dataset = AtomsDataModule(
    datapath=dbfile,
    batch_size=1,
    num_train=n-1500,
    num_val=1000,
    num_test=500,
)
dataset.prepare_data()
dataset.setup()



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/site-packages/traitlets




In [12]:
rows = list(db.select())
for row in rows[:5]:
    atoms = row.toatoms()
    print(atoms)
    for key,value in row.data.items():
        print(key,'\n',value)

Atoms(symbols='CH4', pbc=False)
ccsd_pol 
 [[[1.6801585e+01 1.9000000e-05 6.8000000e-05]
  [1.9000000e-05 1.6801594e+01 9.7000000e-05]
  [6.8000000e-05 9.7000000e-05 1.6801649e+01]]]
b3lyp_pol 
 [[[1.73150018e+01 1.45000000e-06 9.91700000e-05]
  [1.45000000e-06 1.73150224e+01 2.04580000e-04]
  [9.91700000e-05 2.04580000e-04 1.73151339e+01]]]
scan0_pol 
 [[[ 1.6942384e+01 -2.6700000e-04  2.8000000e-05]
  [-2.6700000e-04  1.6938517e+01  8.9000000e-05]
  [ 2.8000000e-05  8.9000000e-05  1.6938661e+01]]]
Atoms(symbols='NC4N', pbc=False)
ccsd_pol 
 [[[118.15987    0.         0.      ]
  [  0.        33.155203  -0.      ]
  [  0.        -0.        33.155203]]]
b3lyp_pol 
 [[[ 1.34583075e+02 -1.17650000e-04  3.14490000e-04]
  [-1.17650000e-04  3.35133380e+01 -0.00000000e+00]
  [ 3.14490000e-04 -0.00000000e+00  3.35133381e+01]]]
scan0_pol 
 [[[132.072771   0.         0.      ]
  [  0.        32.976133  -0.      ]
  [  0.        -0.        32.976142]]]
Atoms(symbols='OCSO2C2H2', pbc=False)
ccsd_

In [1]:
import os
import schnetpack as spk
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl

qm7tut = './qm7tut'
if not os.path.exists('qm7tut'):
    os.mkdir(qm7tut)

cutoff = 5.
n_atom_basis = 128

pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=cutoff)
schnet = spk.representation.PaiNN(
    n_atom_basis=n_atom_basis, n_interactions=3,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(cutoff)
)
pred_ccsd = spk.atomistic.Polarizability(n_in=n_atom_basis, polarizability_key='ccsd_pol')

nnpot = spk.model.NeuralNetworkPotential(
    representation=schnet,
    input_modules=[pairwise_distance],
    output_modules=[pred_ccsd],
    postprocessors=[trn.CastTo64()]
)
output_ccsd = spk.task.ModelOutput(
    name='ccsd_pol',
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1.,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)
task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_ccsd],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": 1e-4}
)
logger = pl.loggers.TensorBoardLogger(save_dir=qm7tut)
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(qm7tut, "best_inference_model"),
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=qm7tut,
    max_epochs=3, # for testing, we restrict the number of epochs
)
trainer.fit(task, datamodule=dataset)


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/xujinhong/opt/anaconda3/envs/schnet/lib/python3.10/site-packages/traitlets

NameError: name 'dataset' is not defined