In [None]:
# !pip install chgnet
# !pip install matgl
# !pip install matpes

In [None]:
# %load ../src/data_preprocessing.py
import numpy as np
import matplotlib.pyplot as plt
import json 
import os
import sys
from pymatgen.core.structure import Structure
from chgnet.data.dataset import StructureData, get_train_val_test_loader
import torch
from typing import List, Optional, Union

class Filter:
    
    """
    A class to filter a dataset of materials based on various criteria.
    Attributes:
        data (List[dict]): A list of dictionaries containing material data.
    
    Methods:
        filter: Filters the dataset based on specified criteria.
        _filter_by_material_type: Filters the dataset by material type.
    
    Example:
        data = [
            {"elements": ["Fe", "O"], "nelements": 2, "bandgap": 1.5, "symmetry": {"crystal_system": "cubic"}},
            {"elements": ["Si"], "nelements": 1, "bandgap": 1.1, "symmetry": {"crystal_system": "cubic"}},
            {"elements": ["Na", "Cl"], "nelements": 2, "bandgap": 0.0, "symmetry": {"crystal_system": "cubic"}},
        ]
        filter = Filter(data)
        filtered_data = filter.filter(n_elements=2, contains_elements=["Fe"], bandgap_min=1.0)
    """
    
    def __init__(self, data: List[dict]):
        self.data = data

    def filter(
        self,
        n_elements: Optional[int] = None,
        contains_elements: Optional[List[str]] = None,
        excludes_elements: Optional[List[str]] = None,
        crystal_system: Optional[str] = None,
        bandgap_min: Optional[float] = None,
        bandgap_max: Optional[float] = None,
        material_type: Optional[str] = None  # e.g., 'oxide', 'intermetallic', etc.
    ) -> List[dict]:
        results = self.data

        if n_elements is not None:
            results = [d for d in results if d.get("nelements") == n_elements]

        if contains_elements is not None:
            contains_set = set(contains_elements)
            results = [d for d in results if contains_set.issubset(set(d.get("elements", [])))]

        if excludes_elements is not None:
            excludes_set = set(excludes_elements)
            results = [d for d in results if not excludes_set.intersection(set(d.get("elements", [])))]

        if crystal_system is not None:
            results = [d for d in results if d.get("symmetry", {}).get("crystal_system") == crystal_system.lower()]

        if bandgap_min is not None or bandgap_max is not None:
            bandgap_min = bandgap_min if bandgap_min is not None else 0.0
            bandgap_max = bandgap_max if bandgap_max is not None else float("inf")
            results = [d for d in results if bandgap_min <= d.get("bandgap", 0.0) <= bandgap_max]

        if material_type is not None:
            results = self._filter_by_material_type(results, material_type.lower())

        return results

    def _filter_by_material_type(self, dataset: List[dict], mtype: str) -> List[dict]:
        if mtype == "oxide":
            return [d for d in dataset if "O" in d.get("elements", [])]
        elif mtype == "transition-metal-oxide":
            tmos = {
                    'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn',
                    'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd',
                    'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg'
                    }
            return [d for d in dataset if "O" in d.get("elements", []) and any(tm in tmos for tm in d.get("elements", []))]
        elif mtype == "nitride":
            return [d for d in dataset if "N" in d.get("elements", [])]
        elif mtype == "sulfide":
            return [d for d in dataset if "S" in d.get("elements", [])]
        elif mtype == "halide":
            halogens = {"F", "Cl", "Br", "I"}
            return [d for d in dataset if halogens.intersection(d.get("elements", []))]
        elif mtype == "intermetallic":
            non_metals = {"H", "B", "C", "N", "O", "F", "P", "S", "Cl", "Se", "Br", "I", "At", "Ts"}
            return [d for d in dataset if not non_metals.intersection(d.get("elements", []))]
        elif mtype == "metal":
            return [d for d in dataset if d.get("bandgap", None) == 0.0]
        else:
            raise ValueError(f"Unknown material type: {mtype}")
        
