In [None]:
! unzip /content/drive/MyDrive/finetune_dataset.zip

Archive:  /content/drive/MyDrive/finetune_dataset.zip
   creating: finetune_dataset/
   creating: finetune_dataset/structure_1/
  inflating: finetune_dataset/structure_1/CONTCAR  
 extracting: finetune_dataset/structure_1/CHG  
  inflating: finetune_dataset/structure_1/vasprun.xml  
  inflating: finetune_dataset/structure_1/stdout.31977  
  inflating: finetune_dataset/structure_1/POSCAR  
  inflating: finetune_dataset/structure_1/INCAR  
  inflating: finetune_dataset/structure_1/PCDAT  
  inflating: finetune_dataset/structure_1/DOSCAR  
  inflating: finetune_dataset/structure_1/POTCAR  
  inflating: finetune_dataset/structure_1/KPOINTS  
  inflating: finetune_dataset/structure_1/sub.vasp  
  inflating: finetune_dataset/structure_1/EIGENVAL  
  inflating: finetune_dataset/structure_1/XDATCAR  
  inflating: finetune_dataset/structure_1/stderr.31977  
  inflating: finetune_dataset/structure_1/PROCAR  
  inflating: finetune_dataset/structure_1/IBZKPT  
  inflating: finetune_dataset/structu

# Installation

In [None]:
pip install chgnet

Collecting chgnet
  Downloading chgnet-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting ase>=3.23.0 (from chgnet)
  Downloading ase-3.26.0-py3-none-any.whl.metadata (4.1 kB)
