# Tutorial

## Build a Si diamond crystal

Here we build a Si diamond crystal containing 64 atoms with the ASE build function.

In [1]:
from ase.build import bulk


si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))

Since we need to that all arrays to be torch tensors, we use the inbuilt `atoms_to_state` function in `torch-sim`. This function converts the ASE atoms object to a `SimState` object, which is a custom object that contains all the information about the system used by `torch-sim` for simulations. `torch-sim.io` module also supports back and forth conversion between Pymatgen `Structure` objects and Phonopy `PhonopyAtoms` objects.

The `atoms_to_state` function also needs the dtype and the device as arguments. This ensures that the state object is on the correct device and has the correct data type.

`SimState` is a dataclass that contains the positions, cell, atomic numbers, periodic boundary conditions, masses and the batch index of the atoms.

In [6]:
import torch

from torch_sim.io import atoms_to_state


# Set device and data type
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32

state = atoms_to_state(si_dc, dtype=dtype, device=device)
print(state)

SimState(positions=tensor([[0.0000, 0.0000, 0.0000],
        [1.3575, 1.3575, 1.3575],
        [0.0000, 2.7150, 2.7150],
        [1.3575, 4.0725, 4.0725],
        [2.7150, 0.0000, 2.7150],
        [4.0725, 1.3575, 4.0725],
        [2.7150, 2.7150, 0.0000],
        [4.0725, 4.0725, 1.3575],
        [0.0000, 0.0000, 5.4300],
        [1.3575, 1.3575, 6.7875],
        [0.0000, 2.7150, 8.1450],
        [1.3575, 4.0725, 9.5025],
        [2.7150, 0.0000, 8.1450],
        [4.0725, 1.3575, 9.5025],
        [2.7150, 2.7150, 5.4300],
        [4.0725, 4.0725, 6.7875],
        [0.0000, 5.4300, 0.0000],
        [1.3575, 6.7875, 1.3575],
        [0.0000, 8.1450, 2.7150],
        [1.3575, 9.5025, 4.0725],
        [2.7150, 5.4300, 2.7150],
        [4.0725, 6.7875, 4.0725],
        [2.7150, 8.1450, 0.0000],
        [4.0725, 9.5025, 1.3575],
        [0.0000, 5.4300, 5.4300],
        [1.3575, 6.7875, 6.7875],
        [0.0000, 8.1450, 8.1450],
        [1.3575, 9.5025, 9.5025],
        [2.7150, 5.4300, 8.14

One can also pass in the attributes directly to the `SimState` object instead of using ASE's build function.

## Define a `torch-sim` model
We will use the `MACE` model trained on the OMAT dataset for this tutorial.

In [8]:
# Load the raw model from the checkpoint
from mace.calculators.foundations_models import mace_mp


mace_checkpoint_url = "https://github.com/ACEsuit/mace-mp/releases/download/mace_omat_0/mace-omat-0-medium.model"
loaded_model = mace_mp(
    model=mace_checkpoint_url,
    return_raw_model=True,
    default_dtype=dtype,
    device=device,
)

Using medium OMAT-0 model under Academic Software License (ASL) license, see https://github.com/gabor1/ASL 
 To use this model you accept the terms of the license.
Using Materials Project MACE for MACECalculator with /home/abhijeet/.cache/mace/maceomat0mediummodel


  return torch.load(model_path, map_location=device)


We can then initialize the `MaceModel` from `torch-sim.models.mace` module.

In [9]:
from torch_sim.models.mace import MaceModel


mace_model = MaceModel(
    model=loaded_model,
    device=device,
    compute_force=True,
    compute_stress=True,
    dtype=dtype,
    enable_cueq=False,
)

Running BatchedMACEForce on device: cuda with dtype: torch.float32


  self.model.atomic_numbers = torch.tensor(


## Model call

We can then pass the structure defined in the `SimState` to the model and get the properties.

In [10]:
results = mace_model(state)
print(results)

{'energy': tensor([-347.3734], device='cuda:0'), 'forces': tensor([[ 1.0860e-05,  1.8540e-05,  9.8736e-06],
        [-1.0921e-05, -8.2361e-06, -9.5972e-06],
        [ 2.9740e-06,  8.4398e-07, -1.4720e-07],
        [ 9.0303e-07, -1.0987e-05, -1.1049e-05],
        [-6.2719e-07,  2.6526e-06, -9.4413e-08],
        [-9.3343e-06,  7.5492e-06, -9.8401e-06],
        [ 2.3516e-07,  3.6540e-06,  3.3305e-06],
        [-1.0521e-05, -9.8684e-06,  3.4072e-07],
        [ 2.7923e-06,  1.1069e-05,  6.7602e-06],
        [-1.1050e-05, -9.1003e-06, -1.0825e-05],
        [ 1.9614e-05, -2.6211e-06, -2.8371e-05],
        [ 1.0609e-05, -5.6052e-06,  2.2416e-05],
        [-1.6350e-06,  1.9337e-05, -2.4647e-05],
        [ 2.1930e-06,  1.8240e-05,  2.7243e-05],
        [ 3.8424e-06,  6.9391e-06,  1.0155e-05],
        [ 2.6202e-06, -2.7603e-06,  3.5580e-06],
        [ 8.0383e-06,  1.1599e-05,  8.8000e-06],
        [-6.1961e-06, -9.6887e-06, -6.9254e-06],
        [ 1.6080e-05, -2.7732e-05, -1.7172e-06],
        [ 