In [1]:
#Mount my drive- run the code, go to the link, accept.
from google.colab import drive
drive.mount('/content/gdrive')

#Change working directory to make it easier to access the files
import os
os.chdir("/content/gdrive/My Drive/Colab Notebooks/dinn")
os.getcwd() 

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


'/content/gdrive/My Drive/Colab Notebooks/dinn'

In [2]:
import torch
from torch.autograd import grad
import torch.nn as nn
from numpy import genfromtxt
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

sidr_data = genfromtxt('sidr.csv', delimiter=',') #in the form of [t,S,I,D,R]

torch.manual_seed(1234)

<torch._C.Generator at 0x7f35a0d0c390>

In [3]:
%%time

PATH = 'sidr' 

class DINN(nn.Module):
    def __init__(self, t, S_data, I_data, D_data, R_data): #[t,S,I,D,R]
        super(DINN, self).__init__()
        self.N = 100 #population size
        self.t = torch.tensor(t, requires_grad=True)
        self.t_float = self.t.float()
        self.t_batch = torch.reshape(self.t_float, (len(self.t),1)) #reshape for batch 
        self.S = torch.tensor(S_data)
        self.I = torch.tensor(I_data)
        self.D = torch.tensor(D_data)
        self.R = torch.tensor(R_data)

        self.losses = []
        self.save = 3 #which file to save to

        self.alpha_tilda = torch.tensor(0.191)#torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.beta_tilda = torch.tensor(0.05) #torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.gamma_tilda = torch.tensor (0.0294) #torch.nn.Parameter(torch.rand(1, requires_grad=True))

        #matrices (x4 for S,I,D,R) for the gradients
        self.m1 = torch.zeros((len(self.t), 4)); self.m1[:, 0] = 1
        self.m2 = torch.zeros((len(self.t), 4)); self.m2[:, 1] = 1
        self.m3 = torch.zeros((len(self.t), 4)); self.m3[:, 2] = 1
        self.m4 = torch.zeros((len(self.t), 4)); self.m4[:, 3] = 1

        #NN
        self.net_sidr = self.Net_sidr()
        self.params = list(self.net_sidr.parameters())
        self.params.extend(list([self.alpha_tilda, self.beta_tilda, self.gamma_tilda]))

    #force parameters to be in a range
    @property
    def alpha(self):
        return torch.tanh(self.alpha_tilda) #* 0.1 + 0.2

    @property
    def beta(self):
        return torch.tanh(self.beta_tilda) #* 0.01 + 0.05
    
    @property
    def gamma(self):
        return torch.tanh(self.gamma_tilda) #* 0.01 + 0.3


    #nets
    class Net_sidr(nn.Module): # input = [t]
        def __init__(self):
            super(DINN.Net_sidr, self).__init__()
            self.fc1=nn.Linear(1, 32) #takes 100 t's
            self.fc2=nn.Linear(32, 32)
            self.fc3=nn.Linear(32, 64)
            self.fc4=nn.Linear(64, 128)
            self.fc5=nn.Linear(128, 128)
            self.fc6=nn.Linear(128, 64)
            self.fc7=nn.Linear(64, 32)
            self.fc8=nn.Linear(32, 32)
            self.out=nn.Linear(32, 4) #outputs S, I, D, R


        def forward(self, t_batch):
            sidr=F.relu(self.fc1(t_batch))
            sidr=F.relu(self.fc2(sidr))
            sidr=F.relu(self.fc3(sidr))
            sidr=F.relu(self.fc4(sidr))
            sidr=F.relu(self.fc5(sidr))
            sidr=F.relu(self.fc6(sidr))
            sidr=F.relu(self.fc7(sidr))
            sidr=F.relu(self.fc8(sidr))
            sidr=self.out(sidr)
            return sidr
            

    def net_f(self, t_batch):
        sidr = self.net_sidr(t_batch)

        S,I,D,R = sidr[:,0], sidr[:,1], sidr[:,2], sidr[:,3]

        #S_t
        sidr.backward(self.m1, retain_graph=True)
        S_t = self.t.grad
        self.t.grad.zero_()

        #I_t
        sidr.backward(self.m2, retain_graph=True)
        I_t = self.t.grad
        self.t.grad.zero_()

        #D_t
        sidr.backward(self.m3, retain_graph=True)
        D_t = self.t.grad
        self.t.grad.zero_()

        #R_t
        sidr.backward(self.m4, retain_graph=True)
        R_t = self.t.grad
        self.t.grad.zero_()

        f1 = S_t + (self.alpha / self.N) * S * I
        f2 = I_t - (self.alpha / self.N) * S * I + self.beta * I + self.gamma * I 
        f3 = D_t - self.gamma * I
        f4 = R_t - self.beta * I 

        return f1, f2, f3, f4, S, I, D, R
    
    def load(self):
      # Load checkpoint
      try:
        checkpoint = torch.load(PATH + str(self.save)+'.pt') 
        print('\nloading pre-trained model...')
        self.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        self.losses = checkpoint['losses']
        print('loaded previous loss: ', loss)

      except RuntimeError :
          print('changed the architecture, ignore')
          pass
      except FileNotFoundError:
          pass

    def train(self, n_epochs):
      #try loading
      self.load()

      #train
      print('\nstarting training...\n')
      
      for epoch in range(n_epochs):
        #lists to hold the output (maintain only the final epoch)
        S_pred_list = []
        I_pred_list = []
        D_pred_list = []
        R_pred_list = []

        f1, f2, f3, f4, S_pred, I_pred, D_pred, R_pred = self.net_f(self.t_batch)
        self.optimizer.zero_grad()

        S_pred_list.append(S_pred)
        I_pred_list.append(I_pred)
        D_pred_list.append(D_pred)
        R_pred_list.append(R_pred)

        loss = (torch.mean(torch.square(self.S - S_pred))+ 
                torch.mean(torch.square(self.I - I_pred))+
                torch.mean(torch.square(self.D - D_pred))+
                torch.mean(torch.square(self.R - R_pred))+
                torch.mean(torch.square(f1))+
                torch.mean(torch.square(f2))+
                torch.mean(torch.square(f3))+
                torch.mean(torch.square(f4))
                ) 
        
        loss.backward()
        self.optimizer.step()
        self.scheduler.step() #scheduler
        #self.scheduler.step(loss)

        self.losses.append(loss)

        #loss + model parameters update
        if epoch % 1000 == 0:
          #checkpoint save every 1000 epochs if the loss is lower
          print('\nSaving model... Loss is: ', loss)
          torch.save({
              'epoch': epoch,
              'model': self.state_dict(),
              'optimizer_state_dict': self.optimizer.state_dict(),
              'scheduler': self.scheduler.state_dict(),
              'loss': loss,
              'losses': self.losses,
              }, PATH + str(self.save)+'.pt')
          if self.save % 2 > 0: #its on 3
            self.save = 2 #change to 2
          else: #its on 2
            self.save = 3 #change to 3

          print('epoch: ', epoch)
          print('alpha: (goal 0.191 ', self.alpha)
          print('beta: (goal 0.05 ', self.beta)
          print('gamma: (goal 0.0294 ', self.gamma)

          # print('#################################')                

        
      #plot
      plt.plot(self.losses, color = 'teal')
      plt.xlabel('Epochs')
      plt.ylabel('Loss')
      return S_pred_list, I_pred_list, D_pred_list, R_pred_list

