-
Notifications
You must be signed in to change notification settings - Fork 107
/
test_restart.py
143 lines (107 loc) · 4.05 KB
/
test_restart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
__doc__ = """Test restart functionality """
import pytest
import numpy as np
from numpy.testing import assert_allclose
from elastica.utils import Tolerance
from elastica.modules import (
BaseSystemCollection,
Constraints,
Forcing,
Connections,
CallBacks,
)
from elastica.restart import save_state, load_state
class GenericSimulatorClass(
BaseSystemCollection, Constraints, Forcing, Connections, CallBacks
):
pass
class TestRestartFunctionsWithFeaturesUsingCosseratRod:
@pytest.fixture(scope="function")
def load_collection(self):
sc = GenericSimulatorClass()
from elastica.rod.cosserat_rod import CosseratRod
# rod = RodBase()
rod_list = []
for _ in range(5):
rod = CosseratRod.straight_rod(
n_elements=10,
start=np.zeros((3)),
direction=np.array([0, 1, 0.0]),
normal=np.array([1, 0, 0.0]),
base_length=1,
base_radius=1,
density=1,
nu=1,
youngs_modulus=1,
)
# Bypass check, but its fine for testing
sc._systems.append(rod)
# Also add rods to a separate list
rod_list.append(rod)
return sc, rod_list
def test_restart_save_load(self, load_collection):
simulator_class, rod_list = load_collection
# Finalize simulator
simulator_class.finalize()
directory = "restart_test_data/"
time = np.random.rand()
# save state
save_state(simulator_class, directory, time=time)
# load state
restart_time = load_state(simulator_class, directory)
# check if restart time loaded correctly
assert_allclose(restart_time, time, atol=Tolerance.atol())
# check if rods are loaded correctly
for idx, correct_rod in enumerate(rod_list):
test_rod = simulator_class[idx]
for key, value in correct_rod.__dict__.items():
# get correct values
correct_value = getattr(correct_rod, key)
# get test values
test_value = getattr(test_rod, key)
assert_allclose(test_value, correct_value)
class TestRestartFunctionsWithFeaturesUsingRigidBodies:
@pytest.fixture(scope="function")
def load_collection(self):
sc = GenericSimulatorClass()
from elastica.rigidbody import Cylinder
# rod = RodBase()
cylinder_list = []
for _ in range(5):
cylinder = Cylinder(
start=np.zeros((3)),
direction=np.array([0, 1, 0.0]),
normal=np.array([1, 0, 0.0]),
base_length=1,
base_radius=1,
density=1,
)
# Bypass check, but its fine for testing
sc._systems.append(cylinder)
# Also add rods to a separate list
cylinder_list.append(cylinder)
return sc, cylinder_list
def test_restart_save_load(self, load_collection):
simulator_class, cylinder_list = load_collection
# Finalize simulator
simulator_class.finalize()
directory = "restart_test_data/"
time = np.random.rand()
# save state
save_state(simulator_class, directory, time=time)
# load state
restart_time = load_state(simulator_class, directory)
# check if restart time loaded correctly
assert_allclose(restart_time, time, atol=Tolerance.atol())
# check if rods are loaded correctly
for idx, correct_cylinder in enumerate(cylinder_list):
test_cylinder = simulator_class[idx]
for key, value in correct_cylinder.__dict__.items():
# get correct values
correct_value = getattr(correct_cylinder, key)
# get test values
test_value = getattr(test_cylinder, key)
assert_allclose(test_value, correct_value)
if __name__ == "__main__":
from pytest import main
main([__file__])