/
test_restart.py
93 lines (81 loc) · 2.83 KB
/
test_restart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pytest
import easyvvuq as uq
import chaospy as cp
from gauss.decoder_gauss import GaussDecoder
@pytest.fixture
def restart(tmpdir):
my_campaign = uq.Campaign(name='gauss', work_dir=tmpdir, db_type='sql')
params = {
"sigma": {
"type": "float",
"min": 0.0,
"max": 100000.0,
"default": 0.25
},
"mu": {
"type": "float",
"min": 0.0,
"max": 100000.0,
"default": 1
},
"num_steps": {
"type": "integer",
"min": 0,
"max": 100000,
"default": 10
},
"out_file": {
"type": "string",
"default": "output.csv"
},
"bias": {
"type": "fixture",
"allowed": ["bias1", "bias2"],
"default": "bias1"
}
}
encoder = uq.encoders.GenericEncoder(template_fname='tests/gauss/gauss.template',
target_filename='gauss_in.json')
decoder = GaussDecoder(target_filename=params['out_file']['default'])
collater = uq.collate.AggregateSamples(average=False)
my_campaign.add_app(name='gauss',
params=params,
encoder=encoder,
decoder=decoder,
collater=collater)
my_campaign.set_app('gauss')
vary = {
"mu": cp.Uniform(1.0, 100.0),
}
sampler = uq.sampling.RandomSampler(vary=vary)
my_campaign.set_sampler(sampler)
my_campaign.draw_samples(num_samples=2, replicas=2)
my_campaign.populate_runs_dir()
my_campaign.collate()
state_file = tmpdir + "{}_state.json".format('gauss')
my_campaign.save_state(state_file)
my_campaign = None
reloaded_campaign = uq.Campaign(state_file=state_file, work_dir=tmpdir)
reloaded_campaign.set_app('gauss')
reloaded_campaign.params_ = params
return reloaded_campaign
def test_restart(restart):
assert(restart.campaign_db is not None)
def test_app(restart):
db = restart.campaign_db
assert(db.app('gauss')['name'] == 'gauss')
assert(db.app('gauss')['params'].params_dict['mu']['max'] == 100000.0)
def test_runs(restart):
db = restart.campaign_db
assert(db.get_num_runs() == 4)
restart.draw_samples(num_samples=2, replicas=2)
assert(db.get_num_runs() == 8)
def test_encoder(restart):
app = restart.campaign_db.app('gauss')
encoder = uq.encoders.GenericEncoder(template_fname='tests/gauss/gauss.template',
target_filename='gauss_in.json')
assert(app['input_encoder'] == encoder.serialize())
def test_decoder(restart):
app = restart.campaign_db.app('gauss')
decoder = GaussDecoder(target_filename=restart.params_['out_file']['default'])
assert(app['output_decoder'] == decoder.serialize())