Skip to content

Commit

Permalink
changed an error name
Browse files Browse the repository at this point in the history
  • Loading branch information
goujou committed Aug 7, 2020
1 parent 487178d commit e8c1520
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 34 deletions.
159 changes: 126 additions & 33 deletions CompartmentalSystems/pwc_model_run_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
###############################################################################


class Error(Exception):
class PWCModelRunFDError(Exception):
"""Generic error occurring in this module."""
pass

Expand All @@ -24,46 +24,101 @@ class Error(Exception):


class PWCModelRunFD(ModelRun):
def __init__(
self,
time_symbol,
data_times,
start_values,
gross_Us,
gross_Fs,
gross_Rs
):

self.data_times = data_times
cls = self.__class__
disc_times = data_times[1:-1]
# def __init__(
# self,
# time_symbol,
# data_times,
# start_values,
# gross_Us,
# gross_Fs,
# gross_Rs
# ):
#
# self.data_times = data_times
# cls = self.__class__
# disc_times = data_times[1:-1]
#
# print('reconstructing us')
## us = cls.reconstruct_us(data_times, gross_Us)
## self.dts = np.diff(self.data_times).astype(np.float64)
# us = gross_Us / self.dts.reshape(-1, 1)
#
# print('reconstructing Bs')
# Bs = cls.reconstruct_Bs(
# data_times,
# start_values,
# gross_Us,
# gross_Fs,
# gross_Rs
# )
#
#
#
# nr_pools = len(start_values)
# strlen = len(str(nr_pools))
#
# def pool_str(i): return ("{:0"+str(strlen)+"d}").format(i)
# par_dicts = [dict()] * len(us)
#
# srm_generic = cls.create_srm_generic(
# time_symbol,
# Bs,
# us
# )
# time_symbol = srm_generic.time_symbol
#
# par_dicts = []
# for k in range(len(us)):
# par_dict = dict()
# for j in range(nr_pools):
# for i in range(nr_pools):
# sym = Symbol('b_'+pool_str(i)+pool_str(j))
# if sym in srm_generic.F.free_symbols:
# par_dict[sym] = Bs[k, i, j]
#
# for i in range(nr_pools):
# sym = Symbol('u_'+pool_str(i))
# if sym in srm_generic.F.free_symbols:
# par_dict[sym] = us[k, i]
#
# par_dicts.append(par_dict)
#
# func_dicts = [dict()] * len(us)
#
# self.pwc_mr = PWCModelRun(
# srm_generic,
# par_dicts,
# start_values,
# data_times,
# func_dicts=func_dicts,
# disc_times=disc_times
# )
# self.us = us
# self.Bs = Bs

print('reconstructing us')
# us = cls.reconstruct_us(data_times, gross_Us)
# self.dts = np.diff(self.data_times).astype(np.float64)
us = gross_Us / self.dts.reshape(-1, 1)
def __init__(self, Bs, us, pwc_mr):
self.data_times = pwc_mr.times
self.pwc_mr = pwc_mr
self.Bs = Bs
self.us = us

print('reconstructing Bs')
Bs = cls.reconstruct_Bs(
data_times,
start_values,
gross_Us,
gross_Fs,
gross_Rs
)
@classmethod
def from_Bs_and_us(cls, time_symbol, data_times, start_values, Bs, us):
disc_times = data_times[1:-1]

nr_pools = len(start_values)
strlen = len(str(nr_pools))

def pool_str(i): return ("{:0"+str(strlen)+"d}").format(i)
def pool_str(i):
return ("{:0"+str(strlen)+"d}").format(i)

par_dicts = [dict()] * len(us)

srm_generic = cls.create_srm_generic(
time_symbol,
Bs,
us
)
time_symbol = srm_generic.time_symbol

par_dicts = []
for k in range(len(us)):
Expand All @@ -83,16 +138,48 @@ def pool_str(i): return ("{:0"+str(strlen)+"d}").format(i)

func_dicts = [dict()] * len(us)

self.pwc_mr = PWCModelRun(
pwc_mr = PWCModelRun(
srm_generic,
par_dicts,
start_values,
data_times,
func_dicts=func_dicts,
disc_times=disc_times
)
self.us = us
self.Bs = Bs

return cls(Bs, us, pwc_mr)

@classmethod
def from_gross_data(
cls,
time_symbol,
data_times,
start_values,
gross_Us,
gross_Fs,
gross_Rs
):

print('reconstructing us')
dts = np.diff(data_times).astype(np.float64)
us = gross_Us / dts.reshape(-1, 1)

print('reconstructing Bs')
Bs = cls.reconstruct_Bs(
data_times,
start_values,
gross_Us,
gross_Fs,
gross_Rs
)

return cls.from_Bs_and_us(
time_symbol,
data_times,
start_values,
Bs,
us
)

@property
def model(self):
Expand All @@ -106,6 +193,10 @@ def nr_pools(self):
def dts(self):
return np.diff(self.data_times).astype(np.float64)

@property
def start_values(self):
return self.pwc_mr.start_values

def solve(self, alternative_start_values=None):
return self.pwc_mr.solve(alternative_start_values)

Expand Down Expand Up @@ -141,6 +232,8 @@ def fake_discretized_Bs(self, data_times=None):
return self.pwc_mr.fake_discretized_Bs(data_times)

# def to_netcdf(self, mdo, file_path):
# """Return a netCDF dataset that contains stocks and fluxes."""
#
# data_vars = {}
# model_ds = xr.Dataset(
# data_vars=data_vars,
Expand All @@ -149,7 +242,7 @@ def fake_discretized_Bs(self, data_times=None):
# }
# model_ds.to_netcdf(file_path)
# model_ds.close()

#
# @classmethod
# def load_from_file(cls, filename):
# pwc_mr_fd_dict = picklegzip.load(filename)
Expand Down Expand Up @@ -310,7 +403,7 @@ def constrainedFunction(x, f, lower, upper, minIncr=0.001):
# )

if not y.success:
raise(Error(y.message))
raise(PWCModelRunFDError(y.message))

B = pars_to_matrix(y.x)

Expand Down
2 changes: 1 addition & 1 deletion tests/Test_pwc_model_run_fd.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_reconstruction_accuracy(self):
xs, gross_Us, gross_Fs, gross_Rs =\
smr.fake_gross_discretized_output(smr.times)

pwc_mr_fd = PWCModelRunFD(
pwc_mr_fd = PWCModelRunFD.from_gross_data(
smr.model.time_symbol,
smr.times,
xs[0, :],
Expand Down

0 comments on commit e8c1520

Please sign in to comment.