In [1]:
from __future__ import annotations

import os
import sys
import time
import copy
import yaml
import logging
import argparse
from ase import Atoms
from ase.io import read

import numpy as np
import torch
import pytest

from ase import Atoms, units
from ase.md.langevin import Langevin
from ase.io import read, write
from ase.constraints import FixAtoms
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from ase.calculators.singlepoint import SinglePointCalculator as sp

from tqdm import tqdm
import pytorch_lightning as pl

from matsciml.common.registry import registry
from matsciml.common.utils import radius_graph_pbc, setup_imports, setup_logging
from matsciml.datasets.transforms import (
    PeriodicPropertiesTransform,
    PointCloudToGraphTransform,
    FrameAveraging,
)
from matsciml.datasets.trajectory_lmdb import data_list_collater
from matsciml.lightning import MatSciMLDataModule
from matsciml.models.pyg import FAENet
from matsciml.models.base import ForceRegressionTask
from matsciml.models.utils.io import multitask_from_checkpoint
from matsciml.preprocessing.atoms_to_graphs import *
from ase import units
from ase.md.verlet import VelocityVerlet
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.io import read, write
import numpy as np
import time
import torch
import argparse
import sys
from tqdm import tqdm
from ase.io import read
from ase.neighborlist import neighbor_list


In [2]:
a2g=AtomsToGraphs(max_neigh=200,
            radius=6,
            r_energy=False,
            r_forces=False,
            r_distances=False,
            r_edges=True,
            r_fixed=True,)
f_avg=FrameAveraging(frame_averaging="3D", fa_method="stochastic")

def convAtomstoBatch(atoms):
    data_obj=a2g.convert(atoms)
    Reformatted_batch={
        'cell' : data_obj.cell,
        'natoms' :  torch.Tensor([data_obj.natoms]).unsqueeze(0),
        'edge_index' : [data_obj.edge_index.shape],
        'cell_offsets': data_obj.cell_offsets,
        'y' : None,
        'force' : None, 
        'fixed' : [data_obj.fixed],
        'tags' : None,
        'sid' :None,
        'fid' : None,
        'dataset' : 'S2EFDataset',
        'graph' : data_list_collater([data_obj]),
    }
    Reformatted_batch=f_avg(Reformatted_batch)
    return Reformatted_batch


In [3]:
def convBatchtoAtoms(batch):
    # data_obj=a2g.convert(atoms)
    curr_atoms = Atoms(
            positions=batch['graph'].pos,
            cell = batch['cell'][0],
            numbers=batch['graph'].atomic_numbers,
            pbc=True) # True or false
    
    return curr_atoms

In [4]:

class FAENet_ASEcalculator(Calculator):
    """Simulation ASE Calculator"""

    implemented_properties = ["energy", "free_energy", "forces", "stress"]

    def __init__(
        self,
        model,
        **kwargs
    ):
        Calculator.__init__(self, **kwargs)
        self.results = {}

        self.model = model
        
        
    # pylint: disable=dangerous-default-value
    def calculate(self, atoms=None, properties=None, system_changes=all_changes):
        """
        Calculate properties.
        :param atoms: ase.Atoms object
        :param properties: [str], properties to be computed, used by ASE internally
        :param system_changes: [str], system changes since last calculation, used by ASE internally
        :return:
        """
        # call to base-class to set atoms attribute
        Calculator.calculate(self, atoms)

        # prepare data
        batch=convAtomstoBatch(atoms)

        # predict + extract data
        out = self.model(batch)
        energy = out['force_regression0']["energy"].detach().cpu().item()
        forces = out['force_regression0']["force"].detach().cpu().numpy()
        stress = out['force_regression0']["stress"].squeeze(0).detach().cpu().numpy()
        # store results
        E = energy
        stress= np.array([stress[0, 0],
                                   stress[1, 1],
                                   stress[2, 2],
                                   stress[1, 2],
                                   stress[0, 2],
                                   stress[0, 1]])
        self.results = {
            "energy": E,
            # force has units eng / len:
            "forces": forces,
            "stress" : stress,
        }


In [5]:
from matsciml.models.utils.io import * 
checkpoint_path = "/home/m3rg2000/Simulation/checkpoints-2024/FAENet_250k.ckpt"
Loaded_model = multitask_from_checkpoint(checkpoint_path)

In [7]:
# class StabilityException(Exception):
#     pass

# start_time = time.time()
# traj = []
# calculator = FAENet_ASEcalculator(Loaded_model)
# cif_file_path = "/home/m3rg2000/Simulation/checkpoints-2024/Example.cif"
# atoms = read(cif_file_path)
# atoms.set_calculator(calculator)
# MaxwellBoltzmannDistribution(atoms, temperature_K=298)
# initial_energy = atoms.get_total_energy()
# dyn = VelocityVerlet(atoms, dt=1*units.fs)
# def write_frame(a=atoms):
#     a.write('md_FaeNet_nve_CIF.xyz', append=True)
#     traj.append(a.copy())
# dyn.attach(write_frame, interval=1)


# def energy_criterion(atoms, initial_energy, tolerance=0.10):
#     current_energy = atoms.get_total_energy()
#     lower_bound = initial_energy * (1 - tolerance)
#     upper_bound = initial_energy * (1 + tolerance)
#     print('energy:',(abs(current_energy-initial_energy)/initial_energy))
#     return lower_bound <= current_energy <= upper_bound

# # Stability check function to stop the simulation if the criterion is violated
# def check_stability(a=atoms):
#     if not energy_criterion(a, initial_energy, tolerance=0.10):
#         # print(total_energy-initial_energy)
#         # print()
#         raise StabilityException("Energy_criterion violated. Stopping the simulation.")

# # Attach the stability check function to the dynamics
# dyn.attach(check_stability, interval=100)




# def calculate_rmsd(traj):
#     initial_positions = traj[0].get_positions()
#     N = len(traj[0])
#     T = len(traj)
#     displacements = np.zeros((N, T, 3))
    
#     for t in range(T):
#         current_positions = traj[t].get_positions()
#         displacements[:, t, :] = current_positions - initial_positions

#     msd = np.mean(np.sum(displacements**2, axis=2), axis=1)
#     rmsd = np.sqrt(msd)
#     return rmsd
    
# def calculate_average_nn_distance(atoms):
#     i, j, _ = neighbor_list('ijd', atoms, cutoff=5.0)  ### What cutoff can be used?
#     distances = atoms.get_distances(i, j, mic=True)
#     return np.mean(distances)



# def lindemann_stability():
#     rmsd = calculate_rmsd(traj[-1000:])
#     avg_nn_distance = calculate_average_nn_distance(traj[0])
#     lindemann_coefficient = np.mean(rmsd) / avg_nn_distance
#     print('lindemann_stability:',lindemann_coefficient)
#     if lindemann_coefficient>0.1:
#         # print("lindemann_stability criterion violated. Stopping the simulation.")
#         raise StabilityException("lindemann_stability criterion violated. Stopping the simulation.")


# dyn.attach(lindemann_stability, interval=1000)

# try:
#     # Run the simulation for 100 steps or until the stability criterion is violated
#     dyn.run(2000)
# except StabilityException as e:
#     print(e)

# print("MD finished!")


  cell = torch.Tensor(atoms.get_cell()).view(1, 3, 3)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


energy: 0.0
lindemann_stability: 0.0
energy: 0.009427899751892893
energy: 0.006626798425159848
energy: 0.009961618666532161
energy: 0.020098465986984105
energy: 0.021908535909705592
energy: 0.018126575660275136
energy: 0.011143645029550961
energy: 0.022423224050134798
energy: 0.024727013296317834
energy: 0.023154431255706844
lindemann_stability: 0.8795049631698294
lindemann_stability criterion violated. Stopping the simulation.
MD finished!


In [9]:
import os
import time
import numpy as np
from ase.io import read
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.verlet import VelocityVerlet
from ase import units
import matplotlib.pyplot as plt
from ase.io.formats import UnknownFileTypeError

class StabilityException(Exception):
    pass

def run_simulation(atoms):
    try:
        start_time = time.time()
        traj = []
        calculator = FAENet_ASEcalculator(Loaded_model)
        atoms.set_calculator(calculator)
        MaxwellBoltzmannDistribution(atoms, temperature_K=298)
        initial_energy = atoms.get_total_energy()
        dyn = VelocityVerlet(atoms, dt=1*units.fs)

        def write_frame(a=atoms):
            # a.write('md_FaeNet_nve_CIF.xyz', append=True)
            traj.append(a.copy())

        dyn.attach(write_frame, interval=1)

        def energy_criterion(atoms, initial_energy, tolerance=0.10):
            current_energy = atoms.get_total_energy()
            lower_bound = initial_energy * (1 - tolerance)
            upper_bound = initial_energy * (1 + tolerance)
            print('energy:', (abs(current_energy-initial_energy)/initial_energy))
            return lower_bound <= current_energy <= upper_bound

        def check_stability(a=atoms):
            if not energy_criterion(a, initial_energy, tolerance=0.10):
                raise StabilityException("Energy_criterion violated. Stopping the simulation.")

        dyn.attach(check_stability, interval=20)   ### energy dumping interval

        def calculate_rmsd(traj):
            initial_positions = traj[0].get_positions()
            N = len(traj[0])
            T = len(traj)
            displacements = np.zeros((N, T, 3))
            
            for t in range(T):
                current_positions = traj[t].get_positions()
                displacements[:, t, :] = current_positions - initial_positions
            msd = np.mean(np.sum(displacements**2, axis=2), axis=1)
            rmsd = np.sqrt(msd)
            return rmsd

        def calculate_average_nn_distance(atoms):
            i, j, _ = neighbor_list('ijd', atoms, cutoff=5.0)  ## Cutoff change
            distances = atoms.get_distances(i, j, mic=True)
            return np.mean(distances)

        def lindemann_stability():
            if len(traj) >= 50:
                rmsd = calculate_rmsd(traj[-50:])
                avg_nn_distance = calculate_average_nn_distance(traj[0])
                lindemann_coefficient = np.mean(rmsd) / avg_nn_distance
                print('lindemann_stability:', lindemann_coefficient)
                if lindemann_coefficient > 0.1:
                    raise StabilityException("lindemann_stability criterion violated. Stopping the simulation.")

        dyn.attach(lindemann_stability, interval=50)   ### last 1000 frames msd

        try:
            dyn.run(200)    ### No. of steps running simulation
            return 200  # Simulation completed successfully
        except StabilityException as e:
            # print(f"File: {os.path.basename(cif_file_path)} - {e}")
            return len(traj)  # Return the number of steps completed before failure

    except (UnknownFileTypeError, IOError, ValueError) as e:
        # print(f"Error reading file {os.path.basename(cif_file_path)}: {str(e)}")
        return None  # Indicate that the file couldn't be processed



In [10]:
# Load Data
dm = MatSciMLDataModule(
    "MaterialsProjectDataset",
    train_path="/home/m3rg2000/matsciml/Scale_new_lmdb/10k_new",#TRAIN_PATH,
    # val_split=VAL_PATH,
    # test_split=VAL_PATH,
    dset_kwargs={
        "transforms": [
            PeriodicPropertiesTransform(cutoff_radius=6.0, adaptive_cutoff=True),
            PointCloudToGraphTransform(
                "pyg",
                node_keys=["pos", "atomic_numbers"],
            ),
            FrameAveraging(frame_averaging="3D", fa_method="stochastic"),
        ],
    },
    batch_size=1,
)

dm.setup()
train_loader = dm.train_dataloader()
dataset_iter = iter(train_loader)
batch = next(dataset_iter)


In [18]:
# Main execution
from tqdm import tqdm
# cif_folder = "/home/m3rg2000/Simulation/checkpoints-2024/"
# failed_simulations = []
time_steps = []
unreadable_files = []
counter=0
for batch in tqdm(train_loader):
    atoms=convBatchtoAtoms(batch)
    steps_completed = run_simulation(atoms)        
    time_steps.append(steps_completed)

    counter += 1
    del batch
    if counter >= 10:
        break



  0%|                                                                                                                                                                | 0/10000 [00:00<?, ?it/s]

energy: 0.0
energy: 3.8312869131013085e-05
energy: 0.016224272097884303
lindemann_stability: 0.20122057381271843


  0%|                                                                                                                                                      | 1/10000 [00:02<6:57:19,  2.50s/it]

energy: 0.0
energy: 0.0006580372807878133
energy: 0.00440004515034582
lindemann_stability: 0.23906766989374717


  0%|                                                                                                                                                      | 2/10000 [00:04<5:18:38,  1.91s/it]

energy: 0.0
energy: 0.02342622837337313
energy: 0.025184115087101343
lindemann_stability: 0.19060791484588194


  0%|                                                                                                                                                      | 3/10000 [00:07<7:19:46,  2.64s/it]

energy: 0.0
energy: 0.0003642627279554307
energy: 0.0011311741123877
lindemann_stability: 0.30103519806167933


  0%|                                                                                                                                                      | 4/10000 [00:09<6:33:28,  2.36s/it]

energy: 0.0
energy: 7.205166805977904e-05
energy: 0.00010403703894635033
lindemann_stability: 0.2852839766706362


  0%|                                                                                                                                                      | 5/10000 [00:11<6:00:47,  2.17s/it]

energy: 0.0
energy: 0.0035314626596697945
energy: 0.006194991871453482
lindemann_stability: 0.1964845848626984


  0%|                                                                                                                                                      | 6/10000 [00:13<5:36:50,  2.02s/it]

energy: 0.0
energy: 0.0002790206505035732
energy: 0.004024769395014759
lindemann_stability: 0.15551445206513237


  0%|                                                                                                                                                      | 7/10000 [00:15<6:17:10,  2.26s/it]

energy: 0.0
energy: 0.0003298909692147317
energy: 0.00021128346494727772
lindemann_stability: 0.21454059573505063


  0%|                                                                                                                                                      | 8/10000 [00:17<5:46:23,  2.08s/it]

energy: 0.0
energy: 0.00759761920381159
energy: 0.03604185385514449
lindemann_stability: 0.8996257194744471


  0%|▏                                                                                                                                                     | 9/10000 [00:19<6:00:40,  2.17s/it]

energy: 0.0
energy: 0.0010049890261231076
energy: 0.00021876088673153195
lindemann_stability: 0.1669406230288212


  0%|▏                                                                                                                                                     | 9/10000 [00:22<6:59:11,  2.52s/it]
