/
test_particle_collections.py
77 lines (59 loc) · 3.23 KB
/
test_particle_collections.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import pytest
from parcels import ( # noqa
FieldSet,
JITParticle,
KernelAOS,
KernelSOA,
ParticleFileAOS,
ParticleFileSOA,
ParticleSetAOS,
ParticleSetSOA,
ScipyParticle,
)
pset_modes = ['soa', 'aos']
ptype = {'scipy': ScipyParticle, 'jit': JITParticle}
pset_type = {'soa': {'pset': ParticleSetSOA, 'pfile': ParticleFileSOA, 'kernel': KernelSOA},
'aos': {'pset': ParticleSetAOS, 'pfile': ParticleFileAOS, 'kernel': KernelAOS}}
def fieldset(xdim=40, ydim=100):
U = np.zeros((ydim, xdim), dtype=np.float32)
V = np.zeros((ydim, xdim), dtype=np.float32)
lon = np.linspace(0, 1, xdim, dtype=np.float32)
lat = np.linspace(-60, 60, ydim, dtype=np.float32)
depth = np.zeros(1, dtype=np.float32)
data = {'U': np.array(U, dtype=np.float32), 'V': np.array(V, dtype=np.float32)}
dimensions = {'lat': lat, 'lon': lon, 'depth': depth}
return FieldSet.from_data(data, dimensions)
@pytest.fixture(name="fieldset")
def fieldset_fixture(xdim=40, ydim=100):
return fieldset(xdim=xdim, ydim=ydim)
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_iteration_forward(fieldset, pset_mode, npart=10):
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=JITParticle)
assert np.all(np.isclose(np.array([p.id for p in pset]), range(npart)+pset[0].id))
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_iteration_backward(fieldset, pset_mode, npart=10):
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=JITParticle)
assert np.all(np.isclose(np.array([p.id for p in reversed(pset)]), pset[0].id+np.arange(npart-1, -1, -1)))
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_get(fieldset, pset_mode, npart=10):
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=JITParticle)
assert np.all(np.isclose([pset.collection.get(i).lon for i in range(npart)], np.linspace(0, 1, npart)))
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_get_single_by_index(fieldset, pset_mode, npart=10):
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=JITParticle)
assert np.all(np.isclose([pset.collection.get_single_by_index(i).lon for i in range(npart)], np.linspace(0, 1, npart)))
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_get_single_by_ID(fieldset, pset_mode, npart=10):
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=np.zeros(npart), pclass=JITParticle)
ids = None
if pset_mode == 'soa':
ids = pset.collection._data['id']
elif pset_mode == 'aos':
ids = np.array([pset.collection._data[i].id for i in range(len(pset))], dtype=np.int64)
assert np.all(np.isclose([pset.collection.get_single_by_ID(np.int64(i)).lon for i in ids], np.linspace(0, 1, npart)))
@pytest.mark.parametrize('pset_mode', pset_modes)
def test_pset_getattr(fieldset, pset_mode, npart=10):
lats = np.random.random(npart)
pset = pset_type[pset_mode]['pset'](fieldset, lon=np.linspace(0, 1, npart), lat=lats, pclass=JITParticle)
assert np.allclose(pset.lat, lats)