In [1]:
#Mount my drive- run the code, go to the link, accept.
from google.colab import drive
drive.mount('/content/gdrive')

#Change working directory to make it easier to access the files
import os
os.chdir("/content/gdrive/My Drive/Colab Notebooks/dinn")
os.getcwd() 

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


'/content/gdrive/My Drive/Colab Notebooks/dinn'

In [2]:
import torch
from torch.autograd import grad
import torch.nn as nn
from numpy import genfromtxt
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

tSIDARTHE_data = genfromtxt('tSIDARTHE_data.csv', delimiter=',') #in the form of [t,S,I,D,A,R,T,H,E]

torch.manual_seed(1234)

<torch._C.Generator at 0x7fd37ccb2ab0>

In [3]:
#[<4, <12, <22, <28, <38, <50, >50] #i.e 7 bins
parameters_dict = {}
parameters_dict['alpha'] = [0.570, 0.422,0.422,0.360,0.210,0.210,0.105]
parameters_dict['beta'] = [0.011, 0.0057,0.0057, 0.005,0.005,0.005,0.005]
parameters_dict['gamma'] = [0.456, 0.285, 0.285,0.2,0.110,0.110,0.110]
parameters_dict['delta'] = [0.011, 0.0057, 0.0057, 0.005,0.005,0.005,0.005]
parameters_dict['epsilon'] = [0.171, 0.171,0.143,0.143,0.143,0.200,0.200]
parameters_dict['zeta'] = [0.125,0.125,0.125,0.034,0.034,0.025,0.025]
parameters_dict['lambdda'] = [0.034,0.034,0.034,0.08,0.08,0.08,0.08]
parameters_dict['eta'] = [0.125,0.125,0.125,0.034,0.034,0.025,0.025]
parameters_dict['rho'] = [0.034,0.034,0.034,0.017,0.017,0.020,0.020]
parameters_dict['mu'] = [0.017,0.017,0.017,0.008,0.008,0.008,0.008]
parameters_dict['kappa'] = [0.017,0.017,0.017,0.017,0.017,0.020,0.020]
parameters_dict['theta'] = [0.371,0.371,0.371, 0.371,0.371,0.371,0.371]
parameters_dict['nu'] = [0.027,0.027,0.027,0.015,0.015,0.015,0.015]
parameters_dict['xi'] = [0.017,0.017,0.017,0.017,0.017,0.020,0.020]
parameters_dict['sigma'] = [0.017,0.017,0.017,0.017,0.017,0.010,0.010]
parameters_dict['tao'] = [0.01,0.01,0.01,0.01,0.01,0.01,]
parameters_dict

{'alpha': [0.57, 0.422, 0.422, 0.36, 0.21, 0.21, 0.105],
 'beta': [0.011, 0.0057, 0.0057, 0.005, 0.005, 0.005, 0.005],
 'delta': [0.011, 0.0057, 0.0057, 0.005, 0.005, 0.005, 0.005],
 'epsilon': [0.171, 0.171, 0.143, 0.143, 0.143, 0.2, 0.2],
 'eta': [0.125, 0.125, 0.125, 0.034, 0.034, 0.025, 0.025],
 'gamma': [0.456, 0.285, 0.285, 0.2, 0.11, 0.11, 0.11],
 'kappa': [0.017, 0.017, 0.017, 0.017, 0.017, 0.02, 0.02],
 'lambdda': [0.034, 0.034, 0.034, 0.08, 0.08, 0.08, 0.08],
 'mu': [0.017, 0.017, 0.017, 0.008, 0.008, 0.008, 0.008],
 'nu': [0.027, 0.027, 0.027, 0.015, 0.015, 0.015, 0.015],
 'rho': [0.034, 0.034, 0.034, 0.017, 0.017, 0.02, 0.02],
 'sigma': [0.017, 0.017, 0.017, 0.017, 0.017, 0.01, 0.01],
 'tao': [0.01, 0.01, 0.01, 0.01, 0.01, 0.01],
 'theta': [0.371, 0.371, 0.371, 0.371, 0.371, 0.371, 0.371],
 'xi': [0.017, 0.017, 0.017, 0.017, 0.017, 0.02, 0.02],
 'zeta': [0.125, 0.125, 0.125, 0.034, 0.034, 0.025, 0.025]}

In [4]:
%%time

PATH = 'sidarthe_tensor_norm' 

