In [1]:
#Mount my drive- run the code, go to the link, accept.
from google.colab import drive
drive.mount('/content/gdrive')

#Change working directory to make it easier to access the files
import os
os.chdir("/content/gdrive/My Drive/Colab Notebooks/dinn")
os.getcwd() 

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


'/content/gdrive/My Drive/Colab Notebooks/dinn'

In [2]:
import torch
from torch.autograd import grad
import torch.nn as nn
from numpy import genfromtxt
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

hiv_data = genfromtxt('hiv.csv', delimiter=',') #in the form of [t, T, I, V]

torch.manual_seed(1234)

<torch._C.Generator at 0x7f46df7bbb10>

In [3]:
%%time

PATH = 'hiv' 

class DINN(nn.Module):
    def __init__(self, t, T_data, I_data, V_data): 
        super(DINN, self).__init__()
        self.t = torch.tensor(t, requires_grad=True)
        self.t_float = self.t.float()
        self.t_batch = torch.reshape(self.t_float, (len(self.t),1)) #reshape for batch 
        self.T = torch.tensor(T_data) 
        self.I = torch.tensor(I_data) 
        self.V = torch.tensor(V_data) 

        self.losses = [] #keep the losses
        self.save = 3 #which file to save to
 
        #learnable parameters
        self.s_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.mu_T_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.mu_I_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.mu_b_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.mu_V_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.r_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.N_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.T_max_param_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.k1_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))
        self.k1_prime_tilda = torch.nn.Parameter(torch.rand(1, requires_grad=True))

        #matrices (x3 for T,I,V) for the gradients
        self.m1 = torch.zeros((len(self.t), 3)); self.m1[:, 0] = 1
        self.m2 = torch.zeros((len(self.t), 3)); self.m2[:, 1] = 1
        self.m3 = torch.zeros((len(self.t), 3)); self.m3[:, 2] = 1

        #values for norm
        self.T_max = max(self.T)
        self.I_max = max(self.I)
        self.V_max = max(self.V)        
        self.T_min = min(self.T)
        self.I_min = min(self.I)
        self.V_min = min(self.V)    

        #normalize 
        self.T_hat = (self.T - self.T_min) / (self.T_max - self.T_min)
        self.I_hat = (self.I - self.I_min) / (self.I_max - self.I_min)
        self.V_hat = (self.V - self.V_min) / (self.V_max - self.V_min)

        #NN
        self.net_hiv = self.Net_hiv()
        self.params = list(self.net_hiv.parameters())
        self.params.extend(list([self.s_tilda, self.mu_T_tilda, self.mu_I_tilda, self.mu_b_tilda, self.mu_V_tilda, self.r_tilda, self.N_tilda, self.T_max_param_tilda, self.k1_tilda, self.k1_prime_tilda]))

        
    #force parameters to be in a range
    @property
    def s(self):
        return torch.tanh(self.s_tilda) + 10

    @property
    def mu_T(self):
        return torch.tanh(self.mu_T_tilda) * 0.01 + 0.02

    @property
    def mu_I(self):
        return torch.tanh(self.mu_I_tilda) 

    @property
    def mu_b(self):
        return torch.tanh(self.mu_b_tilda) 

    @property
    def mu_V(self):
        return torch.tanh(self.mu_V_tilda) * 3

    @property
    def r(self):
        return torch.tanh(self.r_tilda) * 0.1 

    @property
    def N(self):
        return torch.tanh(self.N_tilda) + 250

    @property
    def T_max_param(self):
        return torch.tanh(self.T_max_param_tilda) + 1500

    @property
    def k1(self):
        return torch.tanh(self.k1_tilda) * 3*10e-5

    @property
    def k1_prime(self):
        return torch.tanh(self.k1_prime_tilda) * 0.001
    
    #nets
    class Net_hiv(nn.Module): # input = [t]
        def __init__(self):
            super(DINN.Net_hiv, self).__init__()
            self.fc1=nn.Linear(1, 20) #takes t's
            self.fc2=nn.Linear(20, 20)
            self.fc3=nn.Linear(20, 20)
            self.fc4=nn.Linear(20, 20)
            self.fc5=nn.Linear(20, 20)
            self.fc6=nn.Linear(20, 20)
            self.fc7=nn.Linear(20, 20)
            self.fc8=nn.Linear(20, 20)
            self.out=nn.Linear(20, 3) #outputs T, I, V

        def forward(self, t):
            tiv=F.relu(self.fc1(t))
            tiv=F.relu(self.fc2(tiv))
            tiv=F.relu(self.fc3(tiv))
            tiv=F.relu(self.fc4(tiv))
            tiv=F.relu(self.fc5(tiv))
            tiv=F.relu(self.fc6(tiv))
            tiv=F.relu(self.fc7(tiv))
            tiv=F.relu(self.fc8(tiv))
            tiv=self.out(tiv)
            return tiv    

    def net_f(self, t_batch):       
        tiv_hat = self.net_hiv(t_batch)

        T_hat, I_hat, V_hat = tiv_hat[:,0], tiv_hat[:,1], tiv_hat[:,2]

        #T_t
        tiv_hat.backward(self.m1, retain_graph=True)
        T_hat_t = self.t.grad.clone()
        self.t.grad.zero_()

        #I_t
        tiv_hat.backward(self.m2, retain_graph=True)
        I_hat_t = self.t.grad.clone()
        self.t.grad.zero_()
        
        #V_t
        tiv_hat.backward(self.m3, retain_graph=True)
        V_hat_t = self.t.grad.clone()
        self.t.grad.zero_()
        
        #unnormalize
        T = self.T_min + (self.T_max - self.T_min) * T_hat
        I = self.I_min + (self.I_max - self.I_min) * I_hat
        V = self.V_min + (self.V_max - self.V_min) * V_hat

        f1_hat = T_hat_t - (self.s - self.mu_T * T + self.r * T * (1 - ((T + I) / self.T_max_param) - self.k1 * V * T)) / (self.T_max_param - self.T_min) 
        f2_hat = I_hat_t - (self.k1_prime * V * T - self.mu_I * I) / (self.I_max - self.I_min)         
        f3_hat = V_hat_t - (self.N * self.mu_b * I - self.k1 * V * T - self.mu_V * V) / (self.V_max - self.V_min)         

        return f1_hat, f2_hat, f3_hat, T_hat, I_hat, V_hat
    
    def load(self):
      # Load checkpoint
      try:
        checkpoint = torch.load(PATH + str(self.save)+'.pt') 
        print('\nloading pre-trained model...')
        self.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        self.losses = checkpoint['losses']
        print('loaded previous loss: ', loss)
      except RuntimeError :
          print('changed the architecture, ignore')
          pass
      except FileNotFoundError:
          pass

    def train(self, n_epochs):
      #try loading
      self.load()

      #train
      print('\nstarting training...\n')
      
      for epoch in range(n_epochs):
        #lists to hold the output (maintain only the final epoch)
        T_pred_list= []
        I_pred_list= []
        V_pred_list= []


        f1, f2, f3, T_pred, I_pred, V_pred = self.net_f(self.t_batch)
        self.optimizer.zero_grad()

        T_pred_list.append(self.T_min + (self.T_max - self.T_min) * T_pred) 
        I_pred_list.append(self.I_min + (self.I_max - self.I_min) * I_pred)
        V_pred_list.append(self.V_min + (self.V_max - self.V_min) * V_pred)

        loss = (torch.mean(torch.square(self.T_hat - T_pred)) + torch.mean(torch.square(self.I_hat - I_pred)) + torch.mean(torch.square(self.V_hat - V_pred)) + 
               torch.mean(torch.square(f1)) + torch.mean(torch.square(f2)) + torch.mean(torch.square(f3)))

        loss.backward()

        self.optimizer.step()
        #self.scheduler.step() 
        self.scheduler.step(loss) 

        self.losses.append(loss.item())

        if epoch % 1000 == 0:          
          print('\nEpoch ', epoch)

        #loss + model parameters update
        if epoch % 40000 == 0:
          #checkpoint save every 1000 epochs if the loss is lower
          print('\nSaving model... Loss is: ', loss)
          torch.save({
              'epoch': epoch,
              'model': self.state_dict(),
              'optimizer_state_dict': self.optimizer.state_dict(),
              'scheduler': self.scheduler.state_dict(),
              'loss': loss,
              'losses': self.losses,
              }, PATH + str(self.save)+'.pt')
          if self.save % 2 > 0: #its on 3
            self.save = 2 #change to 2
          else: #its on 2
            self.save = 3 #change to 3

          print('epoch: ', epoch)
          print('s: (goal 10)', self.s)
          print('\nmu_T: (goal 0.02)', self.mu_T)
          print('\nmu_I: (goal 0.26): ', self.mu_I)
          print('\nmu_b (goal 0.24): ', self.mu_b)
          print('\nmu_V: (goal 2.4): ', self.mu_V)
          print('\nr (goal 0.03): ', self.r)
          print('\nN (goal 250): ', self.N)
          print('\nT_max (goal 1500): ', self.T_max_param)
          print('\nk1 (goal 2.4*10e-5): ', self.k1)
          print('\nk1_prime (goal 2*10e-5): ', self.k1_prime)
          print('#################################')
        
      #plot
      plt.plot(self.losses, color = 'teal')
      plt.xlabel('Epochs')
      plt.ylabel('Loss')
      return T_pred_list, I_pred_list, V_pred_list

