LOADING CHGNet

In [1]:
import numpy as np
from pymatgen.core import Structure

from chgnet.model import CHGNet

# If the above line fails in Google Colab due to numpy version issue,
# please restart the runtime, and the problem will be solved

np.set_printoptions(precision=4, suppress=True)

chgnet = CHGNet.load()
chgnet.eval()
# Alternatively you can read your own model
# chgnet = CHGNet.from_file(model_path)

CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu


CHGNet(
  (composition_model): AtomRef(
    (fc): Linear(in_features=94, out_features=1, bias=False)
  )
  (graph_converter): CrystalGraphConverter(algorithm='fast', atom_graph_cutoff=6, bond_graph_cutoff=3)
  (atom_embedding): AtomEmbedding(
    (embedding): Embedding(94, 64)
  )
  (bond_basis_expansion): BondEncoder(
    (rbf_expansion_ag): RadialBessel(
      (smooth_cutoff): CutoffPolynomial()
    )
    (rbf_expansion_bg): RadialBessel(
      (smooth_cutoff): CutoffPolynomial()
    )
  )
  (bond_embedding): Linear(in_features=31, out_features=64, bias=False)
  (bond_weights_ag): Linear(in_features=31, out_features=64, bias=False)
  (bond_weights_bg): Linear(in_features=31, out_features=64, bias=False)
  (angle_basis_expansion): AngleEncoder(
    (fourier_expansion): Fourier()
  )
  (angle_embedding): Linear(in_features=31, out_features=64, bias=False)
  (atom_conv_layers): ModuleList(
    (0-3): 4 x AtomConv(
      (activation): SiLU()
      (twoBody_atom): GatedMLP(
        (mlp_c

Predicting Magnetic Moments Using CHGNet

In [3]:
import csv
import re
import numpy as np
from tqdm import tqdm

from pymatgen.core import Structure, Lattice
from chgnet.model.model import CHGNet


# ============================================================
# CIF string → pymatgen Structure
# ============================================================
def structure_from_str(cif_string):
    lines = cif_string.strip().split("\n")

    # lattice parameters (fixed format in your CSV)
    a, b, c = map(float, lines[2].split()[2:5])
    alpha, beta, gamma = map(float, lines[3].split()[1:4])

    species, coords = [], []
    reading_atoms = False

    for line in lines:
        line = line.strip()

        # atom line: index element x y z
        if re.match(r"^\d+\s+[A-Za-z]+\s+[-\d\.Ee+]+\s+[-\d\.Ee+]+\s+[-\d\.Ee+]+", line):
            reading_atoms = True

        if not reading_atoms:
            continue

        parts = line.split()
        if len(parts) < 5:
            continue

        species.append(parts[1])
        coords.append(list(map(float, parts[2:5])))

    lattice = Lattice.from_parameters(a, b, c, alpha, beta, gamma)
    return Structure(lattice, species, coords)


# ============================================================
# Load pretrained CHGNet (v0.3.0)
# ============================================================
model = CHGNet.load()
model.eval()


# ============================================================
# Files
# ============================================================
INPUT_CSV = "my_new_data1/co_fm.csv"
OUTPUT_CSV = "my_new_data1/co_fm_chgnet_magmom.csv"


# ============================================================
# Run inference
# ============================================================
results = []

with open(INPUT_CSV, encoding="utf-8-sig") as f:
    rows = list(csv.reader(f))

for idx, row in tqdm(enumerate(rows), total=len(rows)):
    try:
        cif_id = row[0].strip().lstrip("\ufeff")
        cif_string = row[1]

        structure = structure_from_str(cif_string)

        # em = energy + magnetic moments
        pred = model.predict_structure(structure, task="em")

        site_magmoms = np.asarray(pred["m"], dtype=float)  # <-- CORRECT KEY
        total_magmom = float(site_magmoms.sum())

        results.append([
            cif_id,
            total_magmom,
            site_magmoms.tolist()
        ])

    except Exception as e:
        print(f"[ERROR] row {idx} ({row[0]}): {e}")
        results.append([row[0], None, None])


# ============================================================
# Save output
# ============================================================
with open(OUTPUT_CSV, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow([
        "cif_id",
        "chgnet_total_magmom_muB",
        "chgnet_site_magmoms_muB"
    ])
    writer.writerows(results)

print(f"\n✓ DONE. Saved → {OUTPUT_CSV}")


CHGNet v0.3.0 initialized with 412,525 parameters
CHGNet will run on cpu


100%|██████████| 2944/2944 [02:37<00:00, 18.67it/s]


✓ DONE. Saved → my_new_data1/co_fm_chgnet_magmom.csv





Extracting Magnetization from Material Project (MP)

In [4]:
from pymatgen.ext.matproj import MPRester
import pandas as pd

df1 = pd.read_csv("my_new_data1/co_fm_chgnet_magmom.csv")
material_ids = df1.iloc[:, 0].tolist()

api_key = "h8IvqxfS25TD7LzN"
m = MPRester(api_key)

# Define properties to fetch
properties = [
    "task_id", "total_magnetization", "volume"
    #, "total_magnetization", "structure",
    #"spacegroup.crystal_system", "nsites", "spacegroup.number",
    #"magnetism.ordering", "formation_energy_per_atom"
]

# Fetch data for specific material_ids
data_s = m.query(criteria={"task_id": {"$in": material_ids}}, properties=properties)

# Convert to DataFrame
df1 = pd.DataFrame(data_s)

# Save results
df1.to_csv("MP_co_fm_chgnet_magmom.csv", index=False)
print(f"Fetched {df1.shape[0]} materials. Saved to 'saved'")

100%|██████████| 2667/2667 [00:04<00:00, 601.82it/s]

Fetched 2667 materials. Saved to 'saved'





Comparing Compounds Total Magnetization from CHGNet(Summed up magnetic moments) and Total Magnetization from MP

In [None]:
# ============================================================
# STEP 1–8: Load, merge, normalize, and evaluate CHGNet vs MP
# ============================================================

import pandas as pd
import numpy as np
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr

# -------------------------------
# STEP 1 — Load CSV files
# -------------------------------
chgnet_df = pd.read_csv("co_fm_chgnet_magmom.csv")
mp_df = pd.read_csv("MP_co_fm_chgnet_magmom.csv")

# -------------------------------
# STEP 2 — Standardize MP IDs
# -------------------------------
# CHGNet file: first column is mp-id (may contain BOM)
chgnet_df.rename(columns={chgnet_df.columns[0]: "mp_id"}, inplace=True)
chgnet_df["mp_id"] = chgnet_df["mp_id"].astype(str).str.strip().str.lstrip("\ufeff")

# MP file
mp_df.rename(columns={"task_id": "mp_id"}, inplace=True)
mp_df["mp_id"] = mp_df["mp_id"].astype(str).str.strip()

# -------------------------------
# STEP 3 — Keep required MP data
# -------------------------------
mp_df = mp_df[[
    "mp_id",
    "total_magnetization",
    "volume"
]].rename(columns={
    "total_magnetization": "mp_total_magmom_muB",
    "volume": "mp_volume_A3"
})

# -------------------------------
# STEP 4 — Inner join on mp_id
# -------------------------------
merged = pd.merge(
    chgnet_df,
    mp_df,
    on="mp_id",
    how="inner"
)

print(f"Merged samples: {len(merged)}")

# -------------------------------
# STEP 5 — Volume-normalized magnetization
# -------------------------------
merged["chgnet_magmom_muB_per_A3"] = (
    merged["chgnet_total_magmom_muB"] / merged["mp_volume_A3"]
)

merged["mp_magmom_muB_per_A3"] = (
    merged["mp_total_magmom_muB"] / merged["mp_volume_A3"]
)

# -------------------------------
# STEP 6 — Drop invalid entries
# -------------------------------
valid = merged.replace([np.inf, -np.inf], np.nan).dropna(
    subset=[
        "chgnet_magmom_muB_per_A3",
        "mp_magmom_muB_per_A3"
    ]
)

print(f"Valid samples: {len(valid)}")

# -------------------------------
# STEP 7 — Metrics
# -------------------------------
y_true = valid["mp_magmom_muB_per_A3"].values
y_pred = valid["chgnet_magmom_muB_per_A3"].values

r2 = r2_score(y_true, y_pred)
cc, _ = pearsonr(y_true, y_pred)

# -------------------------------
# STEP 8 — Report
# -------------------------------
print(f"R²  = {r2:.4f}")
print(f"CC  = {cc:.4f}")

mae = mean_absolute_error(y_true, y_pred)

print(f"MAE = {mae:.6f} μB/Å³")

