# Fine-tune the pretrained CHGNet for better accuracy


In [9]:
import random
import numpy as np
import torch
from pymatgen.core import Structure
from chgnet.model import CHGNet

## 1. Prepare Training Data


In [None]:
from collections import defaultdict
from chgnet.utils import read_json
from pymatgen.core.structure import Structure

sources = ["chgnet_dataset_dhh_I-42d.json", "chgnet_dataset_dhh_with_stresses.json"]

merged = defaultdict(list)

for s in sources:
    d = read_json(s)
    for k, v in d.items():
        if isinstance(v, list):
            merged[k].extend(v)
        else:
            merged[k].append(v)

dataset_dict = dict(merged)

structures = [Structure.from_dict(struct) for struct in dataset_dict["structure"]]
energies_per_atom = dataset_dict["energy_per_atom"]
forces = dataset_dict["force"]
stresses = dataset_dict.get("stress") or None
magmoms = None

print(len(structures))

297


Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion
modifies it to be kbar in VASP raw unit.
If you're using stress labels from VASP, you don't need to do any unit conversions
StructureData dataset class takes in VASP units.


## 2. Define DataSet


In [None]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader

In [None]:
dataset = StructureData(
    structures=structures,
    energies=energies_per_atom,
    forces=forces,
    stresses=stresses,  # can be None
    magmoms=magmoms,  # can be None
)

train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=8, train_ratio=0.8, val_ratio=0.1
)

train_idx = np.array(train_loader.sampler.indices)
val_idx   = np.array(val_loader.sampler.indices)
test_idx  = np.array(test_loader.sampler.indices)

np.savez(
    "dhh_split_indices.npz",
    train_idx=train_idx,
    val_idx=val_idx,
    test_idx=test_idx,
)

The training set is used to optimize the CHGNet through gradient descent, the validation set is used to see validation error at the end of each epoch, and the test set is used to see the final test error at the end of training. The test set can be optional.

The `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.

If you have very large numbers (>100K) of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.


## 3. Define model and trainer


In [None]:
from chgnet.model import CHGNet
from chgnet.trainer import Trainer

# Load pretrained CHGNet
chgnet = CHGNet.load(model_name='r2scan')

It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.


In [None]:
for p in chgnet.parameters():
    p.requires_grad = True

for layer in [
    chgnet.atom_embedding,
    chgnet.bond_basis_expansion,
    chgnet.bond_embedding,
    chgnet.bond_weights_ag,
    chgnet.bond_weights_bg,
    chgnet.angle_basis_expansion,
    chgnet.angle_embedding,
    chgnet.bond_conv_layers,
    chgnet.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

for conv in chgnet.atom_conv_layers[:-1]:
    for param in conv.parameters():
        param.requires_grad = False

for name, p in chgnet.named_parameters():
    if p.requires_grad:
        print(name)

total_trainable = sum(p.numel() for p in chgnet.parameters() if p.requires_grad)
total = sum(p.numel() for p in chgnet.parameters())
print("trainable:", total_trainable, "of", total, f"({100*total_trainable/total:.2f}%)")


In [None]:
# Define Trainer
trainer = Trainer(
    model=chgnet,
    targets="efs",
    optimizer="SGD",
    scheduler="CosLR",
    criterion="MSE",
    epochs=30,
    learning_rate=5e-4,
    use_device="cpu",
    print_freq=10,
    allow_missing_labels = False,

    energy_loss_ratio = 1,
    force_loss_ratio = 1,
    stress_loss_ratio = 1,
)

## 4. Start training


In [None]:
trainer.train(train_loader, val_loader, test_loader)

In [None]:
import json

history = trainer.training_history 

history_json = json.dumps(history)
with open("dhh_mae_history.json", "w") as logs:
    logs.write(history_json)

In [None]:
import matplotlib.pyplot as plt

targets = ["e", "f", "s"]
titles  = {"e": "Energy MAE", "f": "Forces MAE", "s": "Stresses MAE"}

n_epochs = len(history[targets[0]]["train"])
epochs = np.arange(n_epochs)

fig, axes = plt.subplots(len(targets), 1, figsize=(7, 3.6 * len(targets)), sharex=True)
if len(targets) == 1:
    axes = [axes]

for ax, t in zip(axes, targets):
    ax.plot(epochs, history[t]["train"], marker="o", label="train")
    ax.plot(epochs, history[t]["val"],   marker="s", label="val")
    ax.set_title(titles.get(t, f"{t.upper()} MAE"))
    ax.set_ylabel(f"{t.upper()} MAE")
    ax.grid(True, alpha=0.3)
    ax.legend()

axes[-1].set_xlabel("Epoch")
fig.tight_layout()
fig.savefig("mae_all_targets_vs_epoch.png", dpi=200)
plt.show()

After training, the trained model can be found in the directory of today's date. Or it can be accessed by:


In [None]:
model = trainer.model
best_model = trainer.best_model  # best model based on validation energy MAE