Skip to content

Commit

Permalink
lots of playing around with DMR
Browse files Browse the repository at this point in the history
  • Loading branch information
goujou committed Aug 19, 2021
1 parent d48aef1 commit 5752f7e
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions src/CompartmentalSystems/discrete_model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def reconstruct_B(cls, x, F, R):
B[j, j] = 0
else:
pass
print(B[j, j])
print(x[j], R[j], F[:, j].sum(), F[j, :].sum())
# print(B[j, j])
# print(x[j], R[j], F[:, j].sum(), F[j, :].sum())
raise(DMRError('Diag. val < 0: pool %d, ' % j))
else:
B[j, j] = 1
Expand Down Expand Up @@ -293,17 +293,48 @@ def reconstruct_B_2(cls, x, F, R, U):
# construct diagonals
for j in range(nr_pools):
if x[j] + U[j] != 0:
B[j, j] = 1 - (sum(B[:, j]) - B[j, j] + R[j] / (x[j] + U[j]))
# B[j, j] = 1 - (sum(B[:, j]) - B[j, j] + R[j] / (x[j] + U[j]))
B[j, j] = ((x[j] + U[j]) * (1 - sum(B[:, j]) + B[j, j]) - R[j]) / (x[j] + U[j])
if B[j, j] < 0:
# B[j, j] = 0
# y = np.array([B[i, j] * (x[j] + U[j]) for i in range(nr_pools)])
# print(y)
# print()
# print(F[:, j])
# print(y - F[:, j])
# print(sum(B[:, j]))
# print((1-sum(B[:, j])) * (x[j] + U[j]), R[j])
# print(x[j] + U[j], (sum(F[:, j]) + R[j]) / 0.15)
# raise
if np.abs(B[j, j]) < 1e-08:
B[j, j] = 0.0
else:
# pass
print(B[j, j])
print(x[j], U[j], R[j], F[:, j].sum(), F[j, :].sum())
print(U[j] - R[j] - F[:, j].sum() + F[j, :].sum())
print(B[:, j])
raise(DMRError('Diag. val < 0: pool %d, ' % j))
else:
B[j, j] = 1

# # correct for negative diagonals
# neg_diag_idx = np.where(np.diag(B)<0)[0]
# for idx in neg_diag_idx:
## print("'repairing' neg diag in pool", idx)
# # scale outfluxes down to empty pool
# col = B[:, idx]
# d = col[idx].sum()
# s = 1-d
## print(s)
# B[:, idx] = B[:, idx] / s
# r = R[idx] / (x[idx] + U[idx]) / s
# B[idx, idx] = 1 - (sum(B[:, idx]) - B[idx, idx] + r)
# if np.abs(B[idx, idx]) < 1e-08:
# B[idx, idx] = 0
#
# print(B[idx, idx], (B @ (x + U)))

return B

# @classmethod
Expand Down Expand Up @@ -591,12 +622,13 @@ def _solve_age_moment_system(self, max_order, start_age_moments):
Id = np.identity(n)
ones = np.ones(n)
soln = self.solve()
dts = self.dts
soln[soln < 1e-12] = 0
# dts = self.dts

def diag_inv_with_zeros(A):
res = np.zeros_like(A)
for k in range(A.shape[0]):
if A[k, k] != 0:
if np.abs(A[k, k]) != 0:
res[k, k] = 1/A[k, k]
else:
# res[k, k] = np.nan
Expand All @@ -614,11 +646,11 @@ def diag_inv_with_zeros(A):
moment_sum = np.zeros(n)
for j in range(1, k+1):
moment_sum += age_moments[-1][j-1, :].reshape((n,)) \
* binom(k, j) * dts[i]**(k-j)
* binom(k, j) #* dts[i]**(k-j)

# vec[k-1, :] = inv(X_np1) @ B @\
vec[k-1, :] = diag_inv_with_zeros(X_np1) @ B @\
X_n @ (moment_sum + ones*dts[i]**k)
X_n @ (moment_sum + ones)#*dts[i]**k)

age_moments.append(vec)

Expand Down

0 comments on commit 5752f7e

Please sign in to comment.