diff --git a/README.md b/README.md index 64ee92d7..90f1a5b8 100644 --- a/README.md +++ b/README.md @@ -125,8 +125,6 @@ To understand how TorchSim works, start with the [comprehensive tutorials](https TorchSim's structure is summarized in the [API reference](https://radical-ai.github.io/torch-sim/reference/index.html) documentation. -> `torch-sim` module graph. Each node represents a Python module. Arrows indicate imports between modules. Node color indicates connectedness: blue nodes have fewer dependents, red nodes have more (up to 16). The number in parentheses is the number of lines of code in the module. - ## License TorchSim is released under an [MIT license](LICENSE). diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 48239e56..f74ba567 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -3,7 +3,7 @@ API reference ============= -Overview of the torch_sim API. +Overview of the TorchSim API. .. currentmodule:: torch_sim @@ -28,6 +28,12 @@ Overview of the torch_sim API. transforms units + +TorchSim module graph. Each node represents a Python module. Arrows indicate +imports between modules. Node color indicates connectedness: blue nodes have fewer +dependents, red nodes have more (up to 16). The number in parentheses is the number of +lines of code in the module. Click on nodes to navigate to the file. + .. image:: /_static/torch-sim-module-graph.svg :alt: torch-sim Module Graph :width: 100% diff --git a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py index 36988d2f..419c79fa 100644 --- a/examples/scripts/4_High_level_api/4.2_auto_batching_api.py +++ b/examples/scripts/4_High_level_api/4.2_auto_batching_api.py @@ -18,8 +18,8 @@ from mace.calculators.foundations_models import mace_mp from torch_sim.autobatching import ( - ChunkingAutoBatcher, - HotSwappingAutoBatcher, + BinningAutoBatcher, + InFlightAutoBatcher, calculate_memory_scaler, ) from torch_sim.integrators import nvt_langevin @@ -65,7 +65,7 @@ # %% TODO: add max steps converge_max_force = generate_force_convergence_fn(force_tol=1e-1) single_system_memory = calculate_memory_scaler(fire_states[0]) -batcher = HotSwappingAutoBatcher( +batcher = InFlightAutoBatcher( model=mace_model, memory_scales_with="n_atoms_x_density", max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, @@ -86,7 +86,7 @@ print("Total number of completed states", len(all_completed_states)) -# %% run chunking autobatcher +# %% run binning autobatcher nvt_init, nvt_update = nvt_langevin( model=mace_model, dt=0.001, kT=300 * MetalUnits.temperature ) @@ -105,7 +105,7 @@ single_system_memory = calculate_memory_scaler(fire_states[0]) -batcher = ChunkingAutoBatcher( +batcher = BinningAutoBatcher( model=mace_model, memory_scales_with="n_atoms_x_density", max_memory_scaler=single_system_memory * 2.5 if os.getenv("CI") else None, diff --git a/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py b/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py similarity index 98% rename from examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py rename to examples/scripts/5_Workflow/5.3_In_Flight_WBM.py index 79f458fd..9ac1dd11 100644 --- a/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py +++ b/examples/scripts/5_Workflow/5.3_In_Flight_WBM.py @@ -68,7 +68,7 @@ ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype) ) -batcher = ts.autobatching.HotSwappingAutoBatcher( +batcher = ts.autobatching.InFlightAutoBatcher( model=mace_model, memory_scales_with="n_atoms_x_density", max_memory_scaler=1000 if os.getenv("CI") else None, diff --git a/examples/tutorials/autobatching_tutorial.py b/examples/tutorials/autobatching_tutorial.py index b6e8324a..833d6193 100644 --- a/examples/tutorials/autobatching_tutorial.py +++ b/examples/tutorials/autobatching_tutorial.py @@ -29,7 +29,7 @@ atoms exceeds available GPU memory. The `torch_sim.autobatching` module solves this by: 1. Automatically determining optimal batch sizes based on GPU memory constraints -2. Providing two complementary strategies: chunking and hot-swapping +2. Providing two complementary strategies: binning and in-flight 3. Efficiently managing memory resources during large-scale simulations Let's explore how to use these powerful features! @@ -120,9 +120,9 @@ def mock_determine_max_batch_size(*args, **kwargs): This is a verbose way to determine the max memory metric, we'll see a simpler way shortly. -## ChunkingAutoBatcher: Fixed Batching Strategy +## BinningAutoBatcher: Fixed Batching Strategy -Now on to the exciting part, autobatching! The `ChunkingAutoBatcher` groups states into +Now on to the exciting part, autobatching! The `BinningAutoBatcher` groups states into batches with a binpacking algorithm, ensuring that we minimize the total number of batches while maximizing the GPU utilization of each batch. This approach is ideal for scenarios where all states need to be processed the same number of times, such as @@ -132,7 +132,7 @@ def mock_determine_max_batch_size(*args, **kwargs): """ # %% Initialize the batcher, the max memory scaler will be computed automatically -batcher = ts.ChunkingAutoBatcher( +batcher = ts.BinningAutoBatcher( model=mace_model, memory_scales_with="n_atoms", ) @@ -167,11 +167,11 @@ def process_batch(batch): maximum safe batch size through test runs on your GPU. However, the max memory scaler is typically fixed for a given model and simulation setup. To avoid calculating it every time, which is a bit slow, you can calculate it once and then include it in the -`ChunkingAutoBatcher` constructor. +`BinningAutoBatcher` constructor. """ # %% -batcher = ts.ChunkingAutoBatcher( +batcher = ts.BinningAutoBatcher( model=mace_model, memory_scales_with="n_atoms", max_memory_scaler=max_memory_scaler, @@ -192,7 +192,7 @@ def process_batch(batch): nvt_state = nvt_init(state) # Initialize the batcher -batcher = ts.ChunkingAutoBatcher( +batcher = ts.BinningAutoBatcher( model=mace_model, memory_scales_with="n_atoms", ) @@ -217,13 +217,13 @@ def process_batch(batch): # %% [markdown] """ -## HotSwappingAutoBatcher: Dynamic Batching Strategy +## InFlightAutoBatcher: Dynamic Batching Strategy -The `HotSwappingAutoBatcher` optimizes GPU utilization by dynamically removing +The `InFlightAutoBatcher` optimizes GPU utilization by dynamically removing converged states and adding new ones. This is ideal for processes like geometry optimization where different states may converge at different rates. -The `HotSwappingAutoBatcher` is more complex than the `ChunkingAutoBatcher` because +The `InFlightAutoBatcher` is more complex than the `BinningAutoBatcher` because it requires the batch to be dynamically updated. The swapping logic is handled internally, but the user must regularly provide a convergence tensor indicating which batches in the state have converged. @@ -236,7 +236,7 @@ def process_batch(batch): fire_state = fire_init(state) # Initialize the batcher -batcher = ts.HotSwappingAutoBatcher( +batcher = ts.InFlightAutoBatcher( model=mace_model, memory_scales_with="n_atoms", max_memory_scaler=1000, @@ -296,7 +296,7 @@ def process_batch(batch): """ # %% Initialize with return_indices=True -batcher = ts.ChunkingAutoBatcher( +batcher = ts.BinningAutoBatcher( model=mace_model, memory_scales_with="n_atoms", max_memory_scaler=80, @@ -317,8 +317,8 @@ def process_batch(batch): TorchSim's autobatching provides powerful tools for GPU-efficient simulation of multiple systems: -1. Use `ChunkingAutoBatcher` for simpler workflows with fixed iteration counts -2. Use `HotSwappingAutoBatcher` for optimization problems with varying convergence +1. Use `BinningAutoBatcher` for simpler workflows with fixed iteration counts +2. Use `InFlightAutoBatcher` for optimization problems with varying convergence rates 3. Let the library handle memory management automatically, or specify limits manually diff --git a/tests/Interactive-1.interactive b/tests/Interactive-1.interactive new file mode 100755 index 00000000..782fa545 --- /dev/null +++ b/tests/Interactive-1.interactive @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Connected to .venv (Python 3.12.9)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35d12cef-0642-4f4d-8546-1d63bb6315d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([])" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "torch.tensor([1, 2, 3, 4, 5])[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "891faeb7-346d-4ef9-afe4-1cb027f7b24a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([])" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "torch.tensor([1, 2, 3, 4, 5])[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "525a99a3-3606-4ab8-91dc-4bce2bd82a43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "a = torch.tensor([1, 2, 3, 4, 5])[0].shape\n", + "\n", + "torch.tensor(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75cc33ec-eef6-4862-ac99-a659c499e92e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":6: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " torch.tensor(a)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([1, 2, 3, 4, 5])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "a = torch.tensor([1, 2, 3, 4, 5])\n", + "\n", + "torch.tensor(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7b39b27-38f9-4eed-8529-4dc6004a8484", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":6: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " torch.tensor(a, device=torch.device(\"cuda\"))\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([1, 2, 3, 4, 5], device='cuda:0')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "a = torch.tensor([1, 2, 3, 4, 5])\n", + "\n", + "torch.tensor(a, device=torch.device(\"cuda\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1eb08b1e-c30f-4fef-90a4-64f5777cc73e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":6: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " torch.tensor(a, device=torch.device(\"cuda\"), dtype=torch.float64)\n" + ] + }, + { + "data": { + "text/plain": [ + "tensor([1., 2., 3., 4., 5.], device='cuda:0', dtype=torch.float64)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import torch\n", + "a = torch.tensor([1, 2, 3, 4, 5])\n", + "\n", + "torch.tensor(a, device=torch.device(\"cuda\"), dtype=torch.float64)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No kernel connected" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No kernel connected" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No kernel connected" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Connected to .venv (Python 3.12.9)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3686e13-8270-4cdf-ad51-f37c2faac450", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/e3nn/o3/_wigner.py:10: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.\n", + " _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cuequivariance or cuequivariance_torch is not available. Cuequivariance acceleration will be disabled.\n", + "job will run on device=device(type='cuda')\n", + "Loading MACE model...\n", + "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument\n", + "Using Materials Project MACE for MACECalculator with /home/ray/.cache/mace/macempa0mediummodel\n", + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/mace/calculators/foundations_models.py:169: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.\n", + " return torch.load(model_path, map_location=device)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading 2 structures...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/propfoliotorchsim/torch-sim/torch_sim/models/mace.py:175: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " self.model.atomic_numbers = torch.tensor(\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from mace.calculators.foundations_models import mace_mp\n", + "\n", + "import torch_sim as ts\n", + "\n", + "\n", + "# --- Setup and Configuration ---\n", + "# Device and data type configuration\n", + "device = torch.device(\"cpu\") if os.getenv(\"CI\") else torch.device(\"cuda\")\n", + "dtype = torch.float32\n", + "print(f\"job will run on {device=}\")\n", + "\n", + "# --- Model Initialization ---\n", + "print(\"Loading MACE model...\")\n", + "mace_checkpoint_url = \"https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model\"\n", + "mace = mace_mp(model=mace_checkpoint_url, return_raw_model=True)\n", + "mace_model = ts.models.MaceModel(\n", + " model=mace,\n", + " device=device,\n", + " dtype=dtype,\n", + " compute_forces=True,\n", + ")\n", + "\n", + "# Optimization parameters\n", + "fmax = 0.05 # Force convergence threshold\n", + "n_steps = 10 if os.getenv(\"CI\") else 200_000_000\n", + "max_atoms_in_batch = 50 if os.getenv(\"CI\") else 8_000\n", + "\n", + "# --- Data Loading ---\n", + "if not True:\n", + " n_structures_to_relax = 100\n", + " print(f\"Loading {n_structures_to_relax:,} structures...\")\n", + " from matbench_discovery.data import DataFiles, ase_atoms_from_zip\n", + "\n", + " ase_atoms_list = ase_atoms_from_zip(\n", + " DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax\n", + " )\n", + "else:\n", + " n_structures_to_relax = 2\n", + " print(f\"Loading {n_structures_to_relax:,} structures...\")\n", + " from ase.build import bulk\n", + "\n", + " al_atoms = bulk(\"Al\", \"hcp\", a=4.05)\n", + " al_atoms.positions += 0.1 * np.random.randn(*al_atoms.positions.shape)\n", + " fe_atoms = bulk(\"Fe\", \"bcc\", a=2.86)\n", + " fe_atoms.positions += 0.1 * np.random.randn(*fe_atoms.positions.shape)\n", + " ase_atoms_list = [al_atoms, fe_atoms]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf323d76-279b-4f05-a778-f9d186cfe7e8", + "metadata": {}, + "outputs": [], + "source": [ + "# --- Optimization Setup ---\n", + "# Statistics tracking\n", + "\n", + "# Initialize first batch\n", + "fire_init, fire_update = ts.optimizers.frechet_cell_fire(model=mace_model)\n", + "fire_states = fire_init(\n", + " ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype)\n", + ")\n", + "\n", + "batcher = ts.autobatching.InFlightAutoBatcher(\n", + " model=mace_model,\n", + " memory_scales_with=\"n_atoms_x_density\",\n", + " max_memory_scaler=1000 if os.getenv(\"CI\") else None,\n", + ")\n", + "converge_max_force = ts.runners.generate_force_convergence_fn(force_tol=0.05)\n", + "\n", + "start_time = time.perf_counter()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89c3c813-3ee9-4439-9899-935d0618e46b", + "metadata": {}, + "outputs": [ + { + "ename": "ValueError", + "evalue": "n_atoms (1) and n_batches (1) are equal, which means shapes cannot be inferred unambiguously.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "File \u001b[1;32m/workspaces/propfoliotorchsim/torch-sim/examples/scripts/5_Workflow/5.3_Hot_Swap_WBM.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# %%\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39m# --- Main Optimization Loop ---\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m batcher\u001b[39m.\u001b[39;49mload_states(fire_states)\n\u001b[1;32m 4\u001b[0m all_completed_states, convergence_tensor, state \u001b[39m=\u001b[39m [], \u001b[39mNone\u001b[39;00m, \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[39mwhile\u001b[39;00m (result \u001b[39m:=\u001b[39m batcher\u001b[39m.\u001b[39mnext_batch(state, convergence_tensor))[\u001b[39m0\u001b[39m] \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/autobatching.py:845\u001b[0m, in \u001b[0;36mHotSwappingAutoBatcher.load_states\u001b[0;34m(self, states)\u001b[0m\n\u001b[1;32m 842\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcompleted_idx_og_order \u001b[39m=\u001b[39m []\n\u001b[1;32m 844\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mfirst_batch_returned \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[0;32m--> 845\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_first_batch \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_first_batch()\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/autobatching.py:938\u001b[0m, in \u001b[0;36mHotSwappingAutoBatcher._get_first_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 935\u001b[0m states \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_get_next_states()\n\u001b[1;32m 937\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m has_max_metric:\n\u001b[0;32m--> 938\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_memory_scaler \u001b[39m=\u001b[39m estimate_max_memory_scaler(\n\u001b[1;32m 939\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel,\n\u001b[1;32m 940\u001b[0m [first_state, \u001b[39m*\u001b[39;49mstates],\n\u001b[1;32m 941\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcurrent_scalers,\n\u001b[1;32m 942\u001b[0m max_atoms\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmax_atoms_to_try,\n\u001b[1;32m 943\u001b[0m scale_factor\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmemory_scaling_factor,\n\u001b[1;32m 944\u001b[0m )\n\u001b[1;32m 945\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mMax metric calculated: \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmax_memory_scaler\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 946\u001b[0m \u001b[39mreturn\u001b[39;00m concatenate_states([first_state, \u001b[39m*\u001b[39mstates])\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/autobatching.py:410\u001b[0m, in \u001b[0;36mestimate_max_memory_scaler\u001b[0;34m(model, state_list, metric_values, **kwargs)\u001b[0m\n\u001b[1;32m 403\u001b[0m logging\u001b[39m.\u001b[39minfo( \u001b[39m# noqa: LOG015\u001b[39;00m\n\u001b[1;32m 404\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mModel Memory Estimation: Estimating memory from worst case of \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 405\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mlargest and smallest system. Largest system has \u001b[39m\u001b[39m{\u001b[39;00mmax_state\u001b[39m.\u001b[39mn_atoms\u001b[39m}\u001b[39;00m\u001b[39m atoms \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 406\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mand \u001b[39m\u001b[39m{\u001b[39;00mmax_state\u001b[39m.\u001b[39mn_batches\u001b[39m}\u001b[39;00m\u001b[39m batches, and smallest system has \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 407\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mmin_state\u001b[39m.\u001b[39mn_atoms\u001b[39m}\u001b[39;00m\u001b[39m atoms and \u001b[39m\u001b[39m{\u001b[39;00mmin_state\u001b[39m.\u001b[39mn_batches\u001b[39m}\u001b[39;00m\u001b[39m batches.\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 408\u001b[0m )\n\u001b[1;32m 409\u001b[0m min_state_max_batches \u001b[39m=\u001b[39m determine_max_batch_size(min_state, model, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[0;32m--> 410\u001b[0m max_state_max_batches \u001b[39m=\u001b[39m determine_max_batch_size(max_state, model, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 412\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mmin\u001b[39m(min_state_max_batches \u001b[39m*\u001b[39m min_metric, max_state_max_batches \u001b[39m*\u001b[39m max_metric)\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/autobatching.py:297\u001b[0m, in \u001b[0;36mdetermine_max_batch_size\u001b[0;34m(state, model, max_atoms, start_size, scale_factor)\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39mlen\u001b[39m(sizes)):\n\u001b[1;32m 296\u001b[0m n_batches \u001b[39m=\u001b[39m sizes[i]\n\u001b[0;32m--> 297\u001b[0m concat_state \u001b[39m=\u001b[39m concatenate_states([state] \u001b[39m*\u001b[39;49m n_batches)\n\u001b[1;32m 299\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 300\u001b[0m measure_model_memory_forward(concat_state, model)\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/state.py:833\u001b[0m, in \u001b[0;36mconcatenate_states\u001b[0;34m(states, device)\u001b[0m\n\u001b[1;32m 829\u001b[0m target_device \u001b[39m=\u001b[39m device \u001b[39mor\u001b[39;00m first_state\u001b[39m.\u001b[39mdevice\n\u001b[1;32m 831\u001b[0m \u001b[39m# Get property scopes from the first state to identify\u001b[39;00m\n\u001b[1;32m 832\u001b[0m \u001b[39m# global/per-atom/per-batch properties\u001b[39;00m\n\u001b[0;32m--> 833\u001b[0m first_scope \u001b[39m=\u001b[39m infer_property_scope(first_state)\n\u001b[1;32m 834\u001b[0m global_props \u001b[39m=\u001b[39m \u001b[39mset\u001b[39m(first_scope[\u001b[39m\"\u001b[39m\u001b[39mglobal\u001b[39m\u001b[39m\"\u001b[39m])\n\u001b[1;32m 835\u001b[0m per_atom_props \u001b[39m=\u001b[39m \u001b[39mset\u001b[39m(first_scope[\u001b[39m\"\u001b[39m\u001b[39mper_atom\u001b[39m\u001b[39m\"\u001b[39m])\n", + "File \u001b[0;32m/workspaces/propfoliotorchsim/torch-sim/torch_sim/state.py:502\u001b[0m, in \u001b[0;36minfer_property_scope\u001b[0;34m(state, ambiguous_handling)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[39m# TODO: this cannot effectively resolve global properties with\u001b[39;00m\n\u001b[1;32m 498\u001b[0m \u001b[39m# length of n_atoms or n_batches, they will be classified incorrectly,\u001b[39;00m\n\u001b[1;32m 499\u001b[0m \u001b[39m# no clear fix\u001b[39;00m\n\u001b[1;32m 501\u001b[0m \u001b[39mif\u001b[39;00m state\u001b[39m.\u001b[39mn_atoms \u001b[39m==\u001b[39m state\u001b[39m.\u001b[39mn_batches:\n\u001b[0;32m--> 502\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 503\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mn_atoms (\u001b[39m\u001b[39m{\u001b[39;00mstate\u001b[39m.\u001b[39mn_atoms\u001b[39m}\u001b[39;00m\u001b[39m) and n_batches (\u001b[39m\u001b[39m{\u001b[39;00mstate\u001b[39m.\u001b[39mn_batches\u001b[39m}\u001b[39;00m\u001b[39m) are equal, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 504\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mwhich means shapes cannot be inferred unambiguously.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 505\u001b[0m )\n\u001b[1;32m 507\u001b[0m scope \u001b[39m=\u001b[39m {\n\u001b[1;32m 508\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mglobal\u001b[39m\u001b[39m\"\u001b[39m: [],\n\u001b[1;32m 509\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mper_atom\u001b[39m\u001b[39m\"\u001b[39m: [],\n\u001b[1;32m 510\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mper_batch\u001b[39m\u001b[39m\"\u001b[39m: [],\n\u001b[1;32m 511\u001b[0m }\n\u001b[1;32m 513\u001b[0m \u001b[39m# Iterate through all attributes\u001b[39;00m\n", + "\u001b[0;31mValueError\u001b[0m: n_atoms (1) and n_batches (1) are equal, which means shapes cannot be inferred unambiguously." + ] + } + ], + "source": [ + "# --- Main Optimization Loop ---\n", + "batcher.load_states(fire_states)\n", + "all_completed_states, convergence_tensor, state = [], None, None\n", + "while (result := batcher.next_batch(state, convergence_tensor))[0] is not None:\n", + " state, completed_states = result\n", + " print(f\"Starting new batch of {state.n_batches} states.\")\n", + "\n", + " all_completed_states.extend(completed_states)\n", + " print(\"Total number of completed states\", len(all_completed_states))\n", + "\n", + " for _step in range(10):\n", + " state = fire_update(state)\n", + " convergence_tensor = converge_max_force(state, last_energy=None)\n", + "all_completed_states.extend(result[1])\n", + "print(\"Total number of completed states\", len(all_completed_states))\n", + "\n", + "# --- Final Statistics ---\n", + "end_time = time.perf_counter()\n", + "total_time = end_time - start_time\n", + "print(f\"Total time taken: {total_time:.2f} seconds\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5fced48-964f-4516-bf71-4877b49cc7df", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(al_atoms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c45802b0-a644-4e43-81be-e71ecc94e885", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(fe_atoms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e6a723f-299c-4333-b02f-ce098cf6d42b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "job will run on device=device(type='cuda')\n", + "Loading MACE model...\n", + "Using medium MPA-0 model as default MACE-MP model, to use previous (before 3.10) default model please specify 'medium' as model argument\n", + "Using Materials Project MACE for MACECalculator with /home/ray/.cache/mace/macempa0mediummodel\n", + "Using float32 for MACECalculator, which is faster but less accurate. Recommended for MD. Use float64 for geometry optimization.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/propfoliotorchsim/propfolio/.venv/lib/python3.12/site-packages/mace/calculators/foundations_models.py:169: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.\n", + " return torch.load(model_path, map_location=device)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading 2 structures...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/workspaces/propfoliotorchsim/torch-sim/torch_sim/models/mace.py:175: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " self.model.atomic_numbers = torch.tensor(\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "\n", + "import numpy as np\n", + "import torch\n", + "from mace.calculators.foundations_models import mace_mp\n", + "\n", + "import torch_sim as ts\n", + "\n", + "\n", + "# --- Setup and Configuration ---\n", + "# Device and data type configuration\n", + "device = torch.device(\"cpu\") if os.getenv(\"CI\") else torch.device(\"cuda\")\n", + "dtype = torch.float32\n", + "print(f\"job will run on {device=}\")\n", + "\n", + "# --- Model Initialization ---\n", + "print(\"Loading MACE model...\")\n", + "mace_checkpoint_url = \"https://github.com/ACEsuit/mace-mp/releases/download/mace_mpa_0/mace-mpa-0-medium.model\"\n", + "mace = mace_mp(model=mace_checkpoint_url, return_raw_model=True)\n", + "mace_model = ts.models.MaceModel(\n", + " model=mace,\n", + " device=device,\n", + " dtype=dtype,\n", + " compute_forces=True,\n", + ")\n", + "\n", + "# Optimization parameters\n", + "fmax = 0.05 # Force convergence threshold\n", + "n_steps = 10 if os.getenv(\"CI\") else 200_000_000\n", + "max_atoms_in_batch = 50 if os.getenv(\"CI\") else 8_000\n", + "\n", + "# --- Data Loading ---\n", + "if not True:\n", + " n_structures_to_relax = 100\n", + " print(f\"Loading {n_structures_to_relax:,} structures...\")\n", + " from matbench_discovery.data import DataFiles, ase_atoms_from_zip\n", + "\n", + " ase_atoms_list = ase_atoms_from_zip(\n", + " DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax\n", + " )\n", + "else:\n", + " n_structures_to_relax = 2\n", + " print(f\"Loading {n_structures_to_relax:,} structures...\")\n", + " from ase.build import bulk\n", + "\n", + " al_atoms = bulk(\"Al\", \"hcp\", a=4.05)\n", + " al_atoms.positions += 0.1 * np.random.randn(*al_atoms.positions.shape)\n", + " fe_atoms = bulk(\"Fe\", \"bcc\", a=2.86).repeat((2, 2, 2))\n", + " fe_atoms.positions += 0.1 * np.random.randn(*fe_atoms.positions.shape)\n", + " ase_atoms_list = [al_atoms, fe_atoms]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76b32821-543c-4380-b44c-f2f23947ec78", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(fe_atoms)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f10afde7-35e2-4f05-9f5e-3807af1ce3f5", + "metadata": {}, + "outputs": [], + "source": [ + "# --- Optimization Setup ---\n", + "# Statistics tracking\n", + "\n", + "# Initialize first batch\n", + "fire_init, fire_update = ts.optimizers.frechet_cell_fire(model=mace_model)\n", + "fire_states = fire_init(\n", + " ts.io.atoms_to_state(atoms=ase_atoms_list, device=device, dtype=dtype)\n", + ")\n", + "\n", + "batcher = ts.autobatching.InFlightAutoBatcher(\n", + " model=mace_model,\n", + " memory_scales_with=\"n_atoms_x_density\",\n", + " max_memory_scaler=1000 if os.getenv(\"CI\") else None,\n", + ")\n", + "converge_max_force = ts.runners.generate_force_convergence_fn(force_tol=0.05)\n", + "\n", + "start_time = time.perf_counter()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6f7eff0c-7397-43dc-b6d6-4c3f68363fd6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Max metric calculated: 404571.15625\n", + "Starting new batch of 2 states.\n", + "Total number of completed states 0\n", + "Starting new batch of 1 states.\n", + "Total number of completed states 1\n", + "Starting new batch of 1 states.\n", + "Total number of completed states 1\n", + "Starting new batch of 1 states.\n", + "Total number of completed states 1\n", + "Starting new batch of 1 states.\n", + "Total number of completed states 1\n", + "Total number of completed states 2\n", + "Total time taken: 78.69 seconds\n" + ] + } + ], + "source": [ + "# --- Main Optimization Loop ---\n", + "batcher.load_states(fire_states)\n", + "all_completed_states, convergence_tensor, state = [], None, None\n", + "while (result := batcher.next_batch(state, convergence_tensor))[0] is not None:\n", + " state, completed_states = result\n", + " print(f\"Starting new batch of {state.n_batches} states.\")\n", + "\n", + " all_completed_states.extend(completed_states)\n", + " print(\"Total number of completed states\", len(all_completed_states))\n", + "\n", + " for _step in range(10):\n", + " state = fire_update(state)\n", + " convergence_tensor = converge_max_force(state, last_energy=None)\n", + "all_completed_states.extend(result[1])\n", + "print(\"Total number of completed states\", len(all_completed_states))\n", + "\n", + "# --- Final Statistics ---\n", + "end_time = time.perf_counter()\n", + "total_time = end_time - start_time\n", + "print(f\"Total time taken: {total_time:.2f} seconds\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "No kernel connected" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/models/test_graphpes.py b/tests/models/test_graphpes.py index cd32eff2..4764eaaa 100644 --- a/tests/models/test_graphpes.py +++ b/tests/models/test_graphpes.py @@ -144,6 +144,8 @@ def ase_nequip_calculator(device: torch.device, dtype: torch.dtype): model_fixture_name="ts_nequip_model", calculator_fixture_name="ase_nequip_calculator", sim_state_names=consistency_test_simstate_fixtures, + rtol=8e-5, + atol=8e-5, ) test_graphpes_nequip_model_outputs = make_validate_model_outputs_test( diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index e8b5817a..52ce95a0 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -4,8 +4,8 @@ import torch from torch_sim.autobatching import ( - ChunkingAutoBatcher, - HotSwappingAutoBatcher, + BinningAutoBatcher, + InFlightAutoBatcher, calculate_memory_scaler, determine_max_batch_size, to_constant_volume_bins, @@ -124,15 +124,15 @@ def test_split_state(si_double_sim_state: SimState) -> None: assert state[1].cell.shape[0] == 1 -def test_chunking_auto_batcher( +def test_binning_auto_batcher( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test ChunkingAutoBatcher with different states.""" + """Test BinningAutoBatcher with different states.""" # Create a list of states with different sizes states = [si_sim_state, fe_supercell_sim_state] # Initialize the batcher with a fixed max_metric to avoid GPU memory testing - batcher = ChunkingAutoBatcher( + batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0, # Set a small value to force multiple batches @@ -163,13 +163,13 @@ def test_chunking_auto_batcher( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) -def test_chunking_auto_batcher_with_indices( +def test_binning_auto_batcher_with_indices( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test ChunkingAutoBatcher with return_indices=True.""" + """Test BinningAutoBatcher with return_indices=True.""" states = [si_sim_state, fe_supercell_sim_state] - batcher = ChunkingAutoBatcher( + batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0, @@ -190,15 +190,15 @@ def test_chunking_auto_batcher_with_indices( assert indices == batcher.index_bins[i] -def test_chunking_auto_batcher_restore_order_with_split_states( +def test_binning_auto_batcher_restore_order_with_split_states( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test ChunkingAutoBatcher's restore_original_order method with split states.""" + """Test BinningAutoBatcher's restore_original_order method with split states.""" # Create a list of states with different sizes states = [si_sim_state, fe_supercell_sim_state] # Initialize the batcher with a fixed max_metric to avoid GPU memory testing - batcher = ChunkingAutoBatcher( + batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0, # Set a small value to force multiple batches @@ -231,15 +231,15 @@ def test_chunking_auto_batcher_restore_order_with_split_states( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) -def test_hot_swapping_max_metric_too_small( +def test_in_flight_max_metric_too_small( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test HotSwappingAutoBatcher with different states.""" + """Test InFlightAutoBatcher with different states.""" # Create a list of states states = [si_sim_state, fe_supercell_sim_state] # Initialize the batcher with a fixed max_metric - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1.0, # Set a small value to force multiple batches @@ -249,15 +249,15 @@ def test_hot_swapping_max_metric_too_small( batcher.load_states(states) -def test_hot_swapping_auto_batcher( +def test_in_flight_auto_batcher( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test HotSwappingAutoBatcher with different states.""" + """Test InFlightAutoBatcher with different states.""" # Create a list of states states = [si_sim_state, fe_supercell_sim_state] # Initialize the batcher with a fixed max_metric - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, # Set a small value to force multiple batches @@ -317,13 +317,13 @@ def mock_measure(*_args: Any, **_kwargs: Any) -> float: assert max_size == 8 -def test_hot_swapping_auto_batcher_restore_order( +def test_in_flight_auto_batcher_restore_order( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: - """Test HotSwappingAutoBatcher's restore_original_order method.""" + """Test InFlightAutoBatcher's restore_original_order method.""" states = [si_sim_state, fe_supercell_sim_state] - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260.0 ) batcher.load_states(states) @@ -361,7 +361,7 @@ def test_hot_swapping_auto_batcher_restore_order( # batcher.restore_original_order([si_sim_state]) -def test_hot_swapping_with_fire( +def test_in_flight_with_fire( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: fire_init, fire_update = unit_cell_fire(lj_model) @@ -374,7 +374,7 @@ def test_hot_swapping_with_fire( for state in fire_states: state.positions += torch.randn_like(state.positions) * 0.01 - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", # max_metric=400_000, @@ -411,7 +411,7 @@ def convergence_fn(state: SimState) -> bool: assert len(all_completed_states) == len(fire_states) -def test_chunking_auto_batcher_with_fire( +def test_binning_auto_batcher_with_fire( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel ) -> None: fire_init, fire_update = unit_cell_fire(lj_model) @@ -428,7 +428,7 @@ def test_chunking_auto_batcher_with_fire( optimal_batches = to_constant_volume_bins(batch_lengths, 400) optimal_n_batches = len(optimal_batches) - batcher = ChunkingAutoBatcher( + batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=400 ) batcher.load_states(fire_states) @@ -450,18 +450,18 @@ def test_chunking_auto_batcher_with_fire( assert n_batches == optimal_n_batches -def test_hot_swapping_max_iterations( +def test_in_flight_max_iterations( si_sim_state: SimState, fe_supercell_sim_state: SimState, lj_model: LennardJonesModel, ) -> None: - """Test HotSwappingAutoBatcher with max_iterations limit.""" + """Test InFlightAutoBatcher with max_iterations limit.""" # Create states that won't naturally converge states = [si_sim_state.clone(), fe_supercell_sim_state.clone()] # Set max_attempts to a small value to ensure quick termination max_attempts = 3 - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=800.0, diff --git a/tests/test_runners.py b/tests/test_runners.py index 9521b807..ef6b74d0 100644 --- a/tests/test_runners.py +++ b/tests/test_runners.py @@ -4,7 +4,7 @@ import numpy as np import torch -from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher +from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher from torch_sim.integrators import nve, nvt_langevin from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.optimizers import unit_cell_fire @@ -189,7 +189,7 @@ def test_integrate_with_autobatcher( lj_model.device, lj_model.dtype, ) - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, @@ -224,7 +224,7 @@ def test_integrate_with_autobatcher_and_reporting( lj_model.device, lj_model.dtype, ) - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, @@ -382,7 +382,7 @@ def test_optimize_with_autobatcher( lj_model.device, lj_model.dtype, ) - autobatcher = HotSwappingAutoBatcher( + autobatcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, @@ -417,7 +417,7 @@ def test_optimize_with_autobatcher_and_reporting( ) triple_state.positions += torch.randn_like(triple_state.positions) * 0.1 - autobatcher = HotSwappingAutoBatcher( + autobatcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, @@ -626,7 +626,7 @@ def test_static_with_autobatcher( lj_model.device, lj_model.dtype, ) - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, @@ -660,7 +660,7 @@ def test_static_with_autobatcher_and_reporting( lj_model.device, lj_model.dtype, ) - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=260, diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index c784b259..0920e5c0 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -22,7 +22,7 @@ transforms, units, ) -from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher +from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher from torch_sim.integrators import npt_langevin, nve, nvt_langevin # state propagators diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 799b884c..0db427a3 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -3,12 +3,12 @@ This module provides utilities for efficient batch processing of simulation states by dynamically determining optimal batch sizes based on GPU memory constraints. It includes tools for memory usage estimation, batch size determination, and -two complementary strategies for batching: chunking and hot-swapping. +two complementary strategies for batching: binning and hot-swapping. Example: - Using ChunkingAutoBatcher with a model:: + Using BinningAutoBatcher with a model:: - batcher = ChunkingAutoBatcher(model, memory_scales_with="n_atoms") + batcher = BinningAutoBatcher(model, memory_scales_with="n_atoms") batcher.load_states(states) final_states = [] for batch in batcher: @@ -412,7 +412,7 @@ def estimate_max_memory_scaler( return min(min_state_max_batches * min_metric, max_state_max_batches * max_metric) -class ChunkingAutoBatcher: +class BinningAutoBatcher: """Batcher that groups states into bins of similar computational cost. Divides a collection of states into batches that can be processed efficiently @@ -440,7 +440,7 @@ class ChunkingAutoBatcher: Example:: # Create a batcher with a Lennard-Jones model - batcher = ChunkingAutoBatcher( + batcher = BinningAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1000.0 ) @@ -464,7 +464,7 @@ def __init__( max_atoms_to_try: int = 500_000, memory_scaling_factor: float = 1.6, ) -> None: - """Initialize the chunking auto-batcher. + """Initialize the binning auto-batcher. Args: model (ModelInterface): Model to batch for, used to estimate memory @@ -694,7 +694,7 @@ def restore_original_order(self, batched_states: list[SimState]) -> list[SimStat return [state for _, state in sorted(indexed_states, key=lambda x: x[0])] -class HotSwappingAutoBatcher: +class InFlightAutoBatcher: """Batcher that dynamically swaps states based on convergence. Optimizes GPU utilization by removing converged states from the batch and @@ -725,7 +725,7 @@ class HotSwappingAutoBatcher: Example:: # Create a hot-swapping batcher - batcher = HotSwappingAutoBatcher( + batcher = InFlightAutoBatcher( model=lj_model, memory_scales_with="n_atoms", max_memory_scaler=1000.0 ) @@ -799,7 +799,7 @@ def load_states( Processes the input states, computes memory scaling metrics for each, and prepares them for dynamic batching based on convergence criteria. - Unlike ChunkingAutoBatcher, this doesn't create fixed batches upfront. + Unlike BinningAutoBatcher, this doesn't create fixed batches upfront. Args: states (list[SimState] | Iterator[SimState] | SimState): Collection of @@ -957,7 +957,7 @@ def next_batch( Removes converged states from the batch, adds new states if possible, and returns both the updated batch and the completed states. This method - implements the core dynamic batching strategy of the HotSwappingAutoBatcher. + implements the core dynamic batching strategy of the InFlightAutoBatcher. Args: updated_state (SimState | None): Current state after processing, or None diff --git a/torch_sim/runners.py b/torch_sim/runners.py index 1ff09235..75f4bce8 100644 --- a/torch_sim/runners.py +++ b/torch_sim/runners.py @@ -12,7 +12,7 @@ import torch -from torch_sim.autobatching import ChunkingAutoBatcher, HotSwappingAutoBatcher +from torch_sim.autobatching import BinningAutoBatcher, InFlightAutoBatcher from torch_sim.models.interface import ModelInterface from torch_sim.quantities import batchwise_max_force, calc_kinetic_energy, calc_kT from torch_sim.state import SimState, StateLike, concatenate_states, initialize_state @@ -60,27 +60,27 @@ def _configure_reporter( def _configure_batches_iterator( model: ModelInterface, state: SimState, - autobatcher: ChunkingAutoBatcher | bool, -) -> ChunkingAutoBatcher: + autobatcher: BinningAutoBatcher | bool, +) -> BinningAutoBatcher | list[tuple[SimState, list[int]]]: """Create a batches iterator for the integrate function. Args: model (ModelInterface): The model to use for the integration state (SimState): The state to use for the integration - autobatcher (ChunkingAutoBatcher | bool): The autobatcher to use for integration + autobatcher (BinningAutoBatcher | bool): The autobatcher to use for integration Returns: A batches iterator """ # load and properly configure the autobatcher if autobatcher is True: - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=model, return_indices=True, ) autobatcher.load_states(state) batches = autobatcher - elif isinstance(autobatcher, ChunkingAutoBatcher): + elif isinstance(autobatcher, BinningAutoBatcher): autobatcher.load_states(state) autobatcher.return_indices = True batches = autobatcher @@ -89,7 +89,7 @@ def _configure_batches_iterator( else: raise ValueError( f"Invalid autobatcher type: {type(autobatcher).__name__}, " - "must be bool or ChunkingAutoBatcher." + "must be bool or BinningAutoBatcher." ) return batches @@ -103,7 +103,7 @@ def integrate( temperature: float | list | torch.Tensor, timestep: float, trajectory_reporter: TrajectoryReporter | dict | None = None, - autobatcher: ChunkingAutoBatcher | bool = False, + autobatcher: BinningAutoBatcher | bool = False, **integrator_kwargs: dict, ) -> SimState: """Simulate a system using a model and integrator. @@ -120,7 +120,7 @@ def integrate( trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking trajectory. If a dict, will be passed to the TrajectoryReporter constructor. - autobatcher (ChunkingAutoBatcher | bool): Optional autobatcher to use + autobatcher (BinningAutoBatcher | bool): Optional autobatcher to use **integrator_kwargs: Additional keyword arguments for integrator init function Returns: @@ -144,8 +144,8 @@ def integrate( dt=torch.tensor(timestep * unit_system.time, dtype=dtype, device=device), **integrator_kwargs, ) - # state = init_fn(state) + # batch_iterator will be a list if autobatcher is False batch_iterator = _configure_batches_iterator(model, state, autobatcher) trajectory_reporter = _configure_reporter( trajectory_reporter, @@ -177,25 +177,25 @@ def integrate( if trajectory_reporter: trajectory_reporter.finish() - if isinstance(batch_iterator, ChunkingAutoBatcher): + if isinstance(batch_iterator, BinningAutoBatcher): reordered_states = batch_iterator.restore_original_order(final_states) return concatenate_states(reordered_states) return state -def _configure_hot_swapping_autobatcher( +def _configure_in_flight_autobatcher( model: ModelInterface, state: SimState, - autobatcher: HotSwappingAutoBatcher | bool, + autobatcher: InFlightAutoBatcher | bool, max_attempts: int, # TODO: change name to max_iterations -) -> HotSwappingAutoBatcher: +) -> InFlightAutoBatcher: """Configure the hot swapping autobatcher for the optimize function. Args: model (ModelInterface): The model to use for the autobatcher state (SimState): The state to use for the autobatcher - autobatcher (HotSwappingAutoBatcher | bool): The autobatcher to use for the + autobatcher (InFlightAutoBatcher | bool): The autobatcher to use for the autobatcher max_attempts (int): The maximum number of attempts for the autobatcher @@ -203,7 +203,7 @@ def _configure_hot_swapping_autobatcher( A hot swapping autobatcher """ # load and properly configure the autobatcher - if isinstance(autobatcher, HotSwappingAutoBatcher): + if isinstance(autobatcher, InFlightAutoBatcher): autobatcher.return_indices = True autobatcher.max_attempts = max_attempts else: @@ -213,7 +213,7 @@ def _configure_hot_swapping_autobatcher( else: memory_scales_with = "n_atoms" max_memory_scaler = state.n_atoms + 1 - autobatcher = HotSwappingAutoBatcher( + autobatcher = InFlightAutoBatcher( model=model, return_indices=True, max_memory_scaler=max_memory_scaler, @@ -243,7 +243,7 @@ def _chunked_apply( Returns: A state with the function applied """ - autobatcher = ChunkingAutoBatcher( + autobatcher = BinningAutoBatcher( model=model, return_indices=False, **batcher_kwargs, @@ -308,7 +308,7 @@ def optimize( optimizer: Callable, convergence_fn: Callable | None = None, trajectory_reporter: TrajectoryReporter | dict | None = None, - autobatcher: HotSwappingAutoBatcher | bool = False, + autobatcher: InFlightAutoBatcher | bool = False, max_steps: int = 10_000, steps_between_swaps: int = 5, **optimizer_kwargs: dict, @@ -326,11 +326,11 @@ def optimize( trajectory_reporter (TrajectoryReporter | dict | None): Optional reporter for tracking optimization trajectory. If a dict, will be passed to the TrajectoryReporter constructor. - autobatcher (HotSwappingAutoBatcher | bool): Optional autobatcher to use. If + autobatcher (InFlightAutoBatcher | bool): Optional autobatcher to use. If False, the system will assume infinite memory and will not batch, but will still remove converged structures from the batch. If True, the system will estimate the memory - available and batch accordingly. If a HotSwappingAutoBatcher, the system + available and batch accordingly. If a InFlightAutoBatcher, the system will use the provided autobatcher, but will reset the max_attempts to max_steps // steps_between_swaps. max_steps (int): Maximum number of total optimization steps @@ -350,7 +350,7 @@ def optimize( init_fn, update_fn = optimizer(model=model, **optimizer_kwargs) max_attempts = max_steps // steps_between_swaps - autobatcher = _configure_hot_swapping_autobatcher( + autobatcher = _configure_in_flight_autobatcher( model, state, autobatcher, max_attempts ) state = _chunked_apply( @@ -412,7 +412,7 @@ def static( model: ModelInterface, *, trajectory_reporter: TrajectoryReporter | dict | None = None, - autobatcher: ChunkingAutoBatcher | bool = False, + autobatcher: BinningAutoBatcher | bool = False, ) -> list[dict[str, torch.Tensor]]: """Run single point calculations on a batch of systems. @@ -433,7 +433,7 @@ def static( Make sure that if multiple unique states are used, that the `variable_atomic_numbers` and `variable_masses` are set to `True` in the `state_kwargs` argument. - autobatcher (ChunkingAutoBatcher | bool): Optional autobatcher to use for + autobatcher (BinningAutoBatcher | bool): Optional autobatcher to use for batching calculations Returns: @@ -491,7 +491,7 @@ class StaticState(type(state)): trajectory_reporter.finish() - if isinstance(batch_iterator, ChunkingAutoBatcher): + if isinstance(batch_iterator, BinningAutoBatcher): # reorder properties to match original order of states original_indices = list(chain.from_iterable(batch_iterator.index_bins)) return [all_props[idx] for idx in original_indices]