CPU times: user 52 µs, sys: 0 ns, total: 52 µs
Wall time: 56.5 µs


In [None]:
%%time

#this worked best
dinn = DINN(sidr_data[0], sidr_data[1], sidr_data[2], sidr_data[3], 
            sidr_data[4]) #in the form of [t,S,I,D,R]

learning_rate = 2e-2
optimizer = optim.Adam(dinn.params, lr = learning_rate)
dinn.optimizer = optimizer

#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(dinn.optimizer, factor=0.8, patience = 500, verbose=True)
scheduler = torch.optim.lr_scheduler.CyclicLR(dinn.optimizer, base_lr=1e-6, max_lr=5e-4, step_size_up=3000, mode="triangular2", cycle_momentum=False)

dinn.scheduler = scheduler

S_pred_list, I_pred_list, D_pred_list, R_pred_list = dinn.train(100000) #train


loading pre-trained model...
loaded previous loss:  tensor(7.9855e+14, dtype=torch.float64, requires_grad=True)

starting training...


Saving model... Loss is:  tensor(7.9853e+14, dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:  0
alpha: (goal 0.191  tensor(0.1887)
beta: (goal 0.05  tensor(0.0500)
gamma: (goal 0.0294  tensor(0.0294)

Saving model... Loss is:  tensor(7.9855e+14, dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:  1000
alpha: (goal 0.191  tensor(0.1887)
beta: (goal 0.05  tensor(0.0500)
gamma: (goal 0.0294  tensor(0.0294)

Saving model... Loss is:  tensor(7.9854e+14, dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:  2000
alpha: (goal 0.191  tensor(0.1887)
beta: (goal 0.05  tensor(0.0500)
gamma: (goal 0.0294  tensor(0.0294)

Saving model... Loss is:  tensor(7.9859e+14, dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:  3000
alpha: (goal 0.191  tensor(0.1887)
beta: (goal 0.05  tensor(0.0500)
gamma: (goal 0.0294  tensor(0.0294)

Saving model... Loss is:  tensor(7

In [None]:
plt.plot(dinn.losses[150000:], color = 'teal')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
fig = plt.figure(facecolor='w', figsize=(12,12))
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)

ax.plot(sidr_data[0], sidr_data[1], 'black', alpha=0.5, lw=2, label='Susceptible')
ax.plot(sidr_data[0], S_pred_list[0].detach().numpy(), 'red', alpha=0.9, lw=2, label='Susceptible Prediction', linestyle='dashed')

ax.plot(sidr_data[0], sidr_data[2], 'violet', alpha=0.5, lw=2, label='Infected')
ax.plot(sidr_data[0], I_pred_list[0].detach().numpy(), 'dodgerblue', alpha=0.9, lw=2, label='Infected Prediction', linestyle='dashed')

ax.plot(sidr_data[0], sidr_data[3], 'darkgreen', alpha=0.5, lw=2, label='Dead')
ax.plot(sidr_data[0], D_pred_list[0].detach().numpy(), 'green', alpha=0.9, lw=2, label='Dead Prediction', linestyle='dashed')

ax.plot(sidr_data[0], sidr_data[4], 'blue', alpha=0.5, lw=2, label='Recovered')
ax.plot(sidr_data[0], R_pred_list[0].detach().numpy(), 'teal', alpha=0.9, lw=2, label='Recovered Prediction', linestyle='dashed')


ax.set_xlabel('Time /days')
ax.set_ylabel('Number')
#ax.set_ylim([-1,50])
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
print(torch.tensor(sidr_data[1]) - S_pred_list[0])
print(torch.tensor(sidr_data[2]) - I_pred_list[0])
print(torch.tensor(sidr_data[3]) - D_pred_list[0])
print(torch.tensor(sidr_data[4]) - R_pred_list[0])

In [None]:
import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# Initial conditions
N = 100

S0 = N - 1
I0 = 1
D0 = 0
R0 = 0
# A grid of time points (in days)
t = np.linspace(0, 500, 100) 

#parameters
alpha = dinn.alpha
beta = dinn.beta
gamma = dinn.gamma

# The SIR model differential equations.
def deriv(y, t, alpha, betta, gamma):
    S, I, D, R = y
    dSdt = - (alpha / N) * S * I
    dIdt = (alpha / N) * S * I - beta * I - gamma * I 
    dDdt = gamma * I
    dRdt = beta * I

    return dSdt, dIdt, dDdt, dRdt


# Initial conditions vector
y0 = S0, I0, D0, R0
# Integrate the SIR equations over the time grid, t.
ret = odeint(deriv, y0, t, args=(alpha, beta, gamma))
S, I, D, R = ret.T

# Plot the data on two separate curves for S(t), I(t)
fig = plt.figure(facecolor='w', figsize=(12,12))
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)

ax.plot(t, S, 'violet', alpha=0.5, lw=2, label='Learnable Param Susceptible', linestyle='dashed')
ax.plot(t, sidr_data[1], 'dodgerblue', alpha=0.5, lw=2, label='Susceptible')

ax.plot(t, I, 'darkgreen', alpha=0.5, lw=2, label='Learnable Param Infected', linestyle='dashed')
ax.plot(t, sidr_data[2], 'gold', alpha=0.5, lw=2, label='Susceptible')

ax.plot(t, D, 'red', alpha=0.5, lw=2, label='Learnable Param Dead', linestyle='dashed')
ax.plot(t, sidr_data[3], 'salmon', alpha=0.5, lw=2, label='Dead')

ax.plot(t, R, 'blue', alpha=0.5, lw=2, label='Learnable Param Recovered', linestyle='dashed')
ax.plot(t, sidr_data[4], 'black', alpha=0.5, lw=2, label='Recovered')

ax.set_xlabel('Time /days')
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
#calculate relative MSE loss
import math

S_total_loss = 0
S_den = 0
I_total_loss = 0
I_den = 0
D_total_loss = 0
D_den = 0
R_total_loss = 0
R_den = 0

for timestep in range(len(t)):
  S_value = sidr_data[1][timestep] - S[timestep]
  S_total_loss += S_value**2
  S_den += (sidr_data[1][timestep])**2
  I_value = sidr_data[2][timestep] - I[timestep]
  I_total_loss += I_value**2
  I_den += (sidr_data[2][timestep])**2
  D_value = sidr_data[3][timestep] - D[timestep]
  D_total_loss += D_value**2
  D_den += (sidr_data[3][timestep])**2
  R_value = sidr_data[4][timestep] - R[timestep]
  R_total_loss += R_value**2
  R_den += (sidr_data[4][timestep])**2

S_total_loss = math.sqrt(S_total_loss/S_den)
I_total_loss = math.sqrt(I_total_loss/I_den)
D_total_loss = math.sqrt(D_total_loss/D_den)
R_total_loss = math.sqrt(R_total_loss/R_den)

print('S_total_loss: ', S_total_loss)
print('I_total_loss: ', I_total_loss)
print('D_total_loss: ', D_total_loss)
print('R_total_loss: ', R_total_loss)