diff --git a/README.md b/README.md index 80ab3f4b..598ff824 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # easily load the model from mace-mp from mace.calculators.foundations_models import mace_mp -from torch_sim.models import MaceModel +from torch_sim.models.mace import MaceModel mace = mace_mp(model="small", return_raw_model=True) mace_model = MaceModel(model=mace, device=device) diff --git a/examples/scripts/1_Introduction/1.3_Batched_MACE.py b/examples/scripts/1_Introduction/1.3_Batched_MACE.py index 89885b57..62e1b936 100644 --- a/examples/scripts/1_Introduction/1.3_Batched_MACE.py +++ b/examples/scripts/1_Introduction/1.3_Batched_MACE.py @@ -11,7 +11,7 @@ from ase.build import bulk from mace.calculators.foundations_models import mace_mp -import torch_sim as ts +from torch_sim.models.mace import MaceModel # Set device and data type @@ -35,7 +35,7 @@ si_dc = bulk("Si", "diamond", a=5.43, cubic=True).repeat((2, 2, 2)) atoms_list = [si_dc, si_dc] -batched_model = ts.models.MaceModel( +batched_model = MaceModel( # Pass the raw model model=loaded_model, # Or load from compiled model diff --git a/examples/scripts/1_Introduction/1.4_Fairchem.py b/examples/scripts/1_Introduction/1.4_Fairchem.py index 0ca5eeb9..c8655280 100644 --- a/examples/scripts/1_Introduction/1.4_Fairchem.py +++ b/examples/scripts/1_Introduction/1.4_Fairchem.py @@ -13,6 +13,7 @@ from ase.build import bulk import torch_sim as ts +from torch_sim.models.fairchem import FairChemModel device = "cuda" if torch.cuda.is_available() else "cpu" @@ -31,7 +32,7 @@ # Create diamond cubic Silicon si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2)) atomic_numbers = si_dc.get_atomic_numbers() -model = ts.models.FairChemModel( +model = FairChemModel( model=MODEL_PATH, cpu=False, seed=0, diff --git a/examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py index 3165083d..4050b7fc 100644 --- a/examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.10_Batched_MACE_FrechetCellFilter_FIRE.py @@ -14,6 +14,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts +from torch_sim.models.mace import MaceModel from torch_sim.units import UnitConversion @@ -62,7 +63,7 @@ print(f"Total number of structures: {len(atoms_list)}") # Create batched model -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py b/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py index 8e2295cb..1ae52241 100644 --- a/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.7_Batched_MACE_FIRE.py @@ -14,6 +14,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts +from torch_sim.models.mace import MaceModel from torch_sim.optimizers import fire @@ -62,7 +63,7 @@ print(f"Total number of structures: {len(atoms_list)}") # Create batched model -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py b/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py index 0ad9f53d..23a8bfbe 100644 --- a/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py +++ b/examples/scripts/2_Structural_optimization/2.8_Batched_MACE_UnitCellFilter_Gradient_Descent.py @@ -14,6 +14,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts +from torch_sim.models.mace import MaceModel from torch_sim.optimizers import unit_cell_gradient_descent from torch_sim.units import UnitConversion @@ -63,7 +64,7 @@ print(f"Total number of structures: {len(atoms_list)}") # Create batched model -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py b/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py index c6eaa038..af17bfca 100644 --- a/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py +++ b/examples/scripts/2_Structural_optimization/2.9_Batched_MACE_UnitCellFilter_FIRE.py @@ -14,6 +14,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts +from torch_sim.models.mace import MaceModel from torch_sim.optimizers import unit_cell_fire from torch_sim.units import UnitConversion @@ -63,7 +64,7 @@ print(f"Total number of structures: {len(atoms_list)}") # Create batched model -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/5_Workflow/5.2_a2c_silicon_batched.py b/examples/scripts/5_Workflow/5.2_a2c_silicon_batched.py index 055d86a9..f230bdd1 100644 --- a/examples/scripts/5_Workflow/5.2_a2c_silicon_batched.py +++ b/examples/scripts/5_Workflow/5.2_a2c_silicon_batched.py @@ -24,6 +24,7 @@ from tqdm import tqdm import torch_sim as ts +from torch_sim.models.mace import MaceModel from torch_sim.unbatched.models.mace import UnbatchedMaceModel from torch_sim.unbatched.unbatched_integrators import ( NVTNoseHooverState, @@ -201,7 +202,7 @@ def step_fn( start_time = time.perf_counter() # Create a batched model -model = ts.models.MaceModel( +model = MaceModel( model=raw_model, device=device, compute_forces=True, diff --git a/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py b/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py index 9617d5bb..14672e11 100644 --- a/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py +++ b/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py @@ -15,6 +15,7 @@ from mace.calculators.foundations_models import mace_mp import torch_sim as ts +from torch_sim.models.mace import MaceModel # --- Setup and Configuration --- @@ -27,7 +28,7 @@ print("Loading MACE model...") mace_checkpoint_url = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" mace = mace_mp(model=mace_checkpoint_url, return_raw_model=True) -mace_model = ts.models.MaceModel( +mace_model = MaceModel( model=mace, device=device, dtype=dtype, diff --git a/examples/scripts/5_Workflow/5.4_Elastic.py b/examples/scripts/5_Workflow/5.4_Elastic.py index 9a6f8e3b..ccf0e391 100644 --- a/examples/scripts/5_Workflow/5.4_Elastic.py +++ b/examples/scripts/5_Workflow/5.4_Elastic.py @@ -13,6 +13,7 @@ import torch_sim as ts from torch_sim.elastic import get_bravais_type +from torch_sim.models.mace import MaceModel # Calculator @@ -31,7 +32,7 @@ # ASE structure struct = bulk("Cu", "fcc", a=3.58, cubic=True).repeat((2, 2, 2)) -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py index abc738ff..69ad1c13 100644 --- a/examples/scripts/6_Phonons/6.1_Phonons_MACE.py +++ b/examples/scripts/6_Phonons/6.1_Phonons_MACE.py @@ -24,6 +24,7 @@ ) import torch_sim as ts +from torch_sim.models.mace import MaceModel def get_qpts_and_connections( @@ -105,7 +106,7 @@ def get_labels_qpts(ph: Phonopy, n_points: int = 101) -> tuple[list[str], list[b displ = 0.01 # atomic displacement for phonons (in Angstrom) # Relax atomic positions -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py index 04080ea8..7b5dccf7 100644 --- a/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py +++ b/examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py @@ -23,6 +23,7 @@ from phonopy.structure.atoms import PhonopyAtoms import torch_sim as ts +from torch_sim.models.mace import MaceModel def get_relaxed_structure( @@ -219,7 +220,7 @@ def get_qha_phonons( default_dtype=dtype, device=device, ) -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py index b8c80bc3..3c00b18c 100644 --- a/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py +++ b/examples/scripts/6_Phonons/6.3_Conductivity_MACE.py @@ -22,6 +22,7 @@ from phono3py import Phono3py import torch_sim as ts +from torch_sim.models.mace import MaceModel def print_relax_info(trajectory_file: str, device: torch.device) -> None: @@ -54,7 +55,7 @@ def print_relax_info(trajectory_file: str, device: torch.device) -> None: loaded_model = mace_mp( model=mace_checkpoint_url, return_raw_model=True, default_dtype=dtype, device=device ) -model = ts.models.MaceModel( +model = MaceModel( model=loaded_model, device=device, compute_forces=True, diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index e3288fc5..c1647190 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -97,7 +97,7 @@ def mock_determine_max_batch_size(*args, **kwargs): # %% from torch_sim.autobatching import estimate_max_memory_scaler from mace.calculators.foundations_models import mace_mp -from torch_sim.models import MaceModel +from torch_sim.models.mace import MaceModel # Initialize your model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/examples/tutorials/low_level_tutorial.py b/examples/tutorials/low_level_tutorial.py index e3cf06de..e13c5c3c 100644 --- a/examples/tutorials/low_level_tutorial.py +++ b/examples/tutorials/low_level_tutorial.py @@ -65,7 +65,7 @@ # %% from mace.calculators.foundations_models import mace_mp -from torch_sim.models import MaceModel +from torch_sim.models.mace import MaceModel # load mace_mp using the mace package mace_checkpoint_url = "https://github.com/ACEsuit/mace-foundations/releases/download/mace_mpa_0/mace-mpa-0-medium.model" diff --git a/examples/tutorials/metatensor_tutorial.py b/examples/tutorials/metatensor_tutorial.py index 944d4348..701e4159 100644 --- a/examples/tutorials/metatensor_tutorial.py +++ b/examples/tutorials/metatensor_tutorial.py @@ -26,7 +26,7 @@ """ # %% -from torch_sim.models import MetatensorModel +from torch_sim.models.metatensor import MetatensorModel model = MetatensorModel("pet-mad") diff --git a/examples/tutorials/reporting_tutorial.py b/examples/tutorials/reporting_tutorial.py index c387c74d..c47340fe 100644 --- a/examples/tutorials/reporting_tutorial.py +++ b/examples/tutorials/reporting_tutorial.py @@ -205,7 +205,7 @@ """ # %% -from torch_sim.models import LennardJonesModel +from torch_sim.models.lennard_jones import LennardJonesModel # Define some property calculators diff --git a/examples/tutorials/using_graphpes_tutorial.py b/examples/tutorials/using_graphpes_tutorial.py index bbc6b5db..08dc80f4 100644 --- a/examples/tutorials/using_graphpes_tutorial.py +++ b/examples/tutorials/using_graphpes_tutorial.py @@ -49,7 +49,7 @@ """ # %% -from torch_sim.models import GraphPESWrapper +from torch_sim.models.graphpes import GraphPESWrapper # wrap the model for use with TorchSim ts_model = GraphPESWrapper(model, compute_stress=False) diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index b407955f..c9408a76 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -150,14 +150,13 @@ def regular_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] - where: - - εxx, εyy, εzz are normal strains - - εyz, εxz, εxy are shear strains + [εxx, εyy, εzz, εyz, εxz, εxy] where: + - εxx, εyy, εzz are normal strains + - εyz, εxz, εxy are shear strains Returns: torch.Tensor: Matrix of shape (6, 3) where columns correspond to - coefficients for C11, C12, and C44 respectively + coefficients for C11, C12, and C44 respectively Notes: The resulting matrix M has the form: @@ -214,14 +213,13 @@ def tetragonal_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] - where: - - εxx, εyy, εzz are normal strains - - εyz, εxz, εxy are shear strains + [εxx, εyy, εzz, εyz, εxz, εxy] where: + - εxx, εyy, εzz are normal strains + - εyz, εxz, εxy are shear strains Returns: torch.Tensor: Matrix of shape (6, 7) where columns correspond to - coefficients for C11, C12, C13, C16, C33, C44, C66 + coefficients for C11, C12, C13, C16, C33, C44, C66 Notes: The resulting matrix M has the form: @@ -280,11 +278,11 @@ def orthorhombic_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] + [εxx, εyy, εzz, εyz, εxz, εxy] Returns: torch.Tensor: Matrix of shape (6, 9) where columns correspond to - coefficients for C11, C12, C13, C22, C23, C33, C44, C55, C66 + coefficients for C11, C12, C13, C22, C23, C33, C44, C55, C66 Notes: The resulting matrix M has the form: @@ -343,11 +341,11 @@ def trigonal_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] + [εxx, εyy, εzz, εyz, εxz, εxy] Returns: torch.Tensor: Matrix of shape (6, 7) where columns correspond to - coefficients for C11, C12, C13, C14, C15, C33, C44 + coefficients for C11, C12, C13, C14, C15, C33, C44 Notes: The resulting matrix M has the form: @@ -416,11 +414,11 @@ def hexagonal_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] + [εxx, εyy, εzz, εyz, εxz, εxy] Returns: torch.Tensor: Matrix of shape (6, 5) where columns correspond to - coefficients for C11, C33, C12, C13, C44 + coefficients for C11, C33, C12, C13, C44 Notes: The resulting matrix M has the form: @@ -477,12 +475,12 @@ def monoclinic_symmetry(strains: torch.Tensor) -> torch.Tensor: Args: strains: Tensor of shape (6,) containing strain components - [εxx, εyy, εzz, εyz, εxz, εxy] + [εxx, εyy, εzz, εyz, εxz, εxy] Returns: torch.Tensor: Matrix of shape (6, 13) where columns correspond to - coefficients for the 13 independent constants in order: - [C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66] + coefficients for the 13 independent constants in order: + [C11, C12, C13, C15, C22, C23, C25, C33, C35, C44, C46, C55, C66] Notes: For monoclinic symmetry with unique axis b (y), the matrix has the form: diff --git a/torch_sim/math.py b/torch_sim/math.py index 667665a6..f272da31 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -376,7 +376,7 @@ def matrix_exp(A: torch.Tensor) -> torch.Tensor: A: Input matrix Returns: - Matrix exponential of A + torch.Tensor: Matrix exponential of A """ return torch.matrix_exp(A) @@ -390,7 +390,7 @@ def vec(M: torch.Tensor) -> torch.Tensor: M: Input matrix Returns: - Output vector + torch.Tensor: Output vector """ return M.t().reshape(-1) @@ -408,7 +408,7 @@ def expm_frechet_kronform( (crashes, non-termination) if the inputs do contain infinities or NaNs. Returns: - K: Kronecker form of the Frechet derivative of the matrix exponential + torch.Tensor: Kronecker form of the Frechet derivative of the matrix exponential with shape (N*N, N*N) """ if check_finite and not torch.isfinite(A).all(): @@ -519,7 +519,7 @@ def _is_valid_matrix(T: torch.Tensor, n: int = 3) -> bool: n: The expected dimension of the matrix, default=3 Returns: - True if T is a valid nxn tensor, False otherwise + bool: True if T is a valid nxn tensor, False otherwise """ return isinstance(T, torch.Tensor) and T.shape == (n, n) diff --git a/torch_sim/models/__init__.py b/torch_sim/models/__init__.py index 924284bc..894827fe 100644 --- a/torch_sim/models/__init__.py +++ b/torch_sim/models/__init__.py @@ -1,43 +1 @@ -"""Models for TorchSim.""" - -# ruff: noqa: F401 - -from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.morse import MorseModel -from torch_sim.models.soft_sphere import SoftSphereModel - - -try: - from torch_sim.models.orb import OrbModel -except ImportError: - pass - -try: - from torch_sim.models.fairchem import FairChemModel -except ImportError: - pass - -try: - from torch_sim.models.mace import MaceModel -except ImportError: - pass - -try: - from torch_sim.models.sevennet import SevenNetModel -except ImportError: - pass - -try: - from torch_sim.models.mattersim import MatterSimModel -except ImportError: - pass - -try: - from torch_sim.models.graphpes import GraphPESWrapper -except ImportError: - pass - -try: - from torch_sim.models.metatensor import MetatensorModel -except ImportError: - pass +"""Models integrated with TorchSim.""" diff --git a/torch_sim/models/graphpes.py b/torch_sim/models/graphpes.py index 80fead9b..10718f99 100644 --- a/torch_sim/models/graphpes.py +++ b/torch_sim/models/graphpes.py @@ -113,7 +113,7 @@ class GraphPESWrapper(torch.nn.Module, ModelInterface): compute_stress: Whether to compute stress Example: - >>> from torch_sim.models import GraphPESWrapper + >>> from torch_sim.models.graphpes import GraphPESWrapper >>> from graph_pes.models import load_model >>> model = load_model("path/to/model.pt") >>> wrapper = GraphPESWrapper(model) diff --git a/torch_sim/unbatched/models/lennard_jones.py b/torch_sim/unbatched/models/lennard_jones.py index a3a1c42d..df755079 100644 --- a/torch_sim/unbatched/models/lennard_jones.py +++ b/torch_sim/unbatched/models/lennard_jones.py @@ -34,8 +34,9 @@ def lennard_jones_pair( or tensor of shape [n, m] for pair-specific interaction strengths. Returns: - Pairwise Lennard-Jones interaction energies between particles. Shape: [n, m]. - Each element [i,j] represents the interaction energy between particles i and j. + torch.Tensor: Pairwise Lennard-Jones interaction energies between particles. + Shape: [n, m]. Each element [i,j] represents the interaction energy between + particles i and j. """ # Calculate inverse dr and its powers idr = sigma / dr @@ -76,9 +77,9 @@ def lennard_jones_pair_force( of shape [n, m] for pair-specific interaction strengths. Returns: - Pairwise Lennard-Jones forces between particles. Shape: [n, m]. - Each element [i,j] represents the force magnitude between particles i and j. - Positive values indicate repulsion, negative values indicate attraction. + torch.Tensor: Pairwise Lennard-Jones forces between particles. Shape: [n, m]. + Each element [i,j] represents the force magnitude between particles i and j. + Positive values indicate repulsion, negative values indicate attraction. """ # Calculate inverse dr and its powers idr = sigma / dr diff --git a/torch_sim/unbatched/models/morse.py b/torch_sim/unbatched/models/morse.py index 354a3fad..9831f0f0 100644 --- a/torch_sim/unbatched/models/morse.py +++ b/torch_sim/unbatched/models/morse.py @@ -40,8 +40,9 @@ def morse_pair( well. Either a scalar float or tensor of shape [n, m]. Returns: - Pairwise Morse interaction energies between particles. Shape: [n, m]. - Each element [i,j] represents the interaction energy between particles i and j. + torch.Tensor: Pairwise Morse interaction energies between particles. + Shape: [n, m]. Each element [i,j] represents the interaction energy between + particles i and j. """ # Calculate potential energy energy = epsilon * (1.0 - torch.exp(-alpha * (dr - sigma))).pow(2) - epsilon @@ -77,8 +78,8 @@ def morse_pair_force( tensor of shape [n, m]. Returns: - Pairwise Morse forces between particles. Shape: [n, m]. - Positive values indicate repulsion, negative values indicate attraction. + torch.Tensor: Pairwise Morse forces between particles. Shape: [n, m]. + Positive values indicate repulsion, negative values indicate attraction. """ exp_term = torch.exp(-alpha * (dr - sigma)) force = -2.0 * alpha * epsilon * exp_term * (1.0 - exp_term) diff --git a/torch_sim/unbatched/models/particle_life.py b/torch_sim/unbatched/models/particle_life.py index d2caa5b7..98f08095 100644 --- a/torch_sim/unbatched/models/particle_life.py +++ b/torch_sim/unbatched/models/particle_life.py @@ -29,7 +29,7 @@ def asymmetric_particle_pair_force( shape [n, m]. Returns: - Tensor of energies with shape [n, m]. + torch.Tensor: Energies with shape [n, m]. """ inner_mask = dr < beta outer_mask = (dr < sigma) & (dr > beta) @@ -64,7 +64,7 @@ def asymmetric_particle_pair_force_jit( shape [n, m]. Returns: - Tensor of energies with shape [n, m]. + torch.Tensor: Energies with shape [n, m]. """ inner_mask = dr < beta outer_mask = (dr < sigma) & (dr > beta) diff --git a/torch_sim/unbatched/models/soft_sphere.py b/torch_sim/unbatched/models/soft_sphere.py index 166f844f..48355471 100644 --- a/torch_sim/unbatched/models/soft_sphere.py +++ b/torch_sim/unbatched/models/soft_sphere.py @@ -38,8 +38,8 @@ def soft_sphere_pair( float or tensor of shape [n, m]. Returns: - Pairwise interaction energies between particles. Shape: [n, m]. Each element - [i,j] represents the repulsive energy between particles i and j. + torch.Tensor: Pairwise interaction energies between particles. Shape: [n, m]. + Each element [i,j] represents the repulsive energy between particles i and j. """ def fn(dr: torch.Tensor) -> torch.Tensor: @@ -76,7 +76,7 @@ def soft_sphere_pair_force( Returns: torch.Tensor: Forces between particle pairs with shape [n, m]. Forces are zero - for distances greater than sigma. + for distances greater than sigma. """ def fn(dr: torch.Tensor) -> torch.Tensor: diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 19420c36..d9783032 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -390,14 +390,13 @@ def random_packed_structure_multi( # Extract element information from composition into a robust dictionary format element_dict = composition.as_dict() element_symbols = list(element_dict) # Get unique elements - element_counts = [ - int(element_dict[el]) for el in element_symbols - ] # Get counts directly + element_counts = [int(element_dict[el]) for el in element_symbols] # Create species indices tensor mapping each atom to its species type # e.g. for Fe80B20: [0,0,...,0,1,1,...,1] where 0=Fe, 1=B species_idx = torch.tensor( - [i for i, count in enumerate(element_counts) for _ in range(count)], device=device + [i for i, count in enumerate(element_counts) for _ in range(count)], + device=device, ) # Calculate total atoms and number of unique species @@ -450,10 +449,8 @@ def random_packed_structure_multi( # Run FIRE optimization until convergence or max iterations for _step in range(max_iter): # Check if minimum distance criterion is met (95% of smallest target diameter) - if ( - min_distance(state.positions, cell, distance_tolerance) - > diameter_matrix.min() * 0.95 - ): + min_dist = min_distance(state.positions, cell, distance_tolerance) + if min_dist > diameter_matrix.min() * 0.95: break state = fire_update(state) print(f"Final energy: {state.energy.item():.4f}") @@ -666,24 +663,22 @@ def subcells_to_structures( - species: atomic species symbols """ list_subcells = [] - for ids, l, h in candidates: # noqa: E741 + for ids, lower_bound, upper_bound in candidates: # Get positions of atoms in this subcell pos = fractional_positions[ids] # Shift positions to start from origin - new_frac_pos = pos - l + new_frac_pos = pos - lower_bound # Scale positions to [0,1] range - new_frac_pos = new_frac_pos / (h - l) + new_frac_pos = new_frac_pos / (upper_bound - lower_bound) # Calculate new cell parameters - new_cell = cell * (h - l).unsqueeze(0) + new_cell = cell * (upper_bound - lower_bound).unsqueeze(0) - # Convert tensor indices to list/numpy array before indexing species list - species_indices = ids.cpu().numpy() # Convert to numpy array - subcell_species = [ - species[int(i)] for i in species_indices - ] # Get species for these atoms + # Get species for these atoms and convert tensor indices to list/numpy array + # before indexing species list + subcell_species = [species[int(i)] for i in ids.cpu().numpy()] list_subcells.append((new_frac_pos, new_cell, subcell_species))