class DINN(nn.Module):
    def __init__(self, t, S_data, I_data, D_data, A_data, R_data, T_data, H_data, E_data): #[t,S,I,D,A,R,T,H,E]
        super(DINN, self).__init__()
        self.t = torch.tensor(t).float()
        self.t = torch.reshape(self.t, (len(self.t),1)) #reshape for batch 
        self.S = torch.tensor(S_data)
        self.I = torch.tensor(I_data)
        self.D = torch.tensor(D_data)
        self.A = torch.tensor(A_data)
        self.R = torch.tensor(R_data)
        self.T = torch.tensor(T_data)
        self.H = torch.tensor(H_data)
        self.E = torch.tensor(E_data)

        #Unnormalize (out of 60mil population)
        # self.S = self.S * 60e6
        # self.I = self.I * 60e6
        # self.D = self.D * 60e6
        # self.A = self.A * 60e6
        # self.R = self.R * 60e6
        # self.T = self.T * 60e6
        # self.H = self.H * 60e6
        # self.E = self.E * 60e6

        #find values for normalization
        self.t_max = max(self.t)
        self.t_min = min(self.t)        
        self.S_max = max(self.S)
        self.I_max = max(self.I)
        self.D_max = max(self.D)
        self.A_max = max(self.A)
        self.R_max = max(self.R)
        self.T_max = max(self.T)
        self.H_max = max(self.H)
        self.E_max = max(self.E)
        self.S_min = min(self.S)
        self.I_min = min(self.I)
        self.D_min = min(self.D)
        self.A_min = min(self.A)
        self.R_min = min(self.R)
        self.T_min = min(self.T)
        self.H_min = min(self.H)
        self.E_min = min(self.E)

        #normalize 
        self.S_hat = (self.S - self.S_min) / (self.S_max - self.S_min)
        self.I_hat = (self.I - self.I_min) / (self.I_max - self.I_min)
        self.D_hat = (self.D - self.D_min) / (self.D_max - self.D_min)
        self.A_hat = (self.A - self.A_min) / (self.A_max - self.A_min)
        self.R_hat = (self.R - self.R_min) / (self.R_max - self.R_min)
        self.T_hat = (self.T - self.T_min) / (self.T_max - self.T_min)
        self.H_hat = (self.H - self.H_min) / (self.H_max - self.H_min)
        self.E_hat = (self.E - self.E_min) / (self.E_max - self.E_min)
        self.t_hat = (self.t - self.t_min) / (self.t_max - self.t_min)

        self.bin = 0 #a counter for which self.bin I'm on
        self.losses = []
        self.save = 2 #which file to save to
        self.SI_vals = [] #keep the intermediate SI values

        #learnable parameters (there are 7 self.bins (e.g <4 days, etc.))
        self.alpha_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.beta_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.gamma_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.delta_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.epsilon_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.zeta_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.lambdda_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.eta_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.rho_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.mu_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.kappa_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.theta_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.nu_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.xi_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.sigma_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        self.tao_tilda = torch.nn.Parameter(torch.rand(7, requires_grad=True))
        
        #NN
        self.net_sidarthe = self.Net_sidarthe()
        self.params = list(self.net_sidarthe.parameters())
        self.params.extend(list([self.alpha_tilda, self.beta_tilda, self.gamma_tilda, self.delta_tilda, self.epsilon_tilda, self.zeta_tilda, self.lambdda_tilda, self.eta_tilda, self.rho_tilda, self.mu_tilda, self.kappa_tilda, self.theta_tilda, self.nu_tilda, self.xi_tilda, self.sigma_tilda, self.tao_tilda ]))

    #force parameters to be in a range
    @property
    def alpha(self):
        val = self.alpha_tilda[0]*(self.t<4) + self.alpha_tilda[1]*((4<=self.t)&(self.t<12)) + self.alpha_tilda[2]*((12<=self.t)&(self.t<22)) + self.alpha_tilda[3]*((22<=self.t)&(self.t<28)) + self.alpha_tilda[4]*((28<=self.t)&(self.t<38)) + self.alpha_tilda[5]*((38<=self.t)&(self.t<50)) + self.alpha_tilda[6]*(self.t>=50)
        return torch.tanh(val)# * 0.8

    @property
    def beta(self):
        val = self.beta_tilda[0]*(self.t<4) + self.beta_tilda[1]*((4<=self.t)&(self.t<12)) + self.beta_tilda[2]*((12<=self.t)&(self.t<22)) + self.beta_tilda[3]*((22<=self.t)&(self.t<28)) + self.beta_tilda[4]*((28<=self.t)&(self.t<38)) + self.beta_tilda[5]*((38<=self.t)&(self.t<50)) + self.beta_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.03
    
    @property
    def gamma(self):
        val = self.gamma_tilda[0]*(self.t<4) + self.gamma_tilda[1]*((4<=self.t)&(self.t<12)) + self.gamma_tilda[2]*((12<=self.t)&(self.t<22)) + self.gamma_tilda[3]*((22<=self.t)&(self.t<28)) + self.gamma_tilda[4]*((28<=self.t)&(self.t<38)) + self.gamma_tilda[5]*((38<=self.t)&(self.t<50)) + self.gamma_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.6

    @property
    def delta(self):
        val = self.delta_tilda[0]*(self.t<4) + self.delta_tilda[1]*((4<=self.t)&(self.t<12)) + self.delta_tilda[2]*((12<=self.t)&(self.t<22)) + self.delta_tilda[3]*((22<=self.t)&(self.t<28)) + self.delta_tilda[4]*((28<=self.t)&(self.t<38)) + self.delta_tilda[5]*((38<=self.t)&(self.t<50)) + self.delta_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.03

    @property
    def epsilon(self):
        val = self.epsilon_tilda[0]*(self.t<4) + self.epsilon_tilda[1]*((4<=self.t)&(self.t<12)) + self.epsilon_tilda[2]*((12<=self.t)&(self.t<22)) + self.epsilon_tilda[3]*((22<=self.t)&(self.t<28)) + self.epsilon_tilda[4]*((28<=self.t)&(self.t<38)) + self.epsilon_tilda[5]*((38<=self.t)&(self.t<50)) + self.epsilon_tilda[6]*(self.t>=50)
        return torch.tanh(val) * 0.5

    @property
    def zeta(self):
        val = self.zeta_tilda[0]*(self.t<4) + self.zeta_tilda[1]*((4<=self.t)&(self.t<12)) + self.zeta_tilda[2]*((12<=self.t)&(self.t<22)) + self.zeta_tilda[3]*((22<=self.t)&(self.t<28)) + self.zeta_tilda[4]*((28<=self.t)&(self.t<38)) + self.zeta_tilda[5]*((38<=self.t)&(self.t<50)) + self.zeta_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.4

    @property
    def lambdda(self):
        val = self.lambdda_tilda[0]*(self.t<4) + self.lambdda_tilda[1]*((4<=self.t)&(self.t<12)) + self.lambdda_tilda[2]*((12<=self.t)&(self.t<22)) + self.lambdda_tilda[3]*((22<=self.t)&(self.t<28)) + self.lambdda_tilda[4]*((28<=self.t)&(self.t<38)) + self.lambdda_tilda[5]*((38<=self.t)&(self.t<50)) + self.lambdda_tilda[6]*(self.t>=50)
        return torch.tanh(val)# * 0.3

    @property
    def eta(self):
        val = self.eta_tilda[0]*(self.t<4) + self.eta_tilda[1]*((4<=self.t)&(self.t<12)) + self.eta_tilda[2]*((12<=self.t)&(self.t<22)) + self.eta_tilda[3]*((22<=self.t)&(self.t<28)) + self.eta_tilda[4]*((28<=self.t)&(self.t<38)) + self.eta_tilda[5]*((38<=self.t)&(self.t<50)) + self.eta_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.3

    @property
    def rho(self):
        val = self.rho_tilda[0]*(self.t<4) + self.rho_tilda[1]*((4<=self.t)&(self.t<12)) + self.rho_tilda[2]*((12<=self.t)&(self.t<22)) + self.rho_tilda[3]*((22<=self.t)&(self.t<28)) + self.rho_tilda[4]*((28<=self.t)&(self.t<38)) + self.rho_tilda[5]*((38<=self.t)&(self.t<50)) + self.rho_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.1

    @property
    def mu(self):
        val = self.mu_tilda[0]*(self.t<4) + self.mu_tilda[1]*((4<=self.t)&(self.t<12)) + self.mu_tilda[2]*((12<=self.t)&(self.t<22)) + self.mu_tilda[3]*((22<=self.t)&(self.t<28)) + self.mu_tilda[4]*((28<=self.t)&(self.t<38)) + self.mu_tilda[5]*((38<=self.t)&(self.t<50)) + self.mu_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.05

    @property
    def kappa(self):
        val = self.kappa_tilda[0]*(self.t<4) + self.kappa_tilda[1]*((4<=self.t)&(self.t<12)) + self.kappa_tilda[2]*((12<=self.t)&(self.t<22)) + self.kappa_tilda[3]*((22<=self.t)&(self.t<28)) + self.kappa_tilda[4]*((28<=self.t)&(self.t<38)) + self.kappa_tilda[5]*((38<=self.t)&(self.t<50)) + self.kappa_tilda[6]*(self.t>=50)
        return torch.tanh(val)# * 0.05

    @property
    def theta(self):
        val = self.theta_tilda[0]*(self.t<4) + self.theta_tilda[1]*((4<=self.t)&(self.t<12)) + self.theta_tilda[2]*((12<=self.t)&(self.t<22)) + self.theta_tilda[3]*((22<=self.t)&(self.t<28)) + self.theta_tilda[4]*((28<=self.t)&(self.t<38)) + self.theta_tilda[5]*((38<=self.t)&(self.t<50)) + self.theta_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.8
                    
    @property
    def nu(self):
        val = self.nu_tilda[0]*(self.t<4) + self.nu_tilda[1]*((4<=self.t)&(self.t<12)) + self.nu_tilda[2]*((12<=self.t)&(self.t<22)) + self.nu_tilda[3]*((22<=self.t)&(self.t<28)) + self.nu_tilda[4]*((28<=self.t)&(self.t<38)) + self.nu_tilda[5]*((38<=self.t)&(self.t<50)) + self.nu_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.08

    @property
    def xi(self):
        val = self.xi_tilda[0]*(self.t<4) + self.xi_tilda[1]*((4<=self.t)&(self.t<12)) + self.xi_tilda[2]*((12<=self.t)&(self.t<22)) + self.xi_tilda[3]*((22<=self.t)&(self.t<28)) + self.xi_tilda[4]*((28<=self.t)&(self.t<38)) + self.xi_tilda[5]*((38<=self.t)&(self.t<50)) + self.xi_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.05

    @property
    def sigma(self):
        val = self.sigma_tilda[0]*(self.t<4) + self.sigma_tilda[1]*((4<=self.t)&(self.t<12)) + self.sigma_tilda[2]*((12<=self.t)&(self.t<22)) + self.sigma_tilda[3]*((22<=self.t)&(self.t<28)) + self.sigma_tilda[4]*((28<=self.t)&(self.t<38)) + self.sigma_tilda[5]*((38<=self.t)&(self.t<50)) + self.sigma_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.05

    @property
    def tao(self):
        val = self.tao_tilda[0]*(self.t<4) + self.tao_tilda[1]*((4<=self.t)&(self.t<12)) + self.tao_tilda[2]*((12<=self.t)&(self.t<22)) + self.tao_tilda[3]*((22<=self.t)&(self.t<28)) + self.tao_tilda[4]*((28<=self.t)&(self.t<38)) + self.tao_tilda[5]*((38<=self.t)&(self.t<50)) + self.tao_tilda[6]*(self.t>=50)
        return torch.tanh(val) #* 0.05

    #nets
    class Net_sidarthe(nn.Module): # input = [t]
        def __init__(self):
            super(DINN.Net_sidarthe, self).__init__()
            self.fc1=nn.Linear(1, 20) #takes t
            self.fc2=nn.Linear(20, 20)
            self.fc3=nn.Linear(20, 20)
            self.fc4=nn.Linear(20, 20)
            self.fc5=nn.Linear(20, 20)
            self.fc6=nn.Linear(20, 20)
            self.fc7=nn.Linear(20, 20)
            self.fc8=nn.Linear(20, 20)
            self.out=nn.Linear(20, 8) #outputs [S,I,D,A,R,T,H,E]

        def forward(self, t):
            sidarthe=F.relu(self.fc1(t))
            sidarthe=F.relu(self.fc2(sidarthe))
            sidarthe=F.relu(self.fc3(sidarthe))
            sidarthe=F.relu(self.fc4(sidarthe))
            sidarthe=F.relu(self.fc5(sidarthe))
            sidarthe=F.relu(self.fc6(sidarthe))
            sidarthe=F.relu(self.fc7(sidarthe))
            sidarthe=F.relu(self.fc8(sidarthe))
            sidarthe=self.out(sidarthe)
            return sidarthe    

    def get_SI(self, t_):
      net_vals = self.net_sidarthe(t_)
      self.SI_vals.append(net_vals)
      return net_vals

    def net_f(self, t_hat):
        self.SI_vals = [] #reset values list
        d_SI = torch.autograd.functional.jacobian(self.get_SI, t_hat) #calculate Jacobian

        S_hat,I_hat,D_hat,A_hat,R_hat,T_hat,H_hat,E_hat = self.SI_vals[0][:,0], self.SI_vals[0][:,1], self.SI_vals[0][:,2], self.SI_vals[0][:,3], self.SI_vals[0][:,4], self.SI_vals[0][:,5], self.SI_vals[0][:,6], self.SI_vals[0][:,7] 

        S = self.S_min + (self.S_max - self.S_min) * S_hat
        I = self.I_min + (self.I_max - self.I_min) * I_hat
        D = self.D_min + (self.D_max - self.D_min) * D_hat
        A = self.A_min + (self.A_max - self.A_min) * A_hat
        R = self.R_min + (self.R_max - self.R_min) * R_hat
        T = self.T_min + (self.T_max - self.T_min) * T_hat
        H = self.H_min + (self.H_max - self.H_min) * H_hat
        E = self.E_min + (self.E_max - self.E_min) * E_hat
        t = self.t_min + (self.t_max - self.t_min) * t_hat

        S_hat_t_hat, I_hat_t_hat, D_hat_t_hat, A_hat_t_hat, R_hat_t_hat, T_hat_t_hat, H_hat_t_hat, E_hat_t_hat = torch.diagonal(torch.diagonal(d_SI, 0, -1), 0)[0], torch.diagonal(torch.diagonal(d_SI, 1, -1), 0)[0], torch.diagonal(torch.diagonal(d_SI, 2, -1), 0)[0] , torch.diagonal(torch.diagonal(d_SI, 3, -1), 0)[0] , torch.diagonal(torch.diagonal(d_SI, 4, -1), 0)[0], torch.diagonal(torch.diagonal(d_SI, 5, -1), 0)[0] , torch.diagonal(torch.diagonal(d_SI, 6, -1), 0)[0] , torch.diagonal(torch.diagonal(d_SI, 7, -1), 0)[0]          

        f1 = S_hat_t_hat + (S  * (self.alpha * I + self.beta * D + self.gamma * A + self.delta * R)) / (self.S_max - self.S_min)
        f2 = I_hat_t_hat + (- S * (self.alpha * I + self.beta * D + self.gamma * A + self.delta * R) + (self.epsilon + self.eta + self.lambdda) * I ) / (self.I_max - self.I_min)
        f3 = D_hat_t_hat + (- self.epsilon * I + (self.eta + self.rho) * D) / (self.D_max - self.D_min)
        f4 = A_hat_t_hat + (- self.eta * I + (self.theta + self.mu + self.kappa) * A) / (self.A_max - self.A_min)
        f5 = R_hat_t_hat + (- self.eta * D - self.theta * A + (self.nu + self.xi) * R) / (self.R_max - self.R_min)
        f6 = T_hat_t_hat + (- self.mu * A - self.nu * R + (self.sigma + self.tao) * T) / (self.T_max - self.T_min)
        f7 = H_hat_t_hat + (- self.lambdda * I - self.rho * D - self.kappa * A - self.xi * R - self.sigma * T) / (self.H_max - self.H_min)
        f8 = E_hat_t_hat + (- self.tao * T) / (self.E_max - self.E_min)

        # print('\nf1: ', f1[0][0])
        # print('f2: ', f2[0][0])
        # print('f3: ', f3[0][0])
        # print('f4: ', f4[0][0])
        # print('f5: ', f5[0][0])
        # print('f6: ', f6[0][0])
        # print('f7: ', f7[0][0])
        # print('f8: ', f8[0][0])
        return f1, f2, f3, f4, f5, f6, f7, f8, S, I, D, A, R, T, H, E
    
    def load(self):
      # Load checkpoint
      try:
        checkpoint = torch.load(PATH + str(self.save)+'.pt') 
        print('\nloading pre-trained model...')
        self.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        self.losses = checkpoint['losses']
        print('loaded previous loss: ', loss)
        self.scheduler._last_lr = 1e-10
      except RuntimeError :
          print('changed the architecture, ignore')
          pass
      except FileNotFoundError:
          pass

    def train(self, n_epochs):
      #try loading
      self.load()

      #train
      print('\nstarting training...\n')
      
      for epoch in range(n_epochs):
        #lists to hold the output (maintain only the final epoch)
        S_pred_list = []
        I_pred_list = []
        D_pred_list = []
        A_pred_list = []
        R_pred_list = []
        T_pred_list = []
        H_pred_list = []
        E_pred_list = []

        self.optimizer.zero_grad()
        
        f1, f2, f3, f4, f5, f6, f7, f8, S_pred, I_pred, D_pred, A_pred, R_pred, T_pred, H_pred, E_pred = self.net_f(self.t_hat)

        S_pred_list.append(self.S_min + (self.S_max - self.S_min) * S_pred) #unnormalize for graphing
        I_pred_list.append(self.I_min + (self.I_max - self.I_min) * I_pred)
        D_pred_list.append(self.D_min + (self.D_max - self.D_min) * D_pred)
        A_pred_list.append(self.A_min + (self.A_max - self.A_min) * A_pred)
        R_pred_list.append(self.R_min + (self.R_max - self.R_min) * R_pred)
        T_pred_list.append(self.T_min + (self.T_max - self.T_min) * T_pred)
        H_pred_list.append(self.H_min + (self.H_max - self.H_min) * H_pred)
        E_pred_list.append(self.E_min + (self.E_max - self.E_min) * E_pred)
        
        loss = (torch.mean(torch.square(self.S_hat - S_pred))+ 
                torch.mean(torch.square(self.I_hat - I_pred))+
                torch.mean(torch.square(self.D_hat - D_pred))+
                torch.mean(torch.square(self.A_hat - A_pred))+
                torch.mean(torch.square(self.R_hat - R_pred))+
                torch.mean(torch.square(self.T_hat - T_pred))+
                torch.mean(torch.square(self.H_hat - H_pred))+
                torch.mean(torch.square(self.E_hat - E_pred))+
                torch.mean(torch.square(f1))+
                torch.mean(torch.square(f2))+
                torch.mean(torch.square(f3))+
                torch.mean(torch.square(f4))+
                torch.mean(torch.square(f5))+
                torch.mean(torch.square(f6))+
                torch.mean(torch.square(f7))+
                torch.mean(torch.square(f8))
                ) 

        #loss.backward(retain_graph=True)
        loss.backward()
        self.optimizer.step()
        self.scheduler.step() #scheduler

        self.losses.append(loss)
        #print(self.zeta)
        #loss + model parameters update
        if epoch % 100 == 0:
          #checkpoint save every 1000 epochs if the loss is lower
          print('\nSaving model... Loss is: ', loss)
          torch.save({
              'epoch': epoch,
              'model': self.state_dict(),
              'optimizer_state_dict': self.optimizer.state_dict(),
              'scheduler': self.scheduler.state_dict(),
              'loss': loss,
              'losses': self.losses,
              }, PATH + str(self.save)+'.pt')
          if self.save % 2 > 0: #its on 3
            self.save = 2 #change to 2
          else: #its on 2
            self.save = 3 #change to 3

          print('epoch: ', epoch)
          #print(f'alpha: (goal {parameters_dict["alpha"]}', '\noutput: ', self.alpha)
          print(f'alpha: (goal {parameters_dict["alpha"]}', '\noutput: ', torch.tanh(self.alpha_tilda) * 0.8)
          # print(f'beta: (goal {parameters_dict["beta"]}', '\noutput: ',self.beta)
          # print(f'gamma: (goal {parameters_dict["gamma"]}', '\noutput: ',self.gamma)
          # print(f'delta: (goal {parameters_dict["delta"]}', '\noutput: ',self.delta)
          # print(f'epsilon: (goal {parameters_dict["epsilon"]}', '\noutput: ',self.epsilon)
          # print(f'eta: (goal {parameters_dict["eta"]}', '\noutput: ',self.eta)
          # print(f'lambdda: (goal {parameters_dict["lambdda"]}', '\noutput: ',self.lambdda)
          # print(f'eta: (goal {parameters_dict["eta"]}', '\noutput: ',self.eta)
          # print(f'rho: (goal {parameters_dict["rho"]}', '\noutput: ',self.rho)
          # print(f'mu: (goal {parameters_dict["mu"]}', '\noutput: ',self.mu)
          # print(f'kappa: (goal {parameters_dict["kappa"]}', '\noutput: ',self.kappa)
          # print(f'theta: (goal {parameters_dict["theta"]}', '\noutput: ',self.theta)
          # print(f'nu: (goal {parameters_dict["nu"]}', '\noutput: ',self.nu)
          # print(f'xi: (goal {parameters_dict["xi"]}', '\noutput: ',self.xi)
          # print(f'sigma: (goal {parameters_dict["sigma"]}', '\noutput: ',self.sigma)
          # print(f'tao: (goal {parameters_dict["tao"]}', '\noutput: ',self.tao)
          print('#################################')                

        
      #plot
      plt.plot(self.losses, color = 'teal')
      plt.xlabel('Epochs')
      plt.ylabel('Loss')
      return S_pred_list, I_pred_list, D_pred_list, A_pred_list, R_pred_list, T_pred_list, H_pred_list, E_pred_list

