Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# easily load the model from mace-mp
from mace.calculators.foundations_models import mace_mp
from torch_sim.models import MaceModel
from torch_sim.models.mace import MaceModel
mace = mace_mp(model="small", return_raw_model=True)
mace_model = MaceModel(model=mace, device=device)

Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/1_Introduction/1.3_Batched_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ase.build import bulk
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel


# Set device and data type
Expand All @@ -35,7 +35,7 @@
si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2))
atoms_list = [si_dc, si_dc]

batched_model = ts.models.MaceModel(
batched_model = MaceModel(
# Pass the raw model
model=loaded_model,
# Or load from compiled model
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/1_Introduction/1.4_Fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ase.build import bulk

import torch_sim as ts
from torch_sim.models.fairchem import FairChemModel


device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -31,7 +32,7 @@
# Create diamond cubic Silicon
si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2))
atomic_numbers = si_dc.get_atomic_numbers()
model = ts.models.FairChemModel(
model = FairChemModel(
model=MODEL_PATH,
cpu=False,
seed=0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel
from torch_sim.units import UnitConversion


Expand Down Expand Up @@ -62,7 +63,7 @@
print(f"Total number of structures: {len(atoms_list)}")

# Create batched model
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel
from torch_sim.optimizers import fire


Expand Down Expand Up @@ -62,7 +63,7 @@
print(f"Total number of structures: {len(atoms_list)}")

# Create batched model
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel
from torch_sim.optimizers import unit_cell_gradient_descent
from torch_sim.units import UnitConversion

Expand Down Expand Up @@ -63,7 +64,7 @@
print(f"Total number of structures: {len(atoms_list)}")

# Create batched model
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel
from torch_sim.optimizers import unit_cell_fire
from torch_sim.units import UnitConversion

Expand Down Expand Up @@ -63,7 +64,7 @@
print(f"Total number of structures: {len(atoms_list)}")

# Create batched model
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/5_Workflow/5.2_a2c_silicon_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tqdm import tqdm

import torch_sim as ts
from torch_sim.models.mace import MaceModel
from torch_sim.unbatched.models.mace import UnbatchedMaceModel
from torch_sim.unbatched.unbatched_integrators import (
NVTNoseHooverState,
Expand Down Expand Up @@ -201,7 +202,7 @@ def step_fn(

start_time = time.perf_counter()
# Create a batched model
model = ts.models.MaceModel(
model = MaceModel(
model=raw_model,
device=device,
compute_forces=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/5_Workflow/5.3_In_Flight_WBM.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mace.calculators.foundations_models import mace_mp

import torch_sim as ts
from torch_sim.models.mace import MaceModel


# --- Setup and Configuration ---
Expand All @@ -27,7 +28,7 @@
print("Loading MACE model...")
mace_checkpoint_url = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
mace = mace_mp(model=mace_checkpoint_url, return_raw_model=True)
mace_model = ts.models.MaceModel(
mace_model = MaceModel(
model=mace,
device=device,
dtype=dtype,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/5_Workflow/5.4_Elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch_sim as ts
from torch_sim.elastic import get_bravais_type
from torch_sim.models.mace import MaceModel


# Calculator
Expand All @@ -31,7 +32,7 @@
# ASE structure
struct = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2))

model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/6_Phonons/6.1_Phonons_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

import torch_sim as ts
from torch_sim.models.mace import MaceModel


def get_qpts_and_connections(
Expand Down Expand Up @@ -105,7 +106,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b
displ = 0.01 # atomic displacement for phonons (in Angstrom)

# Relax atomic positions
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from phonopy.structure.atoms import PhonopyAtoms

import torch_sim as ts
from torch_sim.models.mace import MaceModel


def get_relaxed_structure(
Expand Down Expand Up @@ -219,7 +220,7 @@ def get_qha_phonons(
default_dtype=dtype,
device=device,
)
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/6_Phonons/6.3_Conductivity_MACE.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from phono3py import Phono3py

import torch_sim as ts
from torch_sim.models.mace import MaceModel


def print_relax_info(trajectory_file: str, device: torch.device) -> None:
Expand Down Expand Up @@ -54,7 +55,7 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None:
loaded_model = mace_mp(
model=mace_checkpoint_url, return_raw_model=True, default_dtype=dtype, device=device
)
model = ts.models.MaceModel(
model = MaceModel(
model=loaded_model,
device=device,
compute_forces=True,
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/autobatching_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def mock_determine_max_batch_size(*args, **kwargs):
# %%
from torch_sim.autobatching import estimate_max_memory_scaler
from mace.calculators.foundations_models import mace_mp
from torch_sim.models import MaceModel
from torch_sim.models.mace import MaceModel

# Initialize your model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/low_level_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@

# %%
from mace.calculators.foundations_models import mace_mp
from torch_sim.models import MaceModel
from torch_sim.models.mace import MaceModel

# load mace_mp using the mace package
mace_checkpoint_url = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model"
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/metatensor_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

# %%
from torch_sim.models import MetatensorModel
from torch_sim.models.metatensor import MetatensorModel

model = MetatensorModel("pet-mad")

Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/reporting_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"""

# %%
from torch_sim.models import LennardJonesModel
from torch_sim.models.lennard_jones import LennardJonesModel


# Define some property calculators
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorials/using_graphpes_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
"""

# %%
from torch_sim.models import GraphPESWrapper
from torch_sim.models.graphpes import GraphPESWrapper

# wrap the model for use with TorchSim
ts_model = GraphPESWrapper(model, compute_stress=False)
Expand Down
36 changes: 17 additions & 19 deletions torch_sim/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,13 @@ def regular_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains
[εxx, εyy, εzz, εyz, εxz, εxy] where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains

Returns:
torch.Tensor: Matrix of shape (6, 3) where columns correspond to
coefficients for C11, C12, and C44 respectively
coefficients for C11, C12, and C44 respectively

Notes:
The resulting matrix M has the form:
Expand Down Expand Up @@ -214,14 +213,13 @@ def tetragonal_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains
[εxx, εyy, εzz, εyz, εxz, εxy] where:
- εxx, εyy, εzz are normal strains
- εyz, εxz, εxy are shear strains

Returns:
torch.Tensor: Matrix of shape (6, 7) where columns correspond to
coefficients for C11, C12, C13, C16, C33, C44, C66
coefficients for C11, C12, C13, C16, C33, C44, C66

Notes:
The resulting matrix M has the form:
Expand Down Expand Up @@ -280,11 +278,11 @@ def orthorhombic_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
[εxx, εyy, εzz, εyz, εxz, εxy]

Returns:
torch.Tensor: Matrix of shape (6, 9) where columns correspond to
coefficients for C11, C12, C13, C22, C23, C33, C44, C55, C66
coefficients for C11, C12, C13, C22, C23, C33, C44, C55, C66

Notes:
The resulting matrix M has the form:
Expand Down Expand Up @@ -343,11 +341,11 @@ def trigonal_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
[εxx, εyy, εzz, εyz, εxz, εxy]

Returns:
torch.Tensor: Matrix of shape (6, 7) where columns correspond to
coefficients for C11, C12, C13, C14, C15, C33, C44
coefficients for C11, C12, C13, C14, C15, C33, C44

Notes:
The resulting matrix M has the form:
Expand Down Expand Up @@ -416,11 +414,11 @@ def hexagonal_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
[εxx, εyy, εzz, εyz, εxz, εxy]

Returns:
torch.Tensor: Matrix of shape (6, 5) where columns correspond to
coefficients for C11, C33, C12, C13, C44
coefficients for C11, C33, C12, C13, C44

Notes:
The resulting matrix M has the form:
Expand Down Expand Up @@ -477,12 +475,12 @@ def monoclinic_symmetry(strains: torch.Tensor) -> torch.Tensor:

Args:
strains: Tensor of shape (6,) containing strain components
[εxx, εyy, εzz, εyz, εxz, εxy]
[εxx, εyy, εzz, εyz, εxz, εxy]

Returns:
torch.Tensor: Matrix of shape (6, 13) where columns correspond to
coefficients for the 13 independent constants in order:
[C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66]
coefficients for the 13 independent constants in order:
[C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66]

Notes:
For monoclinic symmetry with unique axis b (y), the matrix has the form:
Expand Down
8 changes: 4 additions & 4 deletions torch_sim/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def matrix_exp(A: torch.Tensor) -> torch.Tensor:
A: Input matrix

Returns:
Matrix exponential of A
torch.Tensor: Matrix exponential of A
"""
return torch.matrix_exp(A)

Expand All @@ -390,7 +390,7 @@ def vec(M: torch.Tensor) -> torch.Tensor:
M: Input matrix

Returns:
Output vector
torch.Tensor: Output vector
"""
return M.t().reshape(-1)

Expand All @@ -408,7 +408,7 @@ def expm_frechet_kronform(
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Returns:
K: Kronecker form of the Frechet derivative of the matrix exponential
torch.Tensor: Kronecker form of the Frechet derivative of the matrix exponential
with shape (N*N, N*N)
"""
if check_finite and not torch.isfinite(A).all():
Expand Down Expand Up @@ -519,7 +519,7 @@ def _is_valid_matrix(T: torch.Tensor, n: int = 3) -> bool:
n: The expected dimension of the matrix, default=3

Returns:
True if T is a valid nxn tensor, False otherwise
bool: True if T is a valid nxn tensor, False otherwise
"""
return isinstance(T, torch.Tensor) and T.shape == (n, n)

Expand Down
Loading
Loading