Collecting nvidia-ml-py3>=7.352.0 (from chgnet)
  Downloading nvidia-ml-py3-7.352.0.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pymatgen>=2024.9.10 (from chgnet)
  Downloading pymatgen-2025.6.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting bibtexparser>=1.4.0 (from pymatgen>=2024.9.10->chgnet)
  Downloading bibtexparser-1.4.3.tar.gz (55 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.6/55.6 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting monty>=2025.1.9 (from pymatgen>=2024.9.10->chgnet)
  Downloading monty-2025.3.3-py3-none-any.whl.metadata (3.6 kB)
Collecting palettable>=3.3.3 (from pymatgen>=2024.9.10->c

In [None]:
pip install pymatgen==2023.9.10

Collecting pymatgen==2023.9.10
  Downloading pymatgen-2023.9.10.tar.gz (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting mp-api>=0.27.3 (from pymatgen==2023.9.10)
  Downloading mp_api-0.45.8-py3-none-any.whl.metadata (2.4 kB)
Collecting pybtex (from pymatgen==2023.9.10)
  Downloading pybtex-0.25.1-py2.py3-none-any.whl.metadata (2.2 kB)
Collecting maggma>=0.57.1 (from mp-api>=0.27.3->pymatgen==2023.9.10)
  Downloading maggma-0.72.0-py3-none-any.whl.metadata (11 kB)
Collecting emmet-core>=0.84.3rc6 (from mp-api>=0.27.3->pymatgen==2023.9.10)
  Downloading emmet_core-0.84.10rc2-py3-none-any.whl.metadata (2.9 kB)
Collecting latexcodec>=1.0.4 (from pybtex->pymatgen==2023.9.10)
  Downloading latexcodec-3.0.1-py3-none-any.whl.metada

In [1]:
from chgnet.model import CHGNet
import numpy as np
from pymatgen.core import Structure
import os
from chgnet.utils import read_json
from chgnet.data.dataset import StructureData

# Parse DFT outputs to CHGNet readable formats

In [2]:
from chgnet.utils import parse_vasp_dir

parent_dir = "/content/finetune_dataset/"
os.makedirs("training_data", exist_ok=True)

required_files = ["vasprun.xml", "OSZICAR"]

ranges = [(0, 10)]

for start, end in ranges:
    for i in range(start, end + 1):
        vasp_dir = f"structure_{i}"
        vasp_path = os.path.join(parent_dir, vasp_dir)
        saved_path = "training_data"

        if os.path.isdir(vasp_path):
            if all(os.path.exists(os.path.join(vasp_path, f)) for f in required_files):
                try:
                    dataset_dict = parse_vasp_dir(
                        vasp_path,
                        save_path=os.path.join(saved_path, f"chgnet_dataset_{i}.json")
                    )
                    print(f"Successfully converted {vasp_dir}")
                except Exception as e:
                    print(f"Failed to convert {vasp_dir}: {e}")
            else:
                print(f"Skipping {vasp_dir}: Missing required VASP files.")
        else:
            print(f"{vasp_dir} is not a directory or does not exist.")

Successfully converted structure_0
Successfully converted structure_1
Successfully converted structure_2
Successfully converted structure_3
Successfully converted structure_4
Successfully converted structure_5
Successfully converted structure_6
Successfully converted structure_7
Successfully converted structure_8
Successfully converted structure_9
Successfully converted structure_10


In [3]:
directories_and_ranges = [
    {
        "directory": "training_data",
        "ranges": [(0, 10)]
    }
]

all_structures = []
all_energies = []
all_forces = []
all_stresses = []
all_magmoms = []

for item in directories_and_ranges:
    parent_directory = item["directory"]
    ranges = item["ranges"]

    for start, end in ranges:
        for i in range(start, end + 1):
            json_path = os.path.join(parent_directory, f"chgnet_dataset_{i}.json")

            if os.path.exists(json_path):
                try:
                    dataset_dict = read_json(json_path)
                    print(f"Successfully read data from {json_path}")

                    structures = [Structure.from_dict(struct) for struct in dataset_dict["structure"]]
                    energies = dataset_dict["energy_per_atom"]
                    forces = dataset_dict["force"]
                    stresses = dataset_dict.get("stress") or None
                    magmoms = dataset_dict.get("magmom") or None

                    all_structures.extend(structures)
                    all_energies.extend(energies)
                    all_forces.extend(forces)

                    if stresses:
                        all_stresses.extend(stresses)
                    if magmoms:
                        all_magmoms.extend(magmoms)

                except Exception as e:
                    print(f"Failed to extract data from {json_path}: {e}")
            else:
                print(f"No json file found at {json_path}")

print(f"Total structures extracted: {len(all_structures)}")

dataset = StructureData(
    structures=all_structures,
    energies=all_energies,
    forces=all_forces,
    stresses=all_stresses if all_stresses else None,
    magmoms=all_magmoms if all_magmoms else None
)

# 打印最终的数据集大小
print(f"StructureData imported {len(dataset.structures)} structures")


Successfully read data from training_data/chgnet_dataset_0.json
Successfully read data from training_data/chgnet_dataset_1.json
Successfully read data from training_data/chgnet_dataset_2.json
Successfully read data from training_data/chgnet_dataset_3.json
Successfully read data from training_data/chgnet_dataset_4.json
Successfully read data from training_data/chgnet_dataset_5.json
Successfully read data from training_data/chgnet_dataset_6.json
Successfully read data from training_data/chgnet_dataset_7.json
Successfully read data from training_data/chgnet_dataset_8.json
Successfully read data from training_data/chgnet_dataset_9.json
Successfully read data from training_data/chgnet_dataset_10.json
Total structures extracted: 531
StructureData imported 531 structures
StructureData imported 531 structures


In [4]:
print(len(all_structures))
print(len(all_energies))
print(len(all_forces))
print(len(all_stresses))
print(len(all_magmoms))

531
531
531
531
531


In [5]:
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.core.structure import Structure

Mn_correction_in_TMO = -1.668
oxide_correction = -0.687

corrected_energies = []

ranges = list(range(0, len(all_structures)))

for idx, i in enumerate(ranges):
    structure = all_structures[i]

    vasp_raw_energy = all_energies[idx]*structure.num_sites

    num_Li = structure.composition.get("Li", 0)
    num_Mn = structure.composition.get("Mn", 0)
    num_Ti = structure.composition.get("Ti", 0)
    num_O = structure.composition.get("O", 0)

    corrected_energy = (
        vasp_raw_energy + num_Mn * Mn_correction_in_TMO + num_O * oxide_correction
    )

    corrected_energies.append(corrected_energy / structure.num_sites)

print(len(corrected_energies))

531


In [6]:
all_force_x = []
all_force_y = []
all_force_z = []

for force_array in all_forces:
    for x, y, z in force_array:
        all_force_x.append(x)
        all_force_y.append(y)
        all_force_z.append(z)

print(f"all_force_x: {len(all_force_x)} values")
print(f"all_force_y: {len(all_force_y)} values")
print(f"all_force_z: {len(all_force_z)} values")

all_force_x: 23584 values
all_force_y: 23584 values
all_force_z: 23584 values


In [7]:
flattened_stresses = []

for stress_matrix in all_stresses:
    for row in stress_matrix:
        flattened_stresses.extend(row)

print(f"Total number of elements: {len(flattened_stresses)}")

flattened_stresses = [-0.1*x for x in flattened_stresses]

Total number of elements: 4779


In [8]:
flattened_magmoms = []

for magmom_list in all_magmoms:
    flattened_magmoms.extend(magmom_list)

print(f"Total number of elements: {len(flattened_magmoms)}")

flattened_magmoms = [abs(x) for x in flattened_magmoms]

Total number of elements: 23584


# Finetuning

In [9]:
from chgnet.data.dataset import StructureData, get_train_val_test_loader
from chgnet.model import CHGNet
from chgnet.trainer import Trainer

# Load pretrained CHGNet
chgnet = CHGNet.load()

train_loader, val_loader, test_loader = get_train_val_test_loader(
    dataset, batch_size=32, train_ratio=0.9, val_ratio=0.05
)

# Optionally fix the weights of some layers
for layer in [
    chgnet.atom_embedding,
    chgnet.bond_embedding,
    chgnet.angle_embedding,
    chgnet.bond_basis_expansion,
    chgnet.angle_basis_expansion,
    chgnet.atom_conv_layers[:-1],
    chgnet.bond_conv_layers,
    chgnet.angle_layers,
]:
    for param in layer.parameters():
        param.requires_grad = False

# Define Trainer
trainer = Trainer(
    model=chgnet,
    targets="efs",
    optimizer="Adam",
    scheduler="CosLR",
    criterion="MSE",
    epochs=5,
    learning_rate=1e-2,
    use_device="cpu",
    print_freq=1,
)

CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu


In [None]:
trainer.train(train_loader, val_loader, test_loader, save_dir = "finetune_result")

Begin Training: using cpu device
training targets: efs
