# Fine-tune the pretrained CHGNet for better accuracy


In [None]:
from __future__ import annotations

# install CHGNet (only needed on Google Colab or if you didn't install CHGNet yet)
!pip install chgnet

In [None]:
import numpy as np
from pymatgen.core import Structure

from chgnet.model import CHGNet

chgnet = CHGNet.load()

CHGNet initialized with 400,438 parameters


## 1. Prepare Training Data


In [None]:
try:
    from chgnet import ROOT

    lmo = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")
except Exception:
    from urllib.request import urlopen

    url = "https://raw.githubusercontent.com/CederGroupHub/chgnet/main/examples/mp-18767-LiMnO2.cif"
    cif = urlopen(url).read().decode("utf-8")
    lmo = Structure.from_str(cif, fmt="cif")

We create a dummy fine-tuning dataset by using CHGNet prediction with some random noise.
For your purpose of fine-tuning to a specific chemical system or AIMD data, please modify the block below


In [None]:
structures, energies_per_atom, forces, stresses, magmoms = [], [], [], [], []

for _ in range(100):
    structure = lmo.copy()
    # stretch the cell by a small amount
    structure.apply_strain(np.random.uniform(-0.1, 0.1, size=3))
    # perturb all atom positions by a small amount
    structure.perturb(0.1)

    pred = chgnet.predict_structure(structure)

    structures.append(structure)
    energies_per_atom.append(pred["e"] + np.random.uniform(-0.1, 0.1, size=1))
    forces.append(pred["f"] + np.random.uniform(-0.01, 0.01, size=pred["f"].shape))
    stresses.append(
        pred["s"] * -10 + np.random.uniform(-0.05, 0.05, size=pred["s"].shape)
    )
    magmoms.append(pred["m"] + np.random.uniform(-0.03, 0.03, size=pred["m"].shape))

Note that the stress output from CHGNet is in unit of GPa, here the -10 unit conversion
modifies it to be kbar in VASP raw unit. We do this since by default, StructureData
dataset class takes in VASP units.


## 2. Define DataSet


In [None]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader

In [None]:
dataset = StructureData(
    structures=structures,
    energies=energies_per_atom,
    forces=forces,
    stresses=stresses,  # can be None
    magmoms=magmoms,  # can be None
)
train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=8, train_ratio=0.9, val_ratio=0.05
)

100 structures imported


Here the `batch_size` is defined to be 8 for small GPU-memory. If > 10 GB memory is available, we highly recommend to increase `batch_size` for better speed.

If you have very large numbers of structures (which is typical for AIMD), putting them all in a python list can quickly run into memory issues. In this case we highly recommend you to pre-convert all the structures into graphs and save them as shown in `examples/make_graphs.py`. Then directly train CHGNet by loading the graphs from disk instead of memory using the `GraphData` class defined in `data/dataset.py`.


## 3. Define model and trainer


In [None]:
from chgnet.trainer import Trainer

# Load pretrained CHGNet
chgnet = CHGNet.load()

CHGNet initialized with 400,438 parameters


It's optional to freeze the weights inside some layers. This is a common technique to retain the learned knowledge during fine-tuning in large pretrained neural networks. You can choose the layers you want to freeze.


