In [7]:
from __future__ import annotations

import argparse
import copy
import logging
import numpy as np
import os
import sys
import time
import torch
import yaml
from tqdm import tqdm

import pytorch_lightning as pl
from torch import nn
from torch_geometric.data import Batch

from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes
from ase.calculators.singlepoint import SinglePointCalculator as sp
from ase.constraints import FixAtoms
from ase.io import read, write
from ase.md.langevin import Langevin
from ase.md.nptberendsen import NPTBerendsen

from matsciml.common.registry import registry
from matsciml.common.utils import radius_graph_pbc, setup_imports, setup_logging
from matsciml.datasets.transforms import (
    PeriodicPropertiesTransform,
    PointCloudToGraphTransform,
    FrameAveraging,
)
from matsciml.datasets.trajectory_lmdb import data_list_collater
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models.base import ScalarRegressionTask
from matsciml.models.utils.io import *
from matsciml.preprocessing.atoms_to_graphs import *


In [8]:
checkpoint_path = "/home/m3rg2000/Simulation/checkpoints-2024/mace_fr.ckpt"
Loaded_model = multitask_from_checkpoint(checkpoint_path)

No ``atomic_energies`` provided, defaulting to ones.
/home/m3rg2000/miniconda3/envs/matsciml/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'gate' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['gate'])`.
No ``atomic_energies`` provided, defaulting to ones.
/home/m3rg2000/miniconda3/envs/matsciml/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:199: Attribute 'gate' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['gate'])`.


In [9]:

a2g=AtomsToGraphs(max_neigh=200,
            radius=6,
            r_energy=False,
            r_forces=False,
            r_distances=False,
            r_edges=True,
            r_fixed=True,)


PBCTransform=PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True)

def convAtomstoBatch(atoms):
    data_obj=a2g.convert(atoms)
    Reformatted_batch={
        'cell' : data_obj.cell,
        'natoms' :  torch.Tensor([data_obj.natoms]).unsqueeze(0),
        'edge_index' : data_obj.edge_index,
        'cell_offsets': data_obj.cell_offsets,
        'atomic_numbers': data_obj.atomic_numbers,
        'y' : None,
        'pos' : data_obj.pos,
        'force' : None, 
        'fixed' : [data_obj.fixed],
        'tags' : None,
        'sid' :None,
        'fid' : None,
        'dataset' : 'S2EFDataset',
        'graph' : data_list_collater([data_obj]),
    }
    Reformatted_batch=PBCTransform(Reformatted_batch)
    return Reformatted_batch

In [10]:
class MACE_ASEcalculator(Calculator):
    """Simulation ASE Calculator"""

    implemented_properties = ["energy" , "forces", "stress"]

    def __init__(
        self,
        model,
        **kwargs
    ):
        Calculator.__init__(self, **kwargs)
        self.results = {}

        self.model = model
        
        
    # pylint: disable=dangerous-default-value
    def calculate(self, atoms=None, properties=None, system_changes=all_changes):
        """
        Calculate properties.
        :param atoms: ase.Atoms object
        :param properties: [str], properties to be computed, used by ASE internally
        :param system_changes: [str], system changes since last calculation, used by ASE internally
        :return:
        """
        # call to base-class to set atoms attribute
        Calculator.calculate(self, atoms)

        # prepare data
        batch=convAtomstoBatch(atoms)

        # predict + extract data
        out = self.model.forward(batch)
        energy = out['regression0']['corrected_total_energy'].detach().cpu().item()
        forces = out['force_regression0']["force"].detach().cpu().numpy()
        stress = out['force_regression0']["stress"].squeeze(0).detach().cpu().numpy()
        # store results
        E = energy
        stress= np.array([stress[0, 0],
                                   stress[1, 1],
                                   stress[2, 2],
                                   stress[1, 2],
                                   stress[0, 2],
                                   stress[0, 1]])
        self.results = {
            "energy": E,
            # force has units eng / len:
            "forces": forces,
            "stress" : stress,
        }


In [13]:
def run_simulation(calculator,atoms, pressure=1.01325, temperature=298, timestep=0.1, steps=10,writepath=None,writeTraj=False):
         
    # Define the temperature and pressure
    
    init_conf=atoms
    init_conf.set_calculator(calculator)
    # Initialize the NPT dynamics
    dyn = NPTBerendsen(init_conf, timestep=timestep * units.fs, temperature_K=temperature,
                       taut=100 * units.fs, pressure_au=pressure * units.bar,
                       taup=1000 * units.fs, compressibility_au=4.57e-5 / units.bar)
    density=[]
    angles=[]
    lattice_parameters=[]
    
    def write_frame():
        if(writepath!=None):
            dyn.atoms.write(writepath, append=True)
    Traj=[]
    if writeTraj:
        def recordTraj(a=atoms):
            Traj.append(a.copy())
        dyn.attach(recordTraj, interval=1)
    dyn.attach(write_frame, interval=1)
    dyn.run(steps)
    
    # Calculate average values
    
    return  Traj


    


In [14]:
start_time=time.time()

calculator = MACE_ASEcalculator(Loaded_model)
#calculator.model.double() # change model weights type to double precision(hack to avoid error)
# Path to the CIF file
#cif_file_path = "/home/m3rg2000/Simulation/checkpoints-2024/Example.cif"
cif_file_path = "/home/m3rg2000/Simulation/checkpoints-2024/S2EF.extxyz"

# Read CIF file using ASE
input_atoms = read(cif_file_path)

run_simulation(calculator,input_atoms)

print(f"MD finished!")


  (self.temperature / old_temperature - 1.0) *


MD finished!


In [None]:
batch=convAtomstoBatch(atoms)

In [None]:
Res=Loaded_model(batch)

In [None]:
Loaded_model