# Fine-tune the pretrained CHGNet for better accuracy


In [1]:
import random
import numpy as np
import torch
from pymatgen.core import Structure
from chgnet.model import CHGNet

## 1. Prepare Training Data


In [3]:
from chgnet.utils import read_json
from pymatgen.core.structure import Structure

dataset_dict = read_json("/home/phy_cmp/python_projects/CHGnet/chgnet_dataset_dhh_I-42d.json")

structures = [Structure.from_dict(struct) for struct in dataset_dict["structure"]]
energies_per_atom = dataset_dict["energy_per_atom"]
forces = dataset_dict["force"]
stresses = dataset_dict.get("stress") or None
magmoms = dataset_dict.get("magmom") or None

Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion
modifies it to be kbar in VASP raw unit.
If you're using stress labels from VASP, you don't need to do any unit conversions
StructureData dataset class takes in VASP units.


## 2. Define DataSet


In [4]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader

In [8]:
dataset = StructureData(
    structures=structures,
    energies=energies_per_atom,
    forces=forces,
    stresses=stresses,  # can be None
    magmoms=magmoms,  # can be None
)

train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=8, train_ratio=0.8, val_ratio=0.1
)

train_idx = np.array(train_loader.sampler.indices)
val_idx   = np.array(val_loader.sampler.indices)
test_idx  = np.array(test_loader.sampler.indices)

np.savez(
    "dhh_split_indices.npz",
    train_idx=train_idx,
    val_idx=val_idx,
    test_idx=test_idx,
)

StructureData imported 297 structures


The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.

The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.

If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.


## 3. Define model and trainer


In [9]:
from chgnet.model import CHGNet
from chgnet.trainer import Trainer

# Load pretrained CHGNet
chgnet = CHGNet.load(model_name='r2scan')

CHGNet vr2scan initialized with 412,525 parameters
CHGNet will run on cpu


It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.


In [None]:
for p in chgnet.parameters():
    p.requires_grad = True

for layer in [
    chgnet.atom_embedding,
    chgnet.bond_basis_expansion,
    chgnet.bond_embedding,
    chgnet.bond_weights_ag,
    chgnet.bond_weights_bg,
    chgnet.angle_basis_expansion,
    chgnet.angle_embedding,
    chgnet.bond_conv_layers,
    chgnet.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

for conv in chgnet.atom_conv_layers[:-1]:
    for param in conv.parameters():
        param.requires_grad = False

for name, p in chgnet.named_parameters():
    if p.requires_grad:
        print(name)

total_trainable = sum(p.numel() for p in chgnet.parameters() if p.requires_grad)
total = sum(p.numel() for p in chgnet.parameters())
print("trainable:", total_trainable, "of", total, f"({100*total_trainable/total:.2f}%)")


=== ОБУЧАЕМЫЕ ПАРАМЕТРЫ ===
composition_model.fc.weight
atom_conv_layers.3.twoBody_atom.mlp_core.layers.0.weight
atom_conv_layers.3.twoBody_atom.mlp_core.layers.0.bias
atom_conv_layers.3.twoBody_atom.mlp_core.layers.3.weight
atom_conv_layers.3.twoBody_atom.mlp_core.layers.3.bias
atom_conv_layers.3.twoBody_atom.mlp_gate.layers.0.weight
atom_conv_layers.3.twoBody_atom.mlp_gate.layers.0.bias
atom_conv_layers.3.twoBody_atom.mlp_gate.layers.3.weight
atom_conv_layers.3.twoBody_atom.mlp_gate.layers.3.bias
atom_conv_layers.3.twoBody_atom.bn1.weight
atom_conv_layers.3.twoBody_atom.bn1.bias
atom_conv_layers.3.twoBody_atom.bn2.weight
atom_conv_layers.3.twoBody_atom.bn2.bias
atom_conv_layers.3.mlp_out.layers.1.weight
site_wise.weight
site_wise.bias
readout_norm.weight
readout_norm.bias
mlp.layers.0.weight
mlp.layers.0.bias
mlp.layers.2.weight
mlp.layers.2.bias
mlp.layers.4.weight
mlp.layers.4.bias
mlp.layers.7.weight
mlp.layers.7.bias
trainable: 50208 of 412525 (12.17%)


In [20]:
# Define Trainer
trainer = Trainer(
    model=chgnet,
    targets="efs",
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=50,
    learning_rate=3e-4,
    use_device="cpu",
    print_freq=6,
)

## 4. Start training


In [21]:
trainer.train(train_loader, val_loader, test_loader)

Begin Training: using cpu device
training targets: efs
Epoch: [0][1/30] | Time (2.521)(0.001) | Loss 0.5004(0.5004) | MAE e 0.313(0.313)  f 0.101(0.101)  s 0.767(0.767)  
Epoch: [0][6/30] | Time (2.274)(0.000) | Loss 0.4899(0.3868) | MAE e 0.240(0.277)  f 0.115(0.096)  s 0.681(0.619)  
Epoch: [0][12/30] | Time (2.437)(0.001) | Loss 0.7838(0.5419) | MAE e 0.380(0.308)  f 0.190(0.104)  s 0.892(0.743)  
Epoch: [0][18/30] | Time (2.500)(0.000) | Loss 0.2320(0.5225) | MAE e 0.262(0.297)  f 0.050(0.098)  s 0.520(0.747)  
Epoch: [0][24/30] | Time (2.590)(0.000) | Loss 0.2085(0.5820) | MAE e 0.200(0.291)  f 0.053(0.100)  s 0.491(0.789)  
Epoch: [0][30/30] | Time (2.535)(0.000) | Loss 0.2623(0.5541) | MAE e 0.185(0.277)  f 0.091(0.094)  s 0.618(0.771)  
*   e_MAE (0.236) 	f_MAE (0.075) 	s_MAE (0.713) 	
Epoch: [1][1/30] | Time (2.651)(0.001) | Loss 0.2644(0.2644) | MAE e 0.233(0.233)  f 0.059(0.059)  s 0.514(0.514)  
Epoch: [1][6/30] | Time (2.622)(0.001) | Loss 0.5513(0.3572) | MAE e 0.323(0.24

After training, the trained model can be found in the directory of today's date. Or it can be accessed by:


In [22]:
model = trainer.model
best_model = trainer.best_model  # best model based on validation energy MAE