In [1]:
from DistMLIP.implementations.matgl import CHGNet_Dist, Potential_Dist, MolecularDynamics, Relaxer
import matgl
from pymatgen.core import Structure, Lattice
from pymatgen.io.ase import AseAtomsAdaptor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Initialize a CHGNet_Dist model (because CHGNet_Dist inherits from matgl's CHGNet class, model finetuning, training, etc. all work the exact same as in MatGL)
model = CHGNet_Dist()

# Load a previous CHGNet model from matgl
model = matgl.load_model("CHGNet-MatPES-PBE-2025.2.10-2.7M-PES").model

# Make a distributed version of the model
dist_model = CHGNet_Dist.from_existing(model)

# Enable distributed mode for 2 GPUs
dist_model.enable_distributed_mode([6, 7]) 

ModuleList(
  (0-3): 4 x CHGNetBondGraphBlock(
    (conv_layer): CHGNetLineGraphConv(
      (node_update_func): GatedMLP_norm(
        (layers): MLP_norm(
          (layers): ModuleList(
            (0): Linear(in_features=512, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
          )
          (norm_layers): ModuleList(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          )
          (activation): SiLU()
        )
        (gates): MLP_norm(
          (layers): ModuleList(
            (0): Linear(in_features=512, out_features=128, bias=True)
            (1): Linear(in_features=128, out_features=128, bias=True)
          )
          (norm_layers): ModuleList(
            (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          )
          (activation): SiLU()
        )
        (sigmoid): Sigmoid()
      )
      (node_out_func): Linear(in_features=128, out_features=128, bias=False)
      (edge_up

In [3]:
# Insert your atoms
struct = Structure.from_spacegroup("Pm-3m", Lattice.cubic(5.5), ["Li", "Mn"], [[0, 0, 0], [0.5, 0.5, 0.5]])
struct.perturb(0.5)
struct.make_supercell((10, 10, 10))
print(f"There are {len(struct)} atoms.")
atoms = AseAtomsAdaptor().get_atoms(struct)

There are 2000 atoms.


In [4]:
# Create Potential_Dist object, use 128 threads when creating graph structures
potential = Potential_Dist(model=dist_model, num_threads=128)

In [17]:
# Perform static point calculation
output = potential(atoms)



In [None]:
# Run structure relaxation
relaxer = Relaxer(
    potential=potential,
    optimizer="FIRE",
    relax_cell=True
)

results = relaxer.relax(atoms, verbose=True, steps=200)

In [None]:
# Run molecular dynamics
driver = MolecularDynamics(
    atoms,
    potential=potential,
    timestep=0.5,
    temperature=300,
    loginterval=200,
    logfile="logfile.log",
    trajectory="traj.trj"
)

driver.run(20)