class DataExtracter:
    
    """
    A class to extract all the relevant information you need to fine-tune CHGNet (note that magmoms and stresses are optional but forces, energies and structures are required).
    Attributes:
        data (List[dict]): A list of dictionaries containing material data.
    Methods:
        get_energies_per_atom: Returns a list of energies per atom.
        get_forces: Returns a list of forces.
        get_strucs: Returns a list of structures.
        get_magmoms: Returns a list of magnetic moments.
        get_stresses: Returns a list of stresses.
    """

    def __init__(self, data: List[dict]):
        self.data = data

    def get_energies_per_atom(self) -> List[float]:
        total_energies =  [d.get("energy", 0.0) for d in self.data]
        number_atoms = [d.get("nsites", 0) for d in self.data]
        per_atom_energies = [e / n if n > 0 else 0.0 for e, n in zip(total_energies, number_atoms)]
        return per_atom_energies
    
    def get_forces(self) -> List[List[float]]:
        return [d.get("forces", []) for d in self.data]
    
    def get_strucs(self) -> List[Structure]:
        return [Structure.from_dict(d.get("structure", {})) for d in self.data]
    
    def get_magmoms(self) -> List[float]:
        return [d.get("magmoms", []) for d in self.data]
    
    def get_stresses(self) -> List[List[float]]:
        return [d.get("stress", []) for d in self.data]

In [3]:
import sys
import os

# Add the 'data' directory to the Python path
sys.path.append(os.path.abspath("../src"))

# Now you can import your script
# import data_preprocessing

In [4]:
from pymatgen.core import Lattice, Structure
import numpy as np
import torch
import pandas as pd

import gzip
import json

# MatGL, a package that wraps all the ML-IAPs we want to use
import matgl

# MatPES, for downloading the datasets we care about
import matpes

print(f'MatPES ver: {matpes.__version__}')
print(f'Torch ver: {torch.__version__}')
print(f'MatGL ver: {matgl.__version__}')

MatPES ver: 0.0.3
Torch ver: 2.6.0+cpu
MatGL ver: 1.2.6


In [5]:
# Explore what model options we have available to us via the MatGL package

for model in matgl.get_available_pretrained_models():
    print(model)

CHGNet-MPtrj-2023.12.1-2.7M-PES
CHGNet-MPtrj-2024.2.13-11M-PES
CHGNet-MatPES-PBE-2025.2.10-2.7M-PES
CHGNet-MatPES-r2SCAN-2025.2.10-2.7M-PES
M3GNet-MP-2018.6.1-Eform
M3GNet-MP-2021.2.8-DIRECT-PES
M3GNet-MP-2021.2.8-PES
M3GNet-MatPES-PBE-v2025.1-PES
M3GNet-MatPES-r2SCAN-v2025.1-PES
MEGNet-MP-2018.6.1-Eform
MEGNet-MP-2019.4.1-BandGap-mfi
TensorNet-MatPES-PBE-v2025.1-PES
TensorNet-MatPES-r2SCAN-v2025.1-PES


In [6]:
# assumes that the data has been downloaded into a neighboring '/data' directory

def load_PBE_data(full_dataset, target_system):
    """
    Inputs:
    full_dataset (file path) : path to raw PBE dataset, in .json.gz format
    target_dataset (str): a desired filter of the overall dataset based on chemical system

    Returns:
    filtered_data ()
    """

    if type(target_system) != str:
        TypeError("Target system should be a string, e.g. 'Fe-O'!")

    with gzip.open(full_dataset, 'rt', encoding='utf-8') as f:
        full_dataset = json.load(f)

    return full_dataset

In [7]:
# full_data = load_PBE_data("..\data\MatPES-PBE-2025.1.json.gz",'Fe-O')

# print(f'The complete dataset contains {len(full_data)} 300K MD simulations and Materials Project ground state calculations')

In [8]:
# for row in full_data[::50]:
#     print(row)