In [1]:
from DistMLIP.implementations.matgl import CHGNet_Dist, Potential_Dist, MolecularDynamics, Relaxer
import matgl
from pymatgen.core import Structure, Lattice
from pymatgen.io.ase import AseAtomsAdaptor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Initialize a CHGNet_Dist model (because CHGNet_Dist inherits from matgl's CHGNet class, model finetuning, training, etc. all work the exact same as in MatGL)
model = CHGNet_Dist()

# Load a previous CHGNet model from matgl
model = matgl.load_model("CHGNet-MatPES-PBE-2025.2.10-2.7M-PES").model

# Make a distributed version of the model
dist_model = CHGNet_Dist.from_existing(model)

# Enable distributed mode for 2 GPUs
dist_model.enable_distributed_mode([6, 7]) 

In [3]:
# Insert your atoms
struct = Structure.from_spacegroup("Pm-3m", Lattice.cubic(5.5), ["Li", "Mn"], [[0, 0, 0], [0.5, 0.5, 0.5]])
struct.perturb(0.5)
struct.make_supercell((20, 20, 20))
print(f"There are {len(struct)} atoms.")
atoms = AseAtomsAdaptor().get_atoms(struct)


There are 16000 atoms.


In [5]:
# Create Potential_Dist object, use 128 threads when creating graph structures
potential = Potential_Dist(model=dist_model, num_threads=128)

In [None]:
# Run structure relaxation
relaxer = Relaxer(
    potential=potential,
    optimizer="FIRE",
    relax_cell=True
)

results = relaxer.relax(atoms, verbose=True, steps=200)

  atoms.set_calculator(self.calculator)


      Step     Time          Energy          fmax
FIRE:    0 15:47:49    16332.768555        0.244577
FIRE:    1 15:47:50    16332.768555        0.244577
FIRE:    2 15:47:50    16332.768555        0.244580
FIRE:    3 15:47:50    16332.768555        0.244583
FIRE:    4 15:47:50    16332.768555        0.244587
FIRE:    5 15:47:50    16332.768555        0.244591
FIRE:    6 15:47:51    16332.767578        0.244596
FIRE:    7 15:47:51    16332.767578        0.244604
FIRE:    8 15:47:51    16332.731445        0.244612
FIRE:    9 15:47:51    16332.728516        0.244625
FIRE:   10 15:47:51    16332.728516        0.244638
FIRE:   11 15:47:52    16332.726562        0.244656
FIRE:   12 15:47:52    16332.725586        0.244677
FIRE:   13 15:47:52    16332.714844        0.244704
FIRE:   14 15:47:52    16332.713867        0.244738
FIRE:   15 15:47:52    16332.709961        0.244779
FIRE:   16 15:47:53    16332.708984        0.244826
FIRE:   17 15:47:53    16332.705078        0.244873
FIRE:   18 15:

In [None]:
# Run molecular dynamics
driver = MolecularDynamics(
    atoms,
    potential=potential,
    timestep=0.5,
    temperature=300,
    loginterval=200,
    logfile="logfile.log",
    trajectory="traj.trj"
)

driver.run(20)