Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 31 additions & 7 deletions ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tqdm.auto import tqdm
from p_tqdm import p_map
import logging
from geostat.decomp import Cholesky # Making realizations

# Internal imports
import pipt.misc_tools.analysis_tools as at
Expand All @@ -26,7 +27,6 @@
from misc.system_tools.environ_var import OpenBlasSingleThread # Single threaded OpenBLAS runs



class Ensemble:
"""
Class for organizing misc. variables and simulator for an ensemble-based inversion run. Here, the forecast step
Expand Down Expand Up @@ -139,15 +139,24 @@ def __init__(self, keys_en, sim, redund_sim=None):
# individually).
self.state = {key: val for key, val in tmp_load.items()}

# Find the number of ensemble members from state variable
# Find the number of ensemble members from loaded state variables
tmp_ne = []
for tmp_state in self.state.keys():
tmp_ne.extend([self.state[tmp_state].shape[1]])
if max(tmp_ne) != min(tmp_ne):
print('\033[1;33mInput states have different ensemble size\033[1;m')
sys.exit(1)
self.ne = min(tmp_ne)


if 'ne' not in self.keys_en: # NE not specified in input file
if max(tmp_ne) != min(tmp_ne): #Check loaded ensembles are the same size (if more than one state variable)
print('\033[1;33mInput states have different ensemble size\033[1;m')
sys.exit(1)
self.ne = min(tmp_ne) # Use the number of ensemble members in loaded ensemble
else:
# Use the number of ensemble members specified in input file (may be fewer than loaded)
self.ne = int(self.keys_en['ne'])
if self.ne <= min(tmp_ne):
# pick correct number of ensemble members
self.state = {key: val[:,:self.ne] for key, val in self.state.items()}
else:
print('\033[1;33mInput states are smaller than NE\033[1;m')
if 'multilevel' in self.keys_en:
ml_info = extract.extract_multilevel_info(self.keys_en)
self.multilevel, self.tot_level, self.ml_ne, self.ML_error_corr, self.error_comp_scheme, self.ML_corr_done = ml_info
Expand Down Expand Up @@ -338,6 +347,20 @@ def calc_prediction(self, input_state=None, save_prediction=None):
# Index list of ensemble members
list_member_index = list(range(self.ne))

# modified by xluo, for including the simulation of the mean reservoir model
# as used in the RLM-MAC algorithm
if 'daalg' in self.keys_en and self.keys_en['daalg'][1] == 'gies':
list_state.append({})
list_member_index.append(self.ne)

for key in self.state.keys():
tmp_state = np.zeros(list_state[0][key].shape[0])

for i in range(self.ne):
tmp_state += list_state[i][key]

list_state[self.ne][key] = tmp_state / self.ne

if no_tot_run==1: # if not in parallel we use regular loop
en_pred = [self.sim.run_fwd_sim(state, member_index) for state, member_index in
tqdm(zip(list_state, list_member_index), total=len(list_state))]
Expand Down Expand Up @@ -392,6 +415,7 @@ def calc_prediction(self, input_state=None, save_prediction=None):
else: # Run prediction in parallel using p_map
en_pred = p_map(self.sim.run_fwd_sim, list_state,
list_member_index, num_cpus=no_tot_run, disable=self.disable_tqdm)

# List successful runs and crashes
list_crash = [indx for indx, el in enumerate(en_pred) if el is False]
list_success = [indx for indx, el in enumerate(en_pred) if el is not False]
Expand Down
8 changes: 6 additions & 2 deletions pipt/loop/assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,8 +529,12 @@ def post_process_forecast(self):
for k in pred_data_tmp[i]: # DATATYPE
if vintage < len(self.ensemble.sparse_info['mask']) and \
len(pred_data_tmp[i][k]) == int(np.sum(self.ensemble.sparse_info['mask'][vintage])):
self.ensemble.pred_data[i][k] = np.zeros(
(len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
if self.ensemble.keys_da['daalg'][1] == 'gies':
self.ensemble.pred_data[i][k] = np.zeros(
(len(self.ensemble.obs_data[i][k]), self.ensemble.ne+1))
else:
self.ensemble.pred_data[i][k] = np.zeros(
(len(self.ensemble.obs_data[i][k]), self.ensemble.ne))
for m in range(pred_data_tmp[i][k].shape[1]):
data_array = self.ensemble.compress(pred_data_tmp[i][k][:, m], vintage,
self.ensemble.sparse_info['use_ensemble'])
Expand Down
23 changes: 20 additions & 3 deletions pipt/loop/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,24 @@ def _ext_scaling(self):
self.state_scaling = at.calc_scaling(
self.prior_state, self.list_states, self.prior_info)

self.Am = None
delta_scaled_prior = self.state_scaling[:, None] * \
np.dot(at.aug_state(self.prior_state, self.list_states), self.proj)

u_d, s_d, v_d = np.linalg.svd(delta_scaled_prior, full_matrices=False)

# remove the last singular value/vector. This is because numpy returns all ne values, while the last is actually
# zero. This part is a good place to include eventual additional truncation.
energy = 0
trunc_index = len(s_d) - 1 # inititallize
for c, elem in enumerate(s_d):
energy += elem
if energy / sum(s_d) >= self.trunc_energy:
trunc_index = c # take the index where all energy is preserved
break
u_d, s_d, v_d = u_d[:, :trunc_index +
1], s_d[:trunc_index + 1], v_d[:trunc_index + 1, :]
self.Am = np.dot(u_d, np.eye(trunc_index+1) *
((s_d**(-1))[:, None])) # notation from paper

def save_temp_state_assim(self, ind_save):
"""
Expand Down Expand Up @@ -694,7 +711,7 @@ def compress(self, data=None, vintage=0, aug_coeff=None):

data_array = None

elif aug_coeff is None:
elif aug_coeff is None: # compress predicted data

data_array, wdec_rec = self.sparse_data[vintage].compress(data)
rec = self.sparse_data[vintage].reconstruct(
Expand All @@ -703,7 +720,7 @@ def compress(self, data=None, vintage=0, aug_coeff=None):
self.data_rec.append([])
self.data_rec[vintage].append(rec)

elif not aug_coeff:
elif not aug_coeff: # compress true data, aug_coeff = false

options = copy(self.sparse_info)
# find the correct mask for the vintage
Expand Down
5 changes: 3 additions & 2 deletions pipt/misc_tools/analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,8 +542,9 @@ def calc_objectivefun(pert_obs, pred_data, Cd):
data_misfit : array-like
Nex1 array containing objective function values.
"""
ne = pred_data.shape[1]
r = (pred_data - pert_obs)
#ne = pred_data.shape[1]
ne = pert_obs.shape[1]
r = (pred_data[:, :ne] - pert_obs) # This is necessary to use to gies code that xilu has implemented
if len(Cd.shape) == 1:
precission = Cd**(-1)
data_misfit = np.diag(r.T.dot(r*precission[:, None]))
Expand Down
8 changes: 6 additions & 2 deletions pipt/misc_tools/extract_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,12 @@ def organize_sparse_representation(info: Union[dict,list]) -> dict:
sparse['dim'] = [dim[2], dim[1], dim[0]]

# Read mask_files
sparse['mask'] = []
for idx, filename in enumerate(info['mask'], start=1):
sparse['mask'] = []
m_info = info['mask']
# allow for one mask with filename given as string
if isinstance(m_info, str):
m_info = [m_info]
for idx, filename in enumerate(m_info, start=1):
if not os.path.exists(filename):
mask = np.ones(sparse['dim'], dtype=bool)
np.savez(f'mask_{idx}.npz', mask=mask)
Expand Down
Loading