diff --git a/pyproject.toml b/pyproject.toml index 723edafb..1f6a69cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,3 +151,16 @@ conflicts = [ { extra = "sevenn" }, ], ] + + +[tool.basedpyright] +reportImplicitStringConcatenation = "none" +reportPrivateUsage = "none" # since ruff will catch this +reportAny= "none" +reportExplicitAny= "none" +include = ["torch_sim", "tests", "examples", "docs"] +exclude = [".venv"] +reportUnknownMemberType = "none" +reportUnknownVariableType = "none" +reportUnknownArgumentType = "none" +reportMissingTypeStubs = "none" \ No newline at end of file diff --git a/torch_sim/state.py b/torch_sim/state.py index fa898e1c..8fd59aff 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -9,9 +9,9 @@ import typing import warnings from collections import defaultdict -from collections.abc import Generator +from collections.abc import Generator, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast +from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, cast, overload import torch @@ -194,11 +194,7 @@ def n_atoms(self) -> int: @property def n_atoms_per_system(self) -> torch.Tensor: """Number of atoms per system.""" - return ( - self.system_idx.bincount() - if self.system_idx is not None - else torch.tensor([self.n_atoms], device=self.device) - ) + return self.system_idx.bincount() @property def n_atoms_per_batch(self) -> torch.Tensor: @@ -412,7 +408,7 @@ def __getitem__(self, system_indices: int | list[int] | slice | torch.Tensor) -> return _slice_state(self, system_indices) - def __init_subclass__(cls, **kwargs) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: """Enforce that all derived states cannot have tensor attributes that can also be None. This is because torch.concatenate cannot concat between a tensor and a None. See https://github.com/Radical-AI/torch-sim/pull/219 for more details. @@ -431,7 +427,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: for attr_name, attr_typehint in type_hints.items(): origin = typing.get_origin(attr_typehint) - is_union = origin is typing.Union + is_union = origin is typing.Union # pyright: ignore[reportDeprecated] if not is_union and origin is not None: # For Python 3.10+ `|` syntax, origin is types.UnionType # We check by name to be robust against module reloading/patching issues @@ -572,10 +568,10 @@ def _normalize_system_indices( if isinstance(system_indices, slice): # Let PyTorch handle the slice conversion with negative indices return torch.arange(n_systems, device=device)[system_indices] - if isinstance(system_indices, torch.Tensor): + if isinstance(system_indices, torch.Tensor): # pyright: ignore[reportUnnecessaryIsInstance] # Handle negative indices in tensors return torch.where(system_indices < 0, n_systems + system_indices, system_indices) - raise TypeError(f"Unsupported index type: {type(system_indices)}") + raise TypeError(f"Unsupported index type: {type(system_indices)}") # pyright: ignore[reportUnreachable] def state_to_device( @@ -606,14 +602,25 @@ def state_to_device( if isinstance(attr_value, torch.Tensor): attrs[attr_name] = attr_value.to(device=device) - if dtype is not None: - attrs["positions"] = attrs["positions"].to(dtype=dtype) - attrs["masses"] = attrs["masses"].to(dtype=dtype) - attrs["cell"] = attrs["cell"].to(dtype=dtype) - attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) + attrs["positions"] = attrs["positions"].to(dtype=dtype) + attrs["masses"] = attrs["masses"].to(dtype=dtype) + attrs["cell"] = attrs["cell"].to(dtype=dtype) + attrs["atomic_numbers"] = attrs["atomic_numbers"].to(dtype=torch.int) return type(state)(**attrs) +@overload +def get_attrs_for_scope( + state: SimState, scope: Literal["per-atom", "per-system"] +) -> Generator[tuple[str, torch.Tensor], None, None]: ... + + +@overload +def get_attrs_for_scope( + state: SimState, scope: Literal["global"] +) -> Generator[tuple[str, Any], None, None]: ... + + def get_attrs_for_scope( state: SimState, scope: Literal["per-atom", "per-system", "global"] ) -> Generator[tuple[str, Any], None, None]: @@ -627,15 +634,11 @@ def get_attrs_for_scope( Returns: Generator[tuple[str, Any], None, None]: A generator of attribute names and values """ - match scope: - case "per-atom": - attr_names = state._atom_attributes # noqa: SLF001 - case "per-system": - attr_names = state._system_attributes # noqa: SLF001 - case "global": - attr_names = state._global_attributes # noqa: SLF001 - case _: - raise ValueError(f"Unknown scope: {scope!r}") + attr_names = { + "per-atom": state._atom_attributes, # noqa: SLF001 + "per-system": state._system_attributes, # noqa: SLF001 + "global": state._global_attributes, # noqa: SLF001 + }[scope] for attr_name in attr_names: yield attr_name, getattr(state, attr_name) @@ -644,7 +647,7 @@ def _filter_attrs_by_mask( state: SimState, atom_mask: torch.Tensor, system_mask: torch.Tensor, -) -> dict: +) -> dict[str, Any]: """Filter attributes by atom and system masks. Selects subsets of attributes based on boolean masks for atoms and systems. @@ -708,21 +711,21 @@ def _split_state( list[SimState]: A list of SimState objects, each containing a single system """ - system_sizes = torch.bincount(state.system_idx).tolist() + system_sizes: list[int] = torch.bincount(state.system_idx).tolist() - split_per_atom = {} + split_per_atom: dict[str, Sequence[torch.Tensor]] = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): if attr_name != "system_idx": split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) - split_per_system = {} + split_per_system: dict[str, Sequence[torch.Tensor]] = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) global_attrs = dict(get_attrs_for_scope(state, "global")) # Create a state for each system - states = [] + states: list[SimStateVar] = [] n_systems = len(system_sizes) for i in range(n_systems): system_attrs = { @@ -740,7 +743,7 @@ def _split_state( # Add the global attributes **global_attrs, } - states.append(type(state)(**system_attrs)) + states.append(type(state)(**system_attrs)) # pyright: ignore[reportArgumentType] return states @@ -872,9 +875,9 @@ def concatenate_states( concatenated = dict(get_attrs_for_scope(first_state, "global")) # Pre-allocate lists for tensors to concatenate - per_atom_tensors = defaultdict(list) - per_system_tensors = defaultdict(list) - new_system_indices = [] + per_atom_tensors = defaultdict[str, list[torch.Tensor]](list) + per_system_tensors = defaultdict[str, list[torch.Tensor]](list) + new_system_indices: list[torch.Tensor] = [] system_offset = 0 # Process all states in a single pass @@ -944,7 +947,7 @@ def initialize_state( return state_to_device(system, device, dtype) if isinstance(system, list) and all(isinstance(s, SimState) for s in system): - if not all(cast("SimState", state).n_systems == 1 for state in system): + if not all(state.n_systems == 1 for state in system): # pyright: ignore[reportAttributeAccessIssue] raise ValueError( "When providing a list of states, to the initialize_state function, " "all states must have n_systems == 1. To fix this, you can split the " diff --git a/torch_sim/typing.py b/torch_sim/typing.py index 94ec44ca..ea603dea 100644 --- a/torch_sim/typing.py +++ b/torch_sim/typing.py @@ -1,7 +1,7 @@ """Types used across torch-sim.""" from enum import Enum -from typing import TYPE_CHECKING, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Literal, TypeVar import torch @@ -40,13 +40,13 @@ class BravaisType(Enum): TRICLINIC = "triclinic" -StateLike = Union[ - "Atoms", - "Structure", - "PhonopyAtoms", - list["Atoms"], - list["Structure"], - list["PhonopyAtoms"], - SimStateVar, - list[SimStateVar], -] +StateLike = ( + Atoms + | Structure + | PhonopyAtoms + | list[Atoms] + | list[Structure] + | list[PhonopyAtoms] + | SimState + | list[SimState] +)