/
sqlite_reader.py
1531 lines (1303 loc) · 55.7 KB
/
sqlite_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Definition of the SqliteCaseReader.
"""
import sqlite3
from collections import OrderedDict
from io import StringIO
import sys
import numpy as np
from openmdao.recorders.base_case_reader import BaseCaseReader
from openmdao.recorders.case import Case
from openmdao.core.constants import _DEFAULT_OUT_STREAM
from openmdao.utils.general_utils import simple_warning
from openmdao.utils.variable_table import write_source_table
from openmdao.utils.record_util import check_valid_sqlite3_db, get_source_system
from openmdao.recorders.sqlite_recorder import format_version
import pickle
from json import loads as json_loads
class SqliteCaseReader(BaseCaseReader):
"""
A CaseReader specific to files created with SqliteRecorder.
Attributes
----------
problem_metadata : dict
Metadata about the problem, including the system hierachy and connections.
solver_metadata : dict
The solver options for each solver in the recorded model.
_system_options : dict
Metadata about each system in the recorded model, including options and scaling factors.
_format_version : int
The version of the format assumed when loading the file.
_solver_metadata : dict
Metadata for all the solvers in the model, including their type and options
_filename : str
The path to the filename containing the recorded data.
_abs2meta : dict
Dictionary mapping variables to their metadata
_abs2prom : {'input': dict, 'output': dict}
Dictionary mapping absolute names to promoted names.
_prom2abs : {'input': dict, 'output': dict}
Dictionary mapping promoted names to absolute names.
_conns : dict
Dictionary of all model connections.
_auto_ivc_map : dict
Dictionary that maps all auto_ivc sources to either an absolute input name for single
connections or a promoted input name for multiple connections. This is for output display.
_driver_cases : DriverCases
Helper object for accessing cases from the driver_iterations table.
_deriv_cases : DerivCases
Helper object for accessing cases from the driver_derivatives table.
_system_cases : SystemCases
Helper object for accessing cases from the system_iterations table.
_solver_cases : SolverCases
Helper object for accessing cases from the solver_iterations table.
_problem_cases : ProblemCases
Helper object for accessing cases from the problem_cases table.
_global_iterations : list
List of iteration cases and the table and row in which they are found.
"""
def __init__(self, filename, pre_load=False):
"""
Initialize.
Parameters
----------
filename : str
The path to the filename containing the recorded data.
pre_load : bool
If True, load all the data into memory during initialization.
"""
super(SqliteCaseReader, self).__init__(filename, pre_load)
check_valid_sqlite3_db(filename)
# initialize private attributes
self._filename = filename
self._abs2prom = None
self._prom2abs = None
self._abs2meta = None
self._conns = None
self._auto_ivc_map = {}
self._global_iterations = None
# collect metadata from database
with sqlite3.connect(filename) as con:
con.row_factory = sqlite3.Row
cur = con.cursor()
# collect data from the metadata table. this includes:
# format_version
# VOI metadata, which is added to problem_metadata
# var name maps and metadata for all vars, which are saved as private attributes
self._collect_metadata(cur)
# collect data from the driver_metadata table. this includes:
# model viewer data, which is added to problem_metadata
self._collect_driver_metadata(cur)
# collect data from the system_metadata table. this includes:
# component metadata and scaling factors for each system,
# which is added to _system_options
self._collect_system_metadata(cur)
# collect data from the solver_metadata table. this includes:
# solver class and options for each solver, which is saved as an attribute
self._collect_solver_metadata(cur)
# get the global iterations table, and save it as an attribute
self._global_iterations = self._get_global_iterations(cur)
con.close()
# create helper objects for accessing cases from the three iteration tables and
# the problem cases table
var_info = self.problem_metadata['variables']
self._driver_cases = DriverCases(filename, self._format_version, self._global_iterations,
self._prom2abs, self._abs2prom, self._abs2meta,
self._conns, self._auto_ivc_map, var_info)
self._system_cases = SystemCases(filename, self._format_version, self._global_iterations,
self._prom2abs, self._abs2prom, self._abs2meta,
self._conns, self._auto_ivc_map, var_info)
self._solver_cases = SolverCases(filename, self._format_version, self._global_iterations,
self._prom2abs, self._abs2prom, self._abs2meta,
self._conns, self._auto_ivc_map, var_info)
if self._format_version >= 2:
self._problem_cases = ProblemCases(filename,
self._format_version,
self._global_iterations,
self._prom2abs, self._abs2prom, self._abs2meta,
self._conns, self._auto_ivc_map, var_info)
# if requested, load all the iteration data into memory
if pre_load:
self._load_cases()
def _collect_metadata(self, cur):
"""
Load data from the metadata table.
Populates the `format_version` attribute and the `variables` data in
the `problem_metadata` attribute of this CaseReader. Also saves the
variable name maps and variable metadata to private attributes.
Parameters
----------
cur : sqlite3.Cursor
Database cursor to use for reading the data.
"""
cur.execute('select * from metadata')
row = cur.fetchone()
# get format_version
self._format_version = version = row['format_version']
if version not in range(1, format_version + 1):
raise ValueError('SQliteCaseReader encountered an unhandled '
'format version: {0}'.format(self._format_version))
if version >= 11:
# Auto-IVC
self._conns = json_loads(row['conns'])
# add metadata for VOIs (des vars, objective, constraints) to problem metadata
if version >= 4:
self.problem_metadata['variables'] = json_loads(row['var_settings'])
else:
self.problem_metadata['variables'] = None
# get variable name maps and metadata for all variables
if version >= 3:
self._abs2prom = json_loads(row['abs2prom'])
self._prom2abs = json_loads(row['prom2abs'])
self._abs2meta = json_loads(row['abs2meta'])
# need to convert bounds to numpy arrays
for name, meta in self._abs2meta.items():
if 'lower' in meta and meta['lower'] is not None:
meta['lower'] = np.resize(np.array(meta['lower']), meta['shape'])
if 'upper' in meta and meta['upper'] is not None:
meta['upper'] = np.resize(np.array(meta['upper']), meta['shape'])
# Map ivc_source names to input display text.
if version >= 11:
self._auto_ivc_map = auto_ivc_map = {}
abs2prom_in = self._abs2prom['input']
for target, src in self._conns.items():
if src.startswith('_auto_ivc.'):
if src not in auto_ivc_map:
auto_ivc_map[src] = []
auto_ivc_map[src].append(target)
for output, input_list in auto_ivc_map.items():
if len(input_list) > 1:
for input_name in input_list:
# If this recorder is on a component, we might have only a subset of
# the metadata dictionary, but one of them will be in there.
if input_name in abs2prom_in:
auto_ivc_map[output] = abs2prom_in[input_name]
break
else:
auto_ivc_map[output] = abs2prom_in[input_list[0]]
elif version in (1, 2):
abs2prom = row['abs2prom']
prom2abs = row['prom2abs']
abs2meta = row['abs2meta']
try:
self._abs2prom = pickle.loads(abs2prom)
self._prom2abs = pickle.loads(prom2abs)
self._abs2meta = pickle.loads(abs2meta)
except TypeError:
# Reading in a python 2 pickle recorded pre-OpenMDAO 2.4.
self._abs2prom = pickle.loads(abs2prom.encode())
self._prom2abs = pickle.loads(prom2abs.encode())
self._abs2meta = pickle.loads(abs2meta.encode())
self.problem_metadata['abs2prom'] = self._abs2prom
def _collect_driver_metadata(self, cur):
"""
Load data from the driver_metadata table.
Populates the `problem_metadata` attribute of this CaseReader.
Parameters
----------
cur : sqlite3.Cursor
Database cursor to use for reading the data.
"""
cur.execute("SELECT model_viewer_data FROM driver_metadata")
row = cur.fetchone()
if row is not None:
if self._format_version >= 3:
driver_metadata = json_loads(row[0])
elif self._format_version in (1, 2):
driver_metadata = pickle.loads(row[0])
self.problem_metadata.update(driver_metadata)
def _collect_system_metadata(self, cur):
"""
Load data from the system table.
Populates the `system_options` attribute of this CaseReader.
Parameters
----------
cur : sqlite3.Cursor
Database cursor to use for reading the data.
"""
cur.execute("SELECT id, scaling_factors, component_metadata FROM system_metadata")
for row in cur:
id = row[0]
self._system_options[id] = {}
self._system_options[id]['scaling_factors'] = pickle.loads(row[1])
self._system_options[id]['component_options'] = pickle.loads(row[2])
def _collect_solver_metadata(self, cur):
"""
Load data from the solver_metadata table.
Populates the `solver_metadata` attribute of this CaseReader.
Parameters
----------
cur : sqlite3.Cursor
Database cursor to use for reading the data.
"""
cur.execute("SELECT id, solver_options, solver_class FROM solver_metadata")
for row in cur:
id = row[0]
solver_options = pickle.loads(row[1])
solver_class = row[2]
self.solver_metadata[id] = {
'solver_options': solver_options,
'solver_class': solver_class,
}
def _get_global_iterations(self, cur):
"""
Get the global iterations table.
Parameters
----------
cur : sqlite3.Cursor
Database cursor to use for reading the data.
Returns
-------
list
List of global iterations and the table and row where the associated case is found.
"""
cur.execute('select * from global_iterations')
return cur.fetchall()
def _load_cases(self):
"""
Load all driver, solver, and system cases into memory.
"""
self._driver_cases._load_cases()
self._solver_cases._load_cases()
self._system_cases._load_cases()
if self._format_version >= 2:
self._problem_cases._load_cases()
def list_sources(self, out_stream=_DEFAULT_OUT_STREAM):
"""
List of all the different recording sources for which there is recorded data.
Parameters
----------
out_stream : file-like object
Where to send human readable output. Default is sys.stdout.
Set to None to suppress.
Returns
-------
list
One or more of: `problem`, `driver`, `<system hierarchy location>`,
`<solver hierarchy location>`
"""
sources = []
if self._driver_cases.count() > 0:
sources.extend(self._driver_cases.list_sources())
if self._solver_cases.count() > 0:
sources.extend(self._solver_cases.list_sources())
if self._system_cases.count() > 0:
sources.extend(self._system_cases.list_sources())
if self._format_version >= 2 and self._problem_cases.count() > 0:
sources.extend(self._problem_cases.list_sources())
if out_stream:
if out_stream is _DEFAULT_OUT_STREAM:
out_stream = sys.stdout
for source in sources:
out_stream.write('{}\n'.format(source))
return sources
def list_source_vars(self, source, out_stream=_DEFAULT_OUT_STREAM):
"""
List of all inputs and outputs recorded by the specified source.
Parameters
----------
source : {'problem', 'driver', <system hierarchy location>, <solver hierarchy location>}
Identifies the source for which to return information.
out_stream : file-like object
Where to send human readable output. Default is sys.stdout.
Set to None to suppress.
Returns
-------
dict
{'inputs':[key list], 'outputs':[key list], 'residuals':[key list]}. No recurse.
"""
dct = {
'inputs': [],
'outputs': [],
'residuals': [],
}
case = None
if source == 'problem':
if self._problem_cases.count() > 0:
case = self._problem_cases.get_case(0)
elif source == 'driver':
if self._driver_cases.count() > 0:
case = self._driver_cases.get_case(0)
elif source in self._system_cases.list_sources():
source_cases = self._system_cases.list_cases(source)
case = self._system_cases.get_case(source_cases[0])
elif source in self._solver_cases.list_sources():
source_cases = self._solver_cases.list_cases(source)
case = self._solver_cases.get_case(source_cases[0])
else:
raise RuntimeError('Source not found: %s' % source)
if case is None:
raise RuntimeError('No cases recorded for %s' % source)
if case.inputs:
dct['inputs'] = list(case.inputs)
if case.outputs:
dct['outputs'] = list(case.outputs)
if case.residuals:
dct['residuals'] = list(case.residuals)
if out_stream:
if out_stream is _DEFAULT_OUT_STREAM:
out_stream = sys.stdout
write_source_table(dct, out_stream)
return dct
def list_model_options(self, run_counter=None, out_stream=_DEFAULT_OUT_STREAM):
"""
List of all model options.
Parameters
----------
run_counter : int or None
Run_driver or run_model iteration to inspect
out_stream : file-like object
Where to send human readable output. Default is sys.stdout.
Set to None to suppress.
Returns
-------
dict
{'root':{key val}}
"""
if out_stream:
if out_stream is _DEFAULT_OUT_STREAM:
out_stream = sys.stdout
dct = {}
for i in self._system_options:
if '_' in i:
subsys, num = i.rsplit('_', 1)
else:
subsys = i
num = 0
if (run_counter is not None and run_counter == int(num) and subsys == 'root') or \
(subsys == 'root' and run_counter is None):
out_stream.write(
'Run Number: {}\n Subsystem: {}'.format(num, subsys))
for j in self._system_options[i]['component_options']:
option = "{0} : {1}".format(
j, self._system_options[i]['component_options'][j])
out_stream.write('\n {}\n'.format(option))
dct[subsys] = {}
dct[subsys][j] = self._system_options[i]['component_options'][j]
return dct
def list_cases(self, source=None, recurse=True, flat=True, out_stream=_DEFAULT_OUT_STREAM):
"""
Iterate over Driver, Solver and System cases in order.
Parameters
----------
source : {'problem', 'driver', <system hierarchy location>, <solver hierarchy location>,
case name}
If not None, only cases originating from the specified source or case are returned.
recurse : bool, optional
If True, will enable iterating over all successors in case hierarchy.
flat : bool, optional
If False and there are child cases, then a nested ordered dictionary
is returned rather than an iterator.
out_stream : file-like object
Where to send human readable output. Default is sys.stdout.
Set to None to suppress.
Returns
-------
iterator or dict
An iterator or a nested dictionary of identified cases.
"""
# if source was not specified, return all cases
if source is None:
if flat:
source = ''
else:
if self._driver_cases.count() > 0:
source = 'driver'
elif 'root' in self._system_cases.list_sources():
source = 'root'
else:
# if there are no driver or model cases, then we need
# another starting point to build the nested dict.
raise RuntimeError("A nested dictionary of all cases was requested, but "
"neither the driver or the model was recorded. Please "
"specify another source (system or solver) for the cases "
"you want to see.")
if not isinstance(source, str):
raise TypeError("Source parameter must be a string, %s is type %s." %
(source, type(source).__name__))
if not source:
return self._list_cases_recurse_flat(out_stream=out_stream)
elif source == 'problem':
if self._format_version >= 2:
return self._problem_cases.list_cases()
else:
raise RuntimeError('No problem cases recorded (data format = %d).' %
self._format_version)
else:
# figure out which table has cases from the source
if source == 'driver':
case_table = self._driver_cases
elif source in self._system_cases.list_sources():
case_table = self._system_cases
elif source in self._solver_cases.list_sources():
case_table = self._solver_cases
else:
case_table = None
if case_table is not None:
if not recurse:
# return list of cases from the source alone
return case_table.list_cases(source)
elif flat:
# return list of cases from the source plus child cases
cases = []
source_cases = case_table.get_cases(source)
for case in source_cases:
cases += self._list_cases_recurse_flat(case.name)
return cases
else:
# return nested dict of cases from the source and child cases
cases = OrderedDict()
source_cases = case_table.get_cases(source)
for case in source_cases:
cases.update(self._list_cases_recurse_nested(case.name))
return cases
elif '|' in source:
# source is a coordinate
if recurse:
if flat:
return self._list_cases_recurse_flat(source)
else:
return self._list_cases_recurse_nested(source)
else:
raise RuntimeError('Source not found: %s' % source)
def _list_cases_recurse_flat(self, coord=None, out_stream=_DEFAULT_OUT_STREAM):
"""
Iterate recursively over Driver, Solver and System cases in order.
Parameters
----------
coord : an iteration coordinate
Identifies the parent of the cases to return.
out_stream : file-like object
Where to send human readable output. Default is sys.stdout.
Set to None to suppress.
Returns
-------
dict
A nested dictionary of identified cases.
"""
solver_cases = self._solver_cases.list_cases()
system_cases = self._system_cases.list_cases()
driver_cases = self._driver_cases.list_cases()
if self._format_version >= 2:
problem_cases = self._problem_cases.list_cases()
global_iters = self._global_iterations
if not coord:
# will return all cases
coord = ''
parent_case_counter = len(global_iters)
elif coord in driver_cases:
parent_case_counter = self._driver_cases.get_case(coord).counter
elif coord in system_cases:
parent_case_counter = self._system_cases.get_case(coord).counter
elif coord in solver_cases:
parent_case_counter = self._solver_cases.get_case(coord).counter
elif coord in problem_cases:
parent_case_counter = self._problem_cases.get_case(coord).counter
else:
raise RuntimeError('Case not found for coordinate:', coord)
cases = []
self.source_cases_table = {'solver': [], 'system': [], 'driver': [], 'problem': []}
# return all cases in the global iteration table that precede the given case
# and whose coordinate is prefixed by the given coordinate
for i in range(0, parent_case_counter):
global_iter = global_iters[i]
table, row = global_iter[1], global_iter[2]
if table == 'solver':
case_coord = solver_cases[row - 1]
elif table == 'system':
case_coord = system_cases[row - 1]
elif table == 'driver':
case_coord = driver_cases[row - 1]
elif table == 'problem':
case_coord = problem_cases[row - 1]
else:
raise RuntimeError('Unexpected table name in global iterations:', table)
if case_coord.startswith(coord):
self.source_cases_table[table].append(case_coord)
cases.append(case_coord)
if out_stream:
if out_stream is _DEFAULT_OUT_STREAM:
out_stream = sys.stdout
write_source_table(self.source_cases_table, out_stream)
return cases
def _list_cases_recurse_nested(self, coord=None):
"""
Iterate recursively over Driver, Solver and System cases in order.
Parameters
----------
coord : an iteration coordinate
Identifies the parent of the cases to return.
Returns
-------
dict
A nested dictionary of identified cases.
"""
solver_cases = self._solver_cases.list_cases()
system_cases = self._system_cases.list_cases()
driver_cases = self._driver_cases.list_cases()
global_iters = self._global_iterations
if coord in driver_cases:
parent_case = self._driver_cases.get_case(coord)
elif coord in system_cases:
parent_case = self._system_cases.get_case(coord)
elif coord in solver_cases:
parent_case = self._solver_cases.get_case(coord)
else:
raise RuntimeError('Case not found for coordinate:', coord)
cases = OrderedDict()
children = OrderedDict()
cases[parent_case.name] = children
# return all cases in the global iteration table that precede the given case
# and whose coordinate is prefixed by the given coordinate
for i in range(0, parent_case.counter - 1):
global_iter = global_iters[i]
table, row = global_iter[1], global_iter[2]
if table == 'solver':
case_coord = solver_cases[row - 1]
if case_coord.startswith(coord):
parent_coord = '|'.join(case_coord.split('|')[:-2])
if parent_coord == coord:
children.update(self._list_cases_recurse_nested(case_coord))
elif table == 'system':
case_coord = system_cases[row - 1]
if case_coord.startswith(coord):
parent_coord = '|'.join(case_coord.split('|')[:-2])
if parent_coord == coord:
children.update(self._list_cases_recurse_nested(case_coord))
return cases
def get_cases(self, source=None, recurse=True, flat=True):
"""
Iterate over the cases.
Parameters
----------
source : {'problem', 'driver', component pathname, solver pathname, case_name}
Identifies which cases to return.
recurse : bool, optional
If True, will enable iterating over all successors in case hierarchy
flat : bool, optional
If False and there are child cases, then a nested ordered dictionary
is returned rather than an iterator.
Returns
-------
list or dict
The cases identified by source
"""
case_ids = self.list_cases(source, recurse, flat, out_stream=None)
if isinstance(case_ids, list):
return [self.get_case(case_id) for case_id in case_ids]
else:
return self._get_cases_nested(case_ids, OrderedDict())
def _get_cases_nested(self, case_ids, cases):
"""
Populate a nested dictionary of cases matching the provided dictionary of case IDs.
Parameters
----------
case_ids : OrderedDict
The nested dictionary of case IDs.
cases : OrderedDict
The nested dictionary of cases.
Returns
-------
OrderedDict
The nested dictionary of cases with cases added from case_ids.
"""
for case_id in case_ids:
case = self.get_case(case_id)
children = case_ids[case_id]
if len(children.keys()) > 0:
cases[case] = self._get_cases_nested(children, OrderedDict())
else:
cases[case] = OrderedDict()
return cases
def get_case(self, case_id, recurse=False):
"""
Get case identified by case_id.
Parameters
----------
case_id : str or int
The unique identifier of the case to return or an index into all cases.
recurse : bool, optional
If True, will return all successors to the case as well.
Returns
-------
dict
The case identified by case_id
"""
if isinstance(case_id, int):
# it's a global index rather than a coordinate
global_iters = self._global_iterations
if case_id > len(global_iters) - 1:
raise IndexError("Invalid index into available cases:", case_id)
global_iter = global_iters[case_id]
table, row = global_iter[1], global_iter[2]
if table == 'solver':
solver_cases = self._solver_cases.list_cases()
case_id = solver_cases[row - 1]
elif table == 'system':
system_cases = self._system_cases.list_cases()
case_id = system_cases[row - 1]
elif table == 'driver':
driver_cases = self._driver_cases.list_cases()
case_id = driver_cases[row - 1]
if recurse:
return self.get_cases(case_id, recurse=True)
tables = [self._driver_cases, self._system_cases, self._solver_cases]
if self._format_version >= 2:
tables.append(self._problem_cases)
for table in tables:
case = table.get_case(case_id)
if case:
return case
raise RuntimeError('Case not found:', case_id)
class CaseTable(object):
"""
Base class for wrapping case tables in a recording database.
Attributes
----------
_filename : str
The name of the recording file from which to instantiate the case reader.
_format_version : int
The version of the format assumed when loading the file.
_table_name : str
The name of the table in the database.
_index_name : str
The name of the case index column in the table.
_global_iterations : list
List of iteration cases and the table and row in which they are found.
_abs2prom : {'input': dict, 'output': dict}
Dictionary mapping absolute names to promoted names.
_abs2meta : dict
Dictionary mapping absolute variable names to variable metadata.
_prom2abs : {'input': dict, 'output': dict}
Dictionary mapping promoted names to absolute names.
_conns : dict
Dictionary of all model connections.
_var_info : dict
Dictionary with information about variables (scaling, indices, execution order).
_sources : list
List of sources of cases in the table.
_keys : list
List of keys of cases in the table.
_cases : dict
Dictionary mapping keys to cases that have already been loaded.
_auto_ivc_map : dict
Dictionary that maps all auto_ivc sources to either an absolute input name for single
connections or a promoted input name for multiple connections. This is for output display.
_global_iterations : list
List of iteration cases and the table and row in which they are found.
"""
def __init__(self, fname, ver, table, index, giter, prom2abs, abs2prom, abs2meta, conns,
auto_ivc_map, var_info):
"""
Initialize.
Parameters
----------
fname : str
The name of the recording file from which to instantiate the case reader.
ver : int
The version of the format assumed when loading the file.
table : str
The name of the table in the database.
index : str
The name of the case index column in the table.
giter : list of tuple
The global iterations table.
abs2prom : {'input': dict, 'output': dict}
Dictionary mapping absolute names to promoted names.
abs2meta : dict
Dictionary mapping absolute variable names to variable metadata.
prom2abs : {'input': dict, 'output': dict}
Dictionary mapping promoted names to absolute names.
conns : dict
Dictionary of all model connections.
auto_ivc_map : dict
Dictionary that maps all auto_ivc sources to either an absolute input name for single
connections or a promoted input name for multiple connections. This is for output
display.
var_info : dict
Dictionary with information about variables (scaling, indices, execution order).
"""
self._filename = fname
self._format_version = ver
self._table_name = table
self._index_name = index
self._global_iterations = giter
self._prom2abs = prom2abs
self._abs2prom = abs2prom
self._abs2meta = abs2meta
self._conns = conns
self._auto_ivc_map = auto_ivc_map
self._var_info = var_info
# cached keys/cases
self._sources = None
self._keys = None
self._cases = {}
def count(self):
"""
Get the number of cases recorded in the table.
Returns
-------
int
The number of cases recorded in the table.
"""
with sqlite3.connect(self._filename) as con:
cur = con.cursor()
cur.execute("SELECT count(*) FROM %s" % self._table_name)
rows = cur.fetchall()
con.close()
return rows[0][0]
def list_cases(self, source=None):
"""
Get list of case IDs for cases in the table.
Parameters
----------
source : str, optional
A source of cases or the iteration coordinate of a case.
If not None, only cases originating from the specified source or case are returned.
Returns
-------
list
The cases from the table from the specified source or parent case.
"""
if not self._keys:
with sqlite3.connect(self._filename) as con:
cur = con.cursor()
cur.execute("SELECT %s FROM %s ORDER BY id ASC" %
(self._index_name, self._table_name))
rows = cur.fetchall()
con.close()
# cache case list for future use
self._keys = [row[0] for row in rows]
if not source:
# return all cases
return self._keys
elif '|' in source:
# source is a coordinate
return [key for key in self._keys if key.startswith(source)]
else:
# source is a system or solver
return [key for key in self._keys if self._get_source(key) == source]
def get_cases(self, source=None, recurse=False, flat=False):
"""
Get list of case names for cases in the table.
Parameters
----------
source : str, optional
If not None, only cases that have the specified source will be returned
recurse : bool, optional
If True, will enable iterating over all successors in case hierarchy
flat : bool, optional
If False and there are child cases, then a nested ordered dictionary
is returned rather than an iterator.
Returns
-------
list or dict
The cases from the table that have the specified source.
"""
if self._keys is None:
self.list_cases()
if not source:
# return all cases
return [self.get_case(key) for key in self._keys]
elif '|' in source:
# source is a coordinate
if recurse and not flat:
cases = OrderedDict()
for key in self._keys:
if len(key) > len(source) and key.startswith(source):
cases[key] = self.get_cases(key, recurse, flat)
return cases
else:
return list([self.get_case(key) for key in self._keys if key.startswith(source)])
else:
# source is a system or solver
if recurse:
if flat:
# return all cases under the source system
source_sys = source.replace('.nonlinear_solver', '')
return list([self.get_case(key) for key in self._keys
if get_source_system(key).startswith(source_sys)])
else:
cases = OrderedDict()
for key in self._keys:
case_source = self._get_source(key)
if case_source == source:
cases[key] = self.get_cases(key, recurse, flat)
return cases
else:
return [self.get_case(key) for key in self._keys
if self._get_source(key) == source]
def get_case(self, case_id, cache=False):
"""
Get a case from the database.
Parameters
----------
case_id : str or int
The string-identifier of the case to be retrieved or the index of the case.
cache : bool
If True, case will be cached for faster access by key.
Returns
-------
Case
The specified case from the table.
"""
# check to see if we've already cached this case
if isinstance(case_id, int):
case_id = self._get_iteration_coordinate(case_id)
# if we've already cached this case, return the cached instance
if case_id in self._cases:
return self._cases[case_id]
# we don't have it, so fetch it
with sqlite3.connect(self._filename) as con:
con.row_factory = sqlite3.Row
cur = con.cursor()
cur.execute("SELECT * FROM %s WHERE %s='%s'" %
(self._table_name, self._index_name, case_id))
row = cur.fetchone()