In [2]:
import torch
from e3nn.util import jit
from mace.tools import utils, to_one_hot, atomic_numbers_to_indices
from typing import Optional
import openmm as mm
import openmm.unit as unit
import openmm.app as app
from openmm.app import PDBFile
from openmmtorch import TorchForce

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.




In [3]:
cuda_available = torch.cuda.is_available()
print(f"CUDA Available: {cuda_available}")

if cuda_available:
    num_gpus = torch.cuda.device_count()
    print(f"Number of GPUs: {num_gpus}")
    current_gpu_name = torch.cuda.get_device_name(0)
    print(f"Current GPU Name: {current_gpu_name}")
    print("\nGPU Memory Usage:")
    print(f"  Allocated: {round(torch.cuda.memory_allocated(0) / 1024**3, 1)} GB")
    print(f"  Cached: {round(torch.cuda.memory_reserved(0) / 1024**3, 1)} GB")
else:
    print("PyTorch is running on CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\nUsing device: {device}")

CUDA Available: True
Number of GPUs: 1
Current GPU Name: NVIDIA GeForce RTX 5070 Laptop GPU

GPU Memory Usage:
  Allocated: 0.0 GB
  Cached: 0.0 GB

Using device: cuda


NVIDIA GeForce RTX 5070 Laptop GPU with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_89 sm_90 compute_90.
If you want to use the NVIDIA GeForce RTX 5070 Laptop GPU GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



