<a href="https://colab.research.google.com/github/Yanbing-Judy/190DD-Project/blob/main/Project_Part1b_updated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [107]:
# B1
from itertools import product
from sklearn.metrics import mean_squared_error
import numpy as np
import pandas as pd
from scipy.integrate import solve_ivp
import matplotlib.pyplot as plt

data=np.load('part1b.npz')

N = data['N']
Lc = data['Lc']
Ic_0 = data['Ic_0']
gamma = data['gamma']
L_validation = data['L_validation']
L_test = data['L_test']
betas_validation = data['betas_validation']
nb_nodes=5


In [108]:
def SIR_model_batch_beta(t, y, b1, b2, b3, b4):
    S=y[:4]
    I=y[4:8]
    R=y[8]
    beta = np.array([b1,b2,b3,b4])
    dS_dt = -(beta*S*I.sum()/N)
    dI_dt = -dS_dt-gamma*I
    dR_dt = gamma*I.sum()
    RHS = np.concatenate((dS_dt, dI_dt, [dR_dt]))
    return RHS

In [109]:
def GS_beta_batch(N, Lc, Ic_0, gamma, observed_data, pmf):
  grid=np.linspace(0,1,21)
  t=np.linspace(0,19,20)
  mmse=np.inf
  for (b1,b2,b3,b4) in product(grid,grid,grid,grid):
    cumulative_mse = 0.
    if (b1<=b2<=b3<=b4):
      for i in range(nb_nodes):
        #initial condition
        R_0=0
        Sc_0 = pmf*(N-Ic_0[i].sum())
        Ic_ = Ic_0[i]
        y0 = np.concatenate((Sc_0,Ic_,[R_0]))
        sol_object = solve_ivp(fun=SIR_model_batch_beta,t_span=(0,20),y0=y0,args=(b1,b2,b3,b4),dense_output=True)
        L_predicted = []
        for sol in sol_object.sol(t).T:
          L_predicted.append(np.sum(sol[4:8]*Lc))  # Ic*Lc
        cumulative_mse += mean_squared_error(L_predicted,observed_data[i])
      if cumulative_mse<mmse:
        mmse = cumulative_mse
        beta = [b1,b2,b3,b4]
  print(f"MMSE is {mmse:.2f}")
  return mmse,beta

In [None]:
pmf1 = np.array([0.5,0.3,0.1,0.1])
pmf2 = np.array([0.4,0.3,0.2,0.1])
pmf3 = np.array([0.3,0.3,0.2,0.2])
pmf4 = np.array([0.1,0.2,0.3,0.4])

#beta for pmf1
mmse_1, beta_1 = GS_beta_batch(N,Lc,Ic_0[0:5],gamma,L_validation[0:5],pmf1)
print("Estimated betas are")
print(beta_1)
print("Validation betas are")
print(betas_validation[0])

In [None]:
#beta for pmf2
mmse_2, beta_2 = GS_beta_batch(N,Lc,Ic_0[5:10],gamma,L_validation[5:10],pmf2)
print("Estimated betas are")
print(beta_2)
print("Validation betas are")
print(betas_validation[1])

In [None]:
#beta for pmf3
mmse_3, beta_3 = GS_beta_batch(N,Lc,Ic_0[10:15],gamma,L_validation[10:15],pmf3)
print("Estimated betas are")
print(beta_3)
print("Validation betas are")
print(betas_validation[2])

In [None]:
# #beta for pmf3
mmse_4, beta_4 = GS_beta_batch(N,Lc,Ic_0[15:20],gamma,L_validation[15:20],pmf4)
print("Estimated betas are")
print(beta_4)
print("Validation betas are")
print(betas_validation[3])

In [None]:
# B2
beta_test=[]
pmf_array = []
pmf_array.extend([pmf1,pmf2,pmf3,pmf4])
beta_array=[]
pointer = 0 # introduce pointer for matrix operation to compare with L_test. Avoid duplicate code.
for pmf in pmf_array:
  mmse, beta = GS_beta_batch(N,Lc,Ic_0[pointer:pointer+5],gamma,L_test[pointer:pointer+5],pmf)
  pointer+=5 #next batch of 5 nodes
  beta_array.append(beta)
print(beta_array)

In [None]:
# B3
pointer = 0
c = 0
R_0 = 0
t = np.linspace(0,99,100) #Solve the equations for 100 days
for pmf, beta in zip(pmf_array,beta_array):
  Sc_0 = pmf*(N-Ic_0[pointer].sum())
  Ic_first = Ic_0[pointer] # first node in each batch
  y0=np.concatenate((Sc_0,Ic_first,[R_0]))
  sol_object = solve_ivp(fun=SIR_model_batch_beta,t_span=(0,100),y0=y0,args=(beta[0],beta[1],beta[2],beta[3]),dense_output=True)
  I_beta = sol_object.sol(t).T[:,4:8]
  L_beta = Lc[c]*I_beta
  L_predicted = sum(L_beta.T)

  plt.xlabel("Days")
  plt.title("First node of sv {}".format(c))
  # plot for first 20 days
  plt.plot(np.linspace(0,19,20), L_test[pointer], marker = "0" ,label = 'L test')
  # plot S, I, R and predicted L for first node in each social vulnerability case
  plt.plot(sol_object.t, sum(sol_object.y[:4]), label = 'S')
  plt.plot(sol_object.t, sum(sol_object.y[4:8]), label = 'I')
  plt.plot(sol_object.t, sol_object.y[8], label = 'R')
  plt.plot(t, L_predicted, label = 'L predicted')
  plt.show()

  i+=5
  c+=1