CPU times: user 129 µs, sys: 0 ns, total: 129 µs
Wall time: 140 µs


In [None]:
%%time

#this worked best
dinn = DINN(tSIDARTHE_data[0], tSIDARTHE_data[1], tSIDARTHE_data[2], tSIDARTHE_data[3], 
            tSIDARTHE_data[4], tSIDARTHE_data[5], tSIDARTHE_data[6], tSIDARTHE_data[7], tSIDARTHE_data[8]) #in the form of [t,S,I,D,A,R,T,H,E]

learning_rate = 0.02
optimizer = optim.Adam(dinn.params, lr = learning_rate)
dinn.optimizer = optimizer

#scheduler = torch.optim.lr_scheduler.CyclicLR(dinn.optimizer, base_lr=1e-9, max_lr=1e-3, step_size_up=4000, mode="triangular2", cycle_momentum=False)
scheduler = torch.optim.lr_scheduler.CyclicLR(dinn.optimizer, base_lr=1e-9, max_lr=1e-3, step_size_up=1000, mode="triangular2", cycle_momentum=False)
dinn.scheduler = scheduler

S_pred_list, I_pred_list, D_pred_list, A_pred_list, R_pred_list, T_pred_list, H_pred_list, E_pred_list = dinn.train(20000) #train
#S_pred_list, I_pred_list, D_pred_list, A_pred_list, R_pred_list, T_pred_list, H_pred_list, E_pred_list = dinn.test() #test