CPU times: user 49 µs, sys: 9 µs, total: 58 µs
Wall time: 61.5 µs


In [None]:
%%time

#this worked best
dinn = DINN(hiv_data[0], hiv_data[1], hiv_data[2], hiv_data[3]) #t, T_data, I_data, V_data

learning_rate = 1e-5
optimizer = optim.Adam(dinn.params, lr = learning_rate)
dinn.optimizer = optimizer

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(dinn.optimizer, factor=0.99, patience = 5000, verbose=True)
#scheduler = torch.optim.lr_scheduler.CyclicLR(dinn.optimizer, base_lr=1e-7, max_lr=1e-5, step_size_up=20000, mode="triangular2", cycle_momentum=False)

dinn.scheduler = scheduler

try: 
  T_pred_list, I_pred_list, V_pred_list = dinn.train(8000000) #train
except EOFError:
  if dinn.save == 2:
    dinn.save = 3
    T_pred_list, I_pred_list, V_pred_list = dinn.train(8000000) #train
  elif dinn.save == 3:
    dinn.save = 2
    T_pred_list, I_pred_list, V_pred_list = dinn.train(8000000) #train

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch  1907000

Epoch  1908000

Epoch  1909000

Epoch  1910000

Epoch  1911000

Epoch  1912000