In [None]:
# Optionally fix the weights of some layers
for layer in [
    chgnet.atom_embedding,
    chgnet.bond_embedding,
    chgnet.angle_embedding,
    chgnet.bond_basis_expansion,
    chgnet.angle_basis_expansion,
    chgnet.atom_conv_layers[:-1],
    chgnet.bond_conv_layers,
    chgnet.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

In [None]:
# Define Trainer
trainer = Trainer(
    model=chgnet,
    targets="efsm",
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=5,
    learning_rate=0,
    use_device="cpu",
    print_freq=6,
)

## 4. Start training


In [None]:
trainer.train(train_loader, val_loader, test_loader)

Begin Training: using cpu device
training targets: efsm
Epoch: [0][1/12]	Time (0.353)  Data (0.012)  Loss 0.0041 (0.0041)  MAEs:  e 0.060 (0.060)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.014 (0.014)  
Epoch: [0][6/12]	Time (0.401)  Data (0.011)  Loss 0.0048 (0.0043)  MAEs:  e 0.059 (0.059)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.014 (0.014)  
Epoch: [0][12/12]	Time (0.386)  Data (0.015)  Loss 0.0015 (0.0041)  MAEs:  e 0.035 (0.056)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.012 (0.014)  
*   e_MAE (0.060) 	f_MAE (0.005) 	s_MAE (0.002) 	m_MAE (0.013) 	
Epoch: [1][1/12]	Time (0.357)  Data (0.000)  Loss 0.0025 (0.0025)  MAEs:  e 0.043 (0.043)  f 0.005 (0.005)  s 0.002 (0.002)  m 0.015 (0.015)  
Epoch: [1][6/12]	Time (0.357)  Data (0.000)  Loss 0.0042 (0.0042)  MAEs:  e 0.055 (0.058)  f 0.005 (0.005)  s 0.003 (0.003)  m 0.012 (0.014)  
Epoch: [1][12/12]	Time (0.329)  Data (0.000)  Loss 0.0005 (0.0041)  MAEs:  e 0.017 (0.056)  f 0.005 (0.005)  s 0.003 (0.002)  m 0.015 (0.014)  
*   e_MAE (0.060) 	

After training, the trained model can be found in the directory of today's date. Or it can be accessed by:


In [None]:
model = trainer.model
best_model = trainer.best_model  # best model based on validation energy MAE

## Extras 1: GGA / GGA+U compatibility


### Q: Why and when do you care about this?

**When**: If you want to fine-tune the pretrained CHGNet with your own GGA+U VASP calculations, and you want to keep your VASP energy compatible to the pretrained dataset. In case your dataset is so large that the pretrained knowledge does not matter to you, you can ignore this.

**Why**: CHGNet is trained on both GGA and GGA+U calculations from Materials Project. And there has been developed methods in solving the compatibility between GGA and GGA+U calculations which makes the energies universally applicable for cross-chemistry comparison and phase-diagram constructions. Please refer to:

https://journals.aps.org/prb/abstract/10.1103/PhysRevB.84.045115

Below we show an example to apply the compatibility.


In [None]:
# Imagine this is the VASP raw energy
vasp_raw_energy = -58.97

print(f"The raw total energy from VASP of LMO is: {vasp_raw_energy} eV")

The raw total energy from VASP of LMO is: -58.97 eV


You can look for the energy correction applied to each element in :

https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/MP2020Compatibility.yaml

Here LiMnO2 applies to both Mn in transition metal oxides correction and oxide correction.


To demystify `MaterialsProject2020Compatibility`, basically all that's happening is:


In [None]:
Mn_correction_in_TMO = -1.668
oxide_correction = -0.687
_, num_Mn, num_O = lmo.composition.values()


corrected_energy = (
    vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction
)
print(f"The corrected total energy after MP2020 = {corrected_energy:.4} eV")

The corrected total energy after MP2020 = -65.05 eV


You can also apply the `MaterialsProject2020Compatibility` through pymatgen


In [None]:
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry, ComputedEntry

params = {"hubbards": {"Mn": 3.9, "O": 0, "Li": 0}, "run_type": "GGA+U"}

cse = ComputedStructureEntry(lmo, vasp_raw_energy, parameters=params)

MaterialsProject2020Compatibility(check_potcar=False).process_entries(cse)
print(
    f"The total energy of LMO after MP2020Compatibility correction = {cse.energy:.4} eV"
)

The total energy of LMO after MP2020Compatibility correction = -65.05 eV


Now use this corrected energy as labels to tune CHGNet, you're good to go!


## Extras 2: AtomRef


If you want to fine tune CHGNet to DFT labels that are even more incompatible with Materials Project, like r2SCAN functional, or other DFTs like Gaussian or QE. More trick has to be done to withhold the most amount of information learned during pretraining.

For example, formation energy can be a well-compatible property across different functionals. In CHGNet, we use a Atom_Ref operation, which is a formation-energy-like calculation for per-element contribution to the total energy.

When fine-tuning to other functionals that might have large discrepancy in elemental energies. We recommend you to refit the AtomRef. So that the finetuning on the graph layers can be focused on energy contribution from atom-atom interaction instead of meaningless atom reference energies.

Below I will show an example to refit the AtomRef layer:


In [None]:
print("The pretrained Atom_Ref (per atom reference energy):")
for param in chgnet.composition_model.parameters():
    print(param)

The pretrained Atom_Ref (per atom reference energy):
Parameter containing:
tensor([[ -3.4431,  -0.1279,  -2.8300,  -3.4737,  -7.4946,  -8.2354,  -8.1611,
          -8.3861,  -5.7498,  -0.0236,  -1.7406,  -1.6788,  -4.2833,  -6.2002,
          -6.1315,  -5.8405,  -3.8795,  -0.0703,  -1.5668,  -3.4451,  -7.0549,
          -9.1465,  -9.2594,  -9.3514,  -8.9843,  -8.0228,  -6.4955,  -5.6057,
          -3.4002,  -0.9217,  -3.2499,  -4.9164,  -4.7810,  -5.0191,  -3.3316,
           0.5130,  -1.4043,  -3.2175,  -7.4994,  -9.3816, -10.4386,  -9.9539,
          -7.9555,  -8.5440,  -7.3245,  -5.2771,  -1.9014,  -0.4034,  -2.6002,
          -4.0054,  -4.1156,  -3.9928,  -2.7003,   2.2170,  -1.9671,  -3.7180,
          -6.8133,  -7.3502,  -6.0712,  -6.1699,  -5.1471,  -6.1925, -11.5829,
         -15.8841,  -5.9994,  -6.0798,  -5.9513,  -6.0400,  -5.9773,  -2.5091,
          -6.0767, -10.6666, -11.8761, -11.8491, -10.7397,  -9.6100,  -8.4755,
          -6.2070,  -3.0337,   0.4726,  -1.6425,  -3.129

In [None]:
# A list of structures / graphs
structures = [
    lmo,
    Structure(
        species=["Li", "Mn", "Mn", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(6, 3),
    ),
    Structure(
        species=["Li", "Li", "Mn", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(6, 3),
    ),
    Structure(
        species=["Li", "Mn", "Mn", "O", "O", "O", "O"],
        lattice=np.random.rand(3, 3),
        coords=np.random.rand(7, 3),
    ),
]

# A list of energy_per_atom values (random values here)
energies_per_atom = [5.5, 6, 4.8, 5.6]

In [None]:
from chgnet.model.composition_model import AtomRef

print("We initialize another identical AtomRef layers")
new_atom_ref = AtomRef(is_intensive=True)
new_atom_ref.initialize_from_MPtrj()
for param in new_atom_ref.parameters():
    print(param[:, :3])

We initialize another identical AtomRef layers
tensor([[-3.4431, -0.1279, -2.8300]], grad_fn=<SliceBackward0>)


In [None]:
new_atom_ref.fit(structures, energies_per_atom)
print("After refitting, the AtomRef looks like:")
for param in new_atom_ref.parameters():
    print(param)

After refitting, the AtomRef looks like:
Parameter containing:
tensor([[ 0.0000e+00,  0.0000e+00,  4.2667e+00, -3.3299e-15,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  2.9999e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.1467e+01,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  