loading pre-trained model...
loaded previous loss:  tensor(11.7333, dtype=torch.float64, requires_grad=True)

starting training...


Saving model... Loss is:  tensor(11.7276, dtype=torch.float64, grad_fn=<AddBackward0>)


In [None]:
plt.plot(dinn.losses[12000:], color = 'teal')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
#(tSIDARTHE_data[0], tSIDARTHE_data[1], tSIDARTHE_data[2], tSIDARTHE_data[3], 
            #tSIDARTHE_data[4], tSIDARTHE_data[5], tSIDARTHE_data[6], tSIDARTHE_data[7], tSIDARTHE_data[8]) #in the form of [t,S,I,D,A,R,T,H,E]

fig = plt.figure(facecolor='w', figsize=(12,12))
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[1]* 60e6, 'pink', alpha=0.5, lw=2, label='Susceptible')
# ax.plot(tSIDARTHE_data[0], S_pred_list[0].detach().numpy(), 'navy', alpha=0.9, lw=2, label='Susceptible Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[2]* 60e6, 'violet', alpha=0.5, lw=2, label='Infected')
# ax.plot(tSIDARTHE_data[0], I_pred_list[0].detach().numpy(), 'dodgerblue', alpha=0.9, lw=2, label='Infected Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[3]* 60e6, 'darkgreen', alpha=0.5, lw=2, label='Diagnosed')
# ax.plot(tSIDARTHE_data[0], D_pred_list[0].detach().numpy(), 'gold', alpha=0.9, lw=2, label='Diagnosed Prediction', linestyle='dashed')

ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[4], 'red', alpha=0.5, lw=2, label='Ailling')
ax.plot(tSIDARTHE_data[0], A_pred_list[0].detach().numpy(), 'salmon', alpha=0.9, lw=2, label='Ailling Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[5]* 60e6, 'blue', alpha=0.5, lw=2, label='Recognized')
# ax.plot(tSIDARTHE_data[0], R_pred_list[0].detach().numpy(), 'wheat', alpha=0.9, lw=2, label='Recognized Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[6]* 60e6, 'purple', alpha=0.5, lw=2, label='Threatened')
# ax.plot(tSIDARTHE_data[0], T_pred_list[0].detach().numpy(), 'teal', alpha=0.9, lw=2, label='Threatened Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[7]* 60e6, 'yellow', alpha=0.5, lw=2, label='Healed')
# ax.plot(tSIDARTHE_data[0], H_pred_list[0].detach().numpy(), 'slategrey', alpha=0.9, lw=2, label='Healed Prediction', linestyle='dashed')

# ax.plot(tSIDARTHE_data[0], tSIDARTHE_data[8]* 60e6, 'black', alpha=0.5, lw=2, label='Extinct')
# ax.plot(tSIDARTHE_data[0], E_pred_list[0].detach().numpy(), 'aqua', alpha=0.9, lw=2, label='Extinct Prediction', linestyle='dashed')


ax.set_xlabel('Time /days')
ax.set_ylabel('Number')
#ax.set_ylim([-1,50])
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

#initial conditions
I0 = 200
D0 = 20
A0 = 1
R0 = 2
T0 = 0
H0 = 0
E0 = 0
S0 = (1 - I0 - D0 - A0 - R0 - T0 - H0 - E0) * 60e6

# A grid of time points (in days)
t = np.linspace(0, 350, 50) 

# The SIR model differential equations.
def deriv(y, t):
    if t < 4:
      idx = 0
    elif 4 <= t < 12:
      idx = 1
    elif 12 <= t < 22:
      idx = 2
    elif 22 <= t < 28:
      idx = 3
    elif 28 <= t < 38:
      idx = 4
    elif 38 <= t < 50:
      idx = 5
    elif t >= 50:
      idx = 6

    alpha=dinn.alpha[idx]
    beta=dinn.beta[idx]
    gamma=dinn.gamma[idx]
    delta=dinn.delta[idx]
    epsilon=dinn.epsilon[idx]
    zeta=dinn.zeta[idx]
    lambdda=dinn.lambdda[idx]
    eta=dinn.eta[idx]
    rho=dinn.rho[idx]
    mu=dinn.mu[idx]
    kappa=dinn.kappa[idx]
    theta=dinn.theta[idx]
    nu=dinn.nu[idx]
    xi=dinn.xi[idx]
    sigma=dinn.sigma[idx]
    tao=dinn.tao[idx]

    I, D, A, R, T, H, E, S = y
    dSdt = -S * (alpha * I + beta * D + gamma * A + delta * R)
    dIdt = S * (alpha * I + beta * D + gamma * A + delta * R) - (epsilon + zeta + lambdda) * I
    dDdt = epsilon * I - (eta + rho) * D
    dAdt = zeta * I - (theta + mu + kappa) * A 
    dRdt = eta * D + theta * A - (nu + xi) * R
    dTdt = mu * A + nu * R - (sigma + tao) * T
    dHdt = lambdda * I + rho * D + kappa * A + xi * R + sigma * T
    dEdt = tao * T

    return dIdt, dDdt, dAdt, dRdt, dTdt, dHdt, dEdt, dSdt

# Initial conditions vector
y0 = I0, D0, A0, R0, T0, H0, E0, S0

# Integrate the SIR equations over the time grid, t.
ret = odeint(deriv, y0, t)
I, D, A, R, T, H, E, S = ret.T

# Plot the data on two separate curves for S(t), I(t)
fig = plt.figure(facecolor='w')
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)

#ax.plot(t, I, 'violet', alpha=0.5, lw=2, label='Infected', linestyle='dashed')
#ax.plot(t, D, 'darkgreen', alpha=0.5, lw=2, label='Diagnosed', linestyle='dashed')
#ax.plot(t, A, 'red', alpha=0.5, lw=2, label='Ailling', linestyle='dashed')
#ax.plot(t, R, 'blue', alpha=0.5, lw=2, label='Recognized', linestyle='dashed')
#ax.plot(t, T, 'purple', alpha=0.5, lw=2, label='Threatened', linestyle='dashed')
#ax.plot(t, H, 'yellow', alpha=0.5, lw=2, label='Healed', linestyle='dashed')
#ax.plot(t, E, 'black', alpha=0.5, lw=2, label='Extinct', linestyle='dashed')
ax.plot(t, S, 'pink', alpha=0.5, lw=2, label='Susceptible', linestyle='dashed')

ax.set_xlabel('Time /days')
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()