Epoch  1913000

Epoch  1914000

Epoch  1915000

Epoch  1916000

Epoch  1917000

Epoch  1918000

Epoch  1919000

Epoch  1920000

Saving model... Loss is:  tensor(0.0017, dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:  1920000
s: (goal 10) tensor([10.6504], grad_fn=<AddBackward0>)

mu_T: (goal 0.02) tensor([0.0100], grad_fn=<AddBackward0>)

mu_I: (goal 0.26):  tensor([0.2468], grad_fn=<TanhBackward>)

mu_b (goal 0.24):  tensor([0.2336], grad_fn=<TanhBackward>)

mu_V: (goal 2.4):  tensor([2.3459], grad_fn=<MulBackward0>)

r (goal 0.03):  tensor([0.0784], grad_fn=<MulBackward0>)

N (goal 250):  tensor([249.9376], grad_fn=<AddBackward0>)

T_max (goal 1500):  tensor([1500.9678], grad_fn=<AddBackward0>)

k1 (goal 2.4*10e-5):  tensor([0.0001], grad_fn=<MulBackward0>)

k1_prime (goal 2*10e-5):  tensor([0.0002], grad_fn=<MulBackward

In [None]:
plt.plot(dinn.losses[5000000:], color = 'teal')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
fig = plt.figure(facecolor='w')
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)

ax.plot(hiv_data[0], hiv_data[1], 'pink', alpha=0.5, lw=2, label='T')
ax.plot(hiv_data[0], T_pred_list[0].detach().numpy(), 'navy', alpha=0.9, lw=2, label='T Prediction', linestyle='dashed')

ax.plot(hiv_data[0], hiv_data[2], 'violet', alpha=0.5, lw=2, label='I')
ax.plot(hiv_data[0], I_pred_list[0].detach().numpy(), 'dodgerblue', alpha=0.9, lw=2, label='I Prediction', linestyle='dashed')

ax.plot(hiv_data[0], hiv_data[3], 'darkgreen', alpha=0.5, lw=2, label='V')
ax.plot(hiv_data[0], V_pred_list[0].detach().numpy(), 'gold', alpha=0.9, lw=2, label='V Prediction', linestyle='dashed')


ax.set_xlabel('Time /days')
ax.set_ylabel('Number')
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
print('alpha1: (goal 1)', round(dinn.alpha1.item(),2))
print('\nalpha2: (goal 0)', round(dinn.alpha2.item(),2))
print('\nbeta: (goal 0.0075): ', round(dinn.beta.item(),4))
print('\nmu (goal 5): ', round(dinn.mu.item(),2))
print('\nu: (goal 0.515151515): ', round(dinn.u.item(),2))
print('\ntao (goal 0.58): ', round(dinn.tao.item(),2))


print('\nerror:')
print('alpha1: ', round((1-round(dinn.alpha1.item(),2))/1,2)*100,'%')
print('alpha2: ', round((0-round(dinn.alpha2.item(),2))/1e-20,2)*100,'%')
print('beta: ', round((0.0075-round(dinn.beta.item(),4))/0.0075,2)*100,'%')
print('mu: ', round((5-round(dinn.mu.item(),2))/5,2)*100,'%')
print('u: ', round((0.515151515-round(dinn.u.item(),2))/0.515151515,2)*100,'%')
print('tao: ', round((0.58-round(dinn.tao.item(),2))/0.58,2)*100,'%')

In [None]:
#vaccination! 

import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# Initial conditions
T0 = 1000
I0 = 10
V0 = 10e-3

# A grid of time points (in days)
t = np.linspace(0, 20, 50) 

s = dinn.s
mu_T = dinn.mu_T
mu_I = dinn.mu_I
mu_b = dinn.mu_b
mu_V = dinn.mu_V
r = dinn.r
N = dinn.N
T_max = dinn.T_max_param
k1 = dinn.k1
k1_prime = dinn.k1_prime


# The SIR model differential equations.
def deriv(y, t, s, mu_T, mu_V, mu_b, r, N, T_max, k1, k1_prime):
    T, I, V = y
    dTdt = s - mu_T * T + r * T * (1 - ((T + I) / T_max) - k1 * V * T)
    dIdt = k1_prime * V * T - mu_I * I
    dVdt = N * mu_b * I - k1 * V * T - mu_V * V

    return dTdt, dIdt, dVdt


# Initial conditions vector
y0 = T0, I0, V0
# Integrate the SIR equations over the time grid, t.
ret = odeint(deriv, y0, t, args=(s, mu_T, mu_V, mu_b, r, N, T_max, k1, k1_prime))
T, I, V = ret.T

# Plot the data on two separate curves for S(t), I(t)
fig = plt.figure(facecolor='w')
ax = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
ax.plot(t, T, 'violet', alpha=0.5, lw=2, label='T', linestyle='dashed')
ax.plot(t, I, 'darkgreen', alpha=0.5, lw=2, label='I', linestyle='dashed')
ax.plot(t, V, 'blue', alpha=0.5, lw=2, label='V', linestyle='dashed')
ax.set_xlabel('Time /days')
ax.yaxis.set_tick_params(length=0)
ax.xaxis.set_tick_params(length=0)
ax.grid(b=True, which='major', c='w', lw=2, ls='-')
legend = ax.legend()
legend.get_frame().set_alpha(0.5)
for spine in ('top', 'right', 'bottom', 'left'):
    ax.spines[spine].set_visible(False)
plt.show()

In [None]:
#calculate relative MSE loss
import math

T_total_loss = 0
T_den = 0
I_total_loss = 0
I_den = 0
V_total_loss = 0
V_den = 0

for timestep in range(len(t)):
  T_value = hiv_data[1][timestep] - T[timestep]
  T_total_loss += T_value**2
  T_den += (hiv_data[1][timestep])**2
  I_value = hiv_data[2][timestep] - I[timestep]
  I_total_loss += I_value**2
  I_den += (hiv_data[2][timestep])**2
  V_value = hiv_data[3][timestep] - V[timestep]
  V_total_loss += V_value**2
  V_den += (hiv_data[3][timestep])**2

T_total_loss = math.sqrt(T_total_loss/T_den)
I_total_loss = math.sqrt(I_total_loss/I_den)
V_total_loss = math.sqrt(V_total_loss/V_den)

print('T_total_loss: ', T_total_loss)
print('I_total_loss: ', I_total_loss)
print('V_total_loss: ', V_total_loss)