In [6]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests

In [20]:
# Fetches the data from the COVID19India website
class DataFetcher:
    def __init__(self, url="https://api.covid19india.org/data.json"):
        self.url = url
        self.json_data = None
        self.cases_time_series = None
        
    def fetch(self):
        r = requests.get(url=self.url)
        self.json_data = r.json()
        
        # Get the fields
        fields = list(self.json_data['cases_time_series'][0].keys())
        self.cases_time_series = {}
        
        for field in fields:
            if field == 'date':
                self.cases_time_series[field] = [x[field] for x in self.json_data['cases_time_series']]
            else:
                self.cases_time_series[field] = torch.Tensor([float(x[field]) for x in self.json_data['cases_time_series']])

<h2>A time dependent SIR model</h2>

$$  \frac{dS}{dt} = -\frac{\beta IS}{N} \tag1 $$
$$    \frac{dI}{dt} = \frac{\beta IS}{N} - \gamma I \tag2 $$
$$    \frac{dR}{dt} = \gamma I \tag3 $$


In [21]:
class SIR(nn.Module):
    def __init__(self, change_times, gamma, delta_t=0.1):
        """
            change_times: list of times at which beta could have changed 
                          (eg. when lockdown lifted/put in place)
            gamma: the fraction of infected which die -- constant number
            delta_t: the difference between two successive calculations of the 
                                differential equation
        """
        
        super(SIR, self).__init__()
        
        self.change_times = change_times
        
        # Initialize parameters with best estimates
        self.beta = nn.ParameterList([nn.Parameter(torch.zeros(1)) for i in change_times])
        self.gamma = nn.Parameter(torch.ones(1) * 0.01)
        
    def forward(self, n):
        """
            Get predictions for n days
        """
        
        S, I, R, t = 1.0, [0.0], [0.0], 0.0
        curr_beta_idx = 0 # Stores the current beta value we are using
        
        for i in range(n):
            # Compute updates to values
            delta_S = -self.beta[curr_beta_idx] * I[-1] * S
            delta_I = self.beta[curr_beta_idx] * I[-1] * S - self.gamma * I
            delta_R = self.gamma * I[-1]
            
            # Update all parameters
            S += delta_S
            I.append(I[-1] + delta_I)
            R.append(R[-1] + delta_R)
            t += self.delta_t
            
            # Update the value of beta
            if (curr_beta_idx == len(self.change_times) - 1):
                continue
            elif (self.change_times[curr_beta_idx + 1] <= t):
                curr_beta_idx += 1
        
        # Report values every day
        interval = int(1 / self.delta_t)
        return I[::interval], R[::interval]

In [25]:
fetcher = DataFetcher()
fetcher.fetch()

dict_keys(['dailyconfirmed', 'dailydeceased', 'dailyrecovered', 'date', 'totalconfirmed', 'totaldeceased', 'totalrecovered'])