In [4]:
def simple_nl(
    positions: torch.Tensor,
    cell: torch.Tensor,
    pbc: bool,
    cutoff: float,
    sorti: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:

    num_atoms = positions.shape[0]
    device = positions.device
    uij = torch.triu_indices(num_atoms, num_atoms, 1, device=device)
    triu_deltas = positions[uij[0]] - positions[uij[1]]
    wrapped_triu_deltas = triu_deltas.clone()

    if pbc:
        wrapped_triu_deltas -= torch.outer(
            torch.round(wrapped_triu_deltas[:, 2] / cell[2, 2]), cell[2]
        )
        wrapped_triu_deltas -= torch.outer(
            torch.round(wrapped_triu_deltas[:, 1] / cell[1, 1]), cell[1]
        )
        wrapped_triu_deltas -= torch.outer(
            torch.round(wrapped_triu_deltas[:, 0] / cell[0, 0]), cell[0]
        )
        shifts = torch.mm(triu_deltas - wrapped_triu_deltas, torch.linalg.inv(cell))
    else:
        shifts = torch.zeros(triu_deltas.shape, device=device)

    triu_distances = torch.linalg.norm(wrapped_triu_deltas, dim=1)

    mask = triu_distances > cutoff
    uij = uij[:, ~mask]
    shifts = shifts[~mask, :]

    lij = torch.stack((uij[1], uij[0]))
    neighbors = torch.hstack((uij, lij))
    shifts = torch.vstack((shifts, -shifts))

    if sorti:
        idx = torch.argsort(neighbors[0])
        neighbors = neighbors[:, idx]
        shifts = shifts[idx, :]

    return neighbors, shifts


In [5]:
class MACEForce(torch.nn.Module):
    def __init__(self, model_path, atomic_numbers, indices, periodic, device, dtype=torch.float64):
        super().__init__()
        self.device=device
        self.default_dtype = dtype
        print("Running MACEForce on device: ", self.device, " with dtype: ", self.default_dtype)
        self.nm_to_distance = 10.0 # nm->A
        self.energy_to_kJ = 96.49  # eV->kJ
        self.model = torch.load(model_path, map_location=device)
        self.r_max = self.model.r_max
        self.z_table = utils.AtomicNumberTable([int(z) for z in self.model.atomic_numbers])
        self.model = jit.compile(self.model).to(self.default_dtype)

        N=len(atomic_numbers)
        self.ptr = torch.tensor([0,N],dtype=torch.long, device=self.device)
        self.batch = torch.zeros(N, dtype=torch.long, device=self.device)

        self.node_attrs = to_one_hot(
                torch.tensor(atomic_numbers_to_indices(atomic_numbers, z_table=self.z_table), dtype=torch.long, device=self.device).unsqueeze(-1),
                num_classes=len(self.z_table),
            ).to(self.default_dtype)

        if periodic:
            self.pbc = torch.tensor([True, True, True], device=self.device)
        else:
            self.pbc = torch.tensor([False, False, False], device=self.device)

        if indices is None:
            self.indices = None
        else:
            self.indices = torch.tensor(indices, dtype=torch.int64)

    def forward(self, positions, boxvectors: Optional[torch.Tensor] = None):
        positions = positions.to(device=self.device, dtype=self.default_dtype)
        if self.indices is not None:
            positions = positions[self.indices]
        positions = positions*self.nm_to_distance

        if boxvectors is not None:
            cell = boxvectors.to(device=self.device,dtype=self.default_dtype) * self.nm_to_distance
            pbc = True
        else:
            cell = torch.eye(3, device=self.device)
            pbc = False

        mapping, shifts_idx = simple_nl(positions, cell, pbc, self.r_max)
        edge_index = torch.stack((mapping[0], mapping[1])).to(torch.int64)
        shifts = torch.mm(shifts_idx, cell).to(self.default_dtype)

        input_dict = {
            "ptr" : self.ptr,
            "node_attrs": self.node_attrs,
            "batch": self.batch,
            "pbc": self.pbc,
            "positions": positions,
            "edge_index": edge_index,
            "shifts": shifts,
            "cell": cell,
        }

        energy = self.model(input_dict, compute_force=False)["interaction_energy"]
        assert energy is not None, "The model did not return any energy. Please check the input."
        energy = energy*self.energy_to_kJ

        return energy

In [6]:
with open('gaff_system.xml', 'r') as f:
    system = mm.XmlSerializer.deserialize(f.read())
pdb = PDBFile('gaff_ligand_in_solvent.pdb')
system.addForce(mm.MonteCarloBarostat(1*unit.atmosphere, 300*unit.kelvin))
#
chains = list(pdb.topology.chains())
print(chains)
ml_atoms = [atom.index for atom in chains[0].atoms()]
print(ml_atoms)
atomic_numbers = [atom.element.atomic_number for atom in chains[0].atoms()]
print(atomic_numbers)

#-----------------------------------------------------------------------------------------------------------------#

def removeBonds(system, atoms, removeInSet=True, removeConstraints=True):

    atomSet = set(atoms)
    import xml.etree.ElementTree as ET
    xml = mm.XmlSerializer.serialize(system)
    root = ET.fromstring(xml)

    def shouldRemove(termAtoms):
        return all(a in atomSet for a in termAtoms) == removeInSet

    for bonds in root.findall("./Forces/Force/Bonds"):
        for bond in bonds.findall("Bond"):
            bondAtoms = [int(bond.attrib[p]) for p in ("p1", "p2")]
            if shouldRemove(bondAtoms):
                bonds.remove(bond)
    for angles in root.findall("./Forces/Force/Angles"):
        for angle in angles.findall("Angle"):
            angleAtoms = [int(angle.attrib[p]) for p in ("p1", "p2", "p3")]
            if shouldRemove(angleAtoms):
                angles.remove(angle)
    for torsions in root.findall("./Forces/Force/Torsions"):
        for torsion in torsions.findall("Torsion"):
            torsionAtoms = [int(torsion.attrib[p]) for p in ("p1", "p2", "p3", "p4")]
            if shouldRemove(torsionAtoms):
                torsions.remove(torsion)

    if removeConstraints:
        for constraints in root.findall("./Constraints"):
            for constraint in constraints.findall("Constraint"):
                constraintAtoms = [int(constraint.attrib[p]) for p in ("p1", "p2")]
                if shouldRemove(constraintAtoms):
                    constraints.remove(constraint)

    return mm.XmlSerializer.deserialize(ET.tostring(root, encoding="unicode"))


def removeMMInteraction(system, ml_atoms):
    newSystem = removeBonds(system, ml_atoms)
    for force in newSystem.getForces():
        if isinstance(force, mm.NonbondedForce):
            for i in range(len(ml_atoms)):
                for j in range(i):
                    force.addException(ml_atoms[i], ml_atoms[j], 0, 1, 0, True)
        elif isinstance(force, mm.CustomNonbondedForce):
            existing = set(tuple(force.getExclusionParticles(i)) for i in range(force.getNumExclusions()))
            for i in range(len(ml_atoms)):
                a1 = ml_atoms[i]
                for j in range(i):
                    a2 = ml_atoms[j]
                    if (a1, a2) not in existing and (a2, a1) not in existing:
                        force.addExclusion(a1, a2, True)
    return newSystem

mixed_system = removeMMInteraction(system, ml_atoms)
print("number of forces (before nnp)  = ", mixed_system.getNumForces())
print(mixed_system.getForces())

[<Chain 0>, <Chain 1>]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150]
[6, 6, 8, 6, 6, 7, 6, 6, 6, 6, 6, 7, 6, 6, 6, 6, 6, 6, 6, 6, 8, 6, 6, 7, 7, 6, 6, 7, 6, 6, 6, 6, 6, 7, 6, 6, 9, 6, 6, 7, 6, 6, 6, 6, 7, 6, 6, 6, 6, 8, 6, 6, 8, 6, 8, 6, 6, 6, 7, 6, 6, 8, 8, 6, 6, 6, 8, 6, 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [7]:
model_path = "MACE-OFF23_small.model"
pbc = True 
indices = ml_atoms
mace_mlp = MACEForce(model_path, atomic_numbers, indices, pbc, device)

Running MACEForce on device:  cuda  with dtype:  torch.float64


  self.model = torch.load(model_path, map_location=device)


In [8]:
torch_module = torch.jit.script(mace_mlp)
torchforce = TorchForce(torch_module)
mixed_system.addForce(torchforce)

6

In [9]:
# Set up simulation
temperature = 300 * unit.kelvin
frictionCoeff = 1 / unit.picosecond
timeStep = 0.5 * unit.femtosecond
integrator = mm.LangevinMiddleIntegrator(temperature, frictionCoeff, timeStep)
#
simulation = app.Simulation(pdb.topology, mixed_system, integrator)
print("Simulation is running on:", simulation.context.getPlatform().getName())
#
simulation.reporters.clear()
simulation.context.setPositions(pdb.positions)
simulation.minimizeEnergy()
with open(f'min.pdb', 'w') as f:
    app.PDBFile.writeFile(simulation.topology, simulation.context.getState(getPositions=True).getPositions(), f)
#
state = simulation.context.getState(getEnergy=True, getForces=True)
openmm_energy = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
print(openmm_energy)
openmm_force = state.getForces(asNumpy=True).value_in_unit(unit.kilojoule_per_mole/unit.nanometer)
print(openmm_force)
#
simulation.reporters.append(app.DCDReporter('test_mace_mixed.dcd', 100))
simulation.context.setVelocitiesToTemperature(temperature)
reporter = app.StateDataReporter(
    'data.csv', 
    100, 
    step=True, 
    time=True, 
    potentialEnergy=True,
    kineticEnergy=True,
    totalEnergy=True,  
    temperature=True,
    volume=True,
    density=True, 
    speed=True)
simulation.reporters.append(reporter)
simulation.step(200000)

Simulation is running on: CUDA
-147751.02618484327
[[  -23.86135208     6.22130033    44.58150493]
 [    4.24175619     1.61592418    28.01310759]
 [   27.62696429    32.96341682   -16.84038361]
 ...
 [-1551.16995137  -827.59710899   997.25199149]
 [  585.29151291   755.0421598   -597.83918476]
 [  964.95500438    75.85177563  -397.45638694]]
