Skip to content

Commit

Permalink
Merge pull request #249 from armantekinalp/239_fix_restart_functionality
Browse files Browse the repository at this point in the history
239 fix restart functionality
  • Loading branch information
bhosale2 committed May 11, 2023
2 parents b271ad4 + 2b92732 commit 3521ce1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 4 deletions.
1 change: 1 addition & 0 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@
)
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
from elastica.restart import save_state, load_state
13 changes: 9 additions & 4 deletions elastica/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import os
from itertools import groupby
from .memory_block import MemoryBlockCosseratRod, MemoryBlockRigidBody


def all_equal(iterable):
Expand Down Expand Up @@ -41,6 +42,10 @@ def save_state(simulator, directory: str = "", time=0.0, verbose: bool = False):
"""
os.makedirs(directory, exist_ok=True)
for idx, rod in enumerate(simulator):
if isinstance(rod, MemoryBlockCosseratRod) or isinstance(
rod, MemoryBlockRigidBody
):
continue
path = os.path.join(directory, "system_{}.npz".format(idx))
np.savez(path, time=time, **rod.__dict__)

Expand Down Expand Up @@ -69,6 +74,10 @@ def load_state(simulator, directory: str = "", verbose: bool = False):
"""
time_list = [] # Simulation time of rods when they are saved.
for idx, rod in enumerate(simulator):
if isinstance(rod, MemoryBlockCosseratRod) or isinstance(
rod, MemoryBlockRigidBody
):
continue
path = os.path.join(directory, "system_{}.npz".format(idx))
data = np.load(path, allow_pickle=True)
for key, value in data.items():
Expand All @@ -88,10 +97,6 @@ def load_state(simulator, directory: str = "", verbose: bool = False):
"Restart time of loaded rods are different, check your inputs!"
)

# Apply boundary conditions, after loading the systems.
simulator.constrain_values(0.0)
simulator.constrain_rates(0.0)

if verbose:
print("Load complete: {}".format(directory))

Expand Down
93 changes: 93 additions & 0 deletions tests/test_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CallBacks,
)
from elastica.restart import save_state, load_state
import elastica as ea


class GenericSimulatorClass(
Expand Down Expand Up @@ -78,6 +79,98 @@ def test_restart_save_load(self, load_collection):

assert_allclose(test_value, correct_value)

def run_sim(self, final_time, load_from_restart, save_data_restart):
class BaseSimulatorClass(
BaseSystemCollection, Constraints, Forcing, Connections, CallBacks
):
pass

simulator_class = BaseSimulatorClass()

rod_list = []
for _ in range(5):
rod = ea.CosseratRod.straight_rod(
n_elements=10,
start=np.zeros((3)),
direction=np.array([0, 1, 0.0]),
normal=np.array([1, 0, 0.0]),
base_length=1,
base_radius=1,
density=1,
youngs_modulus=1,
)
# Bypass check, but its fine for testing
simulator_class._systems.append(rod)

# Also add rods to a separate list
rod_list.append(rod)

for rod in rod_list:
simulator_class.add_forcing_to(rod).using(
ea.EndpointForces,
start_force=np.zeros(
3,
),
end_force=np.array([0, 0.1, 0]),
ramp_up_time=0.1,
)

# Finalize simulator
simulator_class.finalize()

directory = "restart_test_data/"

time_step = 1e-4
total_steps = int(final_time / time_step)

if load_from_restart:
restart_time = ea.load_state(simulator_class, directory, True)

else:
restart_time = np.float64(0.0)

timestepper = ea.PositionVerlet()
time = ea.integrate(
timestepper,
simulator_class,
final_time,
total_steps,
restart_time=restart_time,
)

if save_data_restart:
ea.save_state(simulator_class, directory, time, True)

# Compute final time accelerations
recorded_list = np.zeros((len(rod_list), rod_list[0].n_elems + 1))
for i, rod in enumerate(rod_list):
recorded_list[i, :] = rod.acceleration_collection[1, :]

return recorded_list

@pytest.mark.parametrize("final_time", [0.2, 1.0])
def test_save_restart_run_sim(self, final_time):

# First half of simulation
_ = self.run_sim(
final_time / 2, load_from_restart=False, save_data_restart=True
)

# Second half of simulation
recorded_list = self.run_sim(
final_time / 2, load_from_restart=True, save_data_restart=False
)
recorded_list_second_half = recorded_list.copy()

# Full simulation
recorded_list = self.run_sim(
final_time, load_from_restart=False, save_data_restart=False
)
recorded_list_full_sim = recorded_list.copy()

# Compare final accelerations of rods
assert_allclose(recorded_list_second_half, recorded_list_full_sim)


class TestRestartFunctionsWithFeaturesUsingRigidBodies:
@pytest.fixture(scope="function")
Expand Down

0 comments on commit 3521ce1

Please sign in to comment.