# Discrepancy Modeling for Epidemiology

This part of the homework focuses on discrepancy modeling for correcting a ODE-based model partially describing the real-world COVID-19 dynamics at population level. You will implement a simplified version of APHYNITY framework we studied in class and use different augmenting. We will model the spread of COVID-19 in Washtenaw county, MI for Winter 2023.

The following libraries will be necessary for your implementation. Ensure you have them installed before proceeding.

In [None]:
from torch import optim
import pandas as pd
import math
import torch
import torch.nn as nn
import numpy as np
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.integrate # SciPy module for integrating ODEs
solve_ivp = scipy.integrate.solve_ivp
from torchdiffeq import odeint  # Solver for ODEs integrated with PyTorch

## Part 1: Discrepancy Modeling with Vanilla Neural ODE

The SEIR (Susceptible - Exposed - Infectious - Recovered) model is a compartmental model used in epidemiology to describe the progression of infectious diseases within a population. It divides the population into four compartments, each representing a stage of the disease. The transitions between compartments are governed by differential equations, which depend on parameters such as infection rates and recovery rates.

\begin{split}\begin{aligned}
\frac{dS}{dt} & = -\frac{\beta SI}{N}\\
\frac{dE}{dt} & = \frac{\beta SI}{N} - \sigma E\\
\frac{dI}{dt} & = \sigma E - \gamma I\\
\frac{dR}{dt} & = \gamma I
\end{aligned}\end{split}
where $N = S + E + I + R$ is the total population.
<div align="center">
   <img src="seir.png" width="900">
</div>

**Compartments in the SEIR Model:**
* Susceptible (S): Represents individuals who are not yet infected but are at risk of infection.
* Exposed (E): Represents individuals who have been exposed to the infection but are not yet infectious.
* Infectious (I): Represents individuals who are infected and can transmit the disease to susceptible individuals.
* Recovered (R): Represents individuals who have recovered from the disease and are assumed to have immunity. In some variations of the model, this compartment can also include individuals who died from the disease (R = Recovered/Removed).

**Parameters:**
* $\beta$ : Transmission rate (how many people one infected individual infects per unit time).
* $\sigma$ : Rate at which exposed individuals become infectious ( 1 / incubation period ).
* $\gamma$ : Recovery rate ( 1 / infectious period ).

In our experiments, we will use the reproduction number which is defined as $R_0 = \beta / \gamma$; it represents the average number of secondary infections caused by a single infectious individual in a fully susceptible population.

**Task 1: Complete the SEIR model implementation**
* Implement the differential equations of the SEIR model
* Use the ODE parameters in `params`

In [None]:
class SEIR(nn.Module):
    """
    SEIR model for epidemiological modeling, commonly used for diseases like COVID-19.

    Args:
        N (int): Total population size.
        Rt (float): Reproduction number.
        beta_init (float): Initial value for the transmission rate parameter.
        sigma_init (float): Initial value for the transition rate parameter.
        reporting_rate (float): Fraction of cases reported, default is 0.025.
    """
    def __init__(self, N, Rt, beta_init, sigma_init, reporting_rate=0.025):
        super().__init__()

        self.Rt = Rt
        self.N = N
        self.reporting_rate = reporting_rate

        # Parameter transformation for bounded parameters in the range (0, 1)
        EPS = -1e-12
        self.beta = Parameter(torch.tensor(np.arctanh(1/.5*beta_init - 1 + EPS), dtype=torch.float32))
        self.sigma = Parameter(torch.tensor(np.arctanh(1/.5*sigma_init - 1 + EPS), dtype=torch.float32))

    def get_scaled_params(self, convert_cpu=False):
        """
        Converts real-value parameters to scaled values in the range (0, 1).

        Args:
            convert_cpu (bool): If True, detach and convert parameters to CPU for visualization.

        Returns:
            dict: Scaled model parameters ('beta', 'sigma', 'gamma').
        """
        params = {}
        # these take values in the domain 0-1
        params['beta'] = .5 * (torch.tanh(self.beta) + 1)
        params['sigma']  = .5 * (torch.tanh(self.sigma) + 1)
        params['gamma'] = params['beta'] / self.Rt

        # for printing and saving results, detach and send to cpu
        if convert_cpu:
            for k, v in params.items():
                if torch.is_tensor(v):
                    params[k] = v.detach().cpu().data.item()
        return params
    

    def forward(self, t, state):
        """
        Computes the ODE derivatives for the SEIR model.

        Args:
            t (float): Current time.
            state (Tensor): Current state values (S, E, I, R).

        Returns:
            Tensor: Derivatives of the state values.
        """
        params = self.get_scaled_params()
        
        # TODO: Implement SEIR equations 
        dS_dt  = ...
        dE_dt  = ...
        dI_dt = ...
        dR_dt  = ...

        dstate_dt = torch.stack([dS_dt, dE_dt, dI_dt, dR_dt], 0)

        return dstate_dt
    
    def new_reported_cases(self, E):
        """
        Computes the number of newly reported cases, accounting for underreporting.

        Args:
            E (Tensor): Number of exposed individuals.

        Returns:
            Tensor: Newly reported cases.
        """
        new_cases = self.get_scaled_params()['sigma'] * E
        return self.reporting_rate * new_cases


**Task 2: Complete the initial conditions for the SEIR model**
* Implement the initial conditions for SEIR model

In [None]:
class InitialConditions(nn.Module):
    """
    Learnable initial conditions for the SEIR model states (S0, E0, I0, R0).

    Args:
        N (int): Total population size.
        E0_init (float): Initial exposed population.
        I0_init (float): Initial infectious population.
        R0_init (float): Initial recovered population.
    """
    def __init__(self, N, E0_init, I0_init, R0_init):
        super().__init__()
        self.N = N
        self.E0 = Parameter(torch.tensor(E0_init, dtype=torch.float32))
        self.I0 = Parameter(torch.tensor(I0_init, dtype=torch.float32))
        self.R0 = Parameter(torch.tensor(R0_init, dtype=torch.float32))

    def forward(self):
        """
        Computes the initial susceptible population (S0) based on total population and initial conditions.

        Returns:
            Tensor: Initial state values [S0, E0, I0, R0].
        """
        # TODO: Complete construction of initial conditions
        S0 = ...
        return ...

**Task 3: Instantiate the SEIR model and its initial conditions**
* We are giving you real-world data for COVID-19 spread in Washtenaw county, MI for Winter 2023.
* We also give you some already learned parameters and initial conditions; use them in your instantiations.

In [None]:
# Import data and preprocess
df = pd.read_csv('Washtenaw_data.csv', header=0)
# Smooth data with a 7-day moving average
cases = df['Cases'].rolling(7, min_periods=1).mean().to_numpy()
cases = torch.tensor(cases, dtype=torch.float32)
y_len = len(cases)

# Model initialization
beta = 0.33
Rt = 1.19
sigma = 0.48
E0, I0, R0 = 600, 800, 1e4
POP_SIZE = 372258 # population size

# TODO: Instantiate the SEIR model and initial conditions
model_phy = SEIR(...)
init_conditions = InitialConditions(...)

Now let's visualize how well our SEIR model fit the real-world data.

In [None]:
# Generate data using your fitted SEIR model
time_points = torch.linspace(0, y_len, y_len)
initial_conditions = init_conditions.forward()  # Initial state [S0, E0, I0, R0]
states = odeint(model_phy, initial_conditions, time_points, method='rk4')  # Solve ODE

# Extract each compartment, calculate proportions
susceptible = states[:, 0].detach().numpy() / POP_SIZE
exposed = states[:, 1].detach().numpy() / POP_SIZE
infectious = states[:, 2].detach().numpy()/ POP_SIZE
recovered = states[:, 3].detach().numpy() / POP_SIZE
reported_cases = model_phy.new_reported_cases(E = states[:, 1])

# Create subplots
fig, ax = plt.subplots(1, 2, figsize=(10, 4.3))  # 1 row, 2 columns of subplots

# First subplot: SEIR compartments
ax[0].plot(time_points.numpy(), susceptible, label='Susceptible Population', color='blue')
ax[0].plot(time_points.numpy(), exposed, label='Exposed Population', color='orange')
ax[0].plot(time_points.numpy(), infectious, label='Infectious Population', color='green')
ax[0].plot(time_points.numpy(), recovered, label='Recovered Population', color='red')
ax[0].set_title('SEIR Model')
ax[0].set_xlabel('Time')
ax[0].set_ylabel('Population Fraction')
ax[0].legend(loc='best')
ax[0].grid()

# Second subplot: Reported vs Predicted Cases
ax[1].plot(cases, label='Actual Cases', color='blue')
ax[1].plot(reported_cases.detach().numpy(), label='Predicted Cases', color='green')
ax[1].set_title('Reported vs Predicted Cases')
ax[1].set_xlabel('Time')
ax[1].set_ylabel('Cases')
ax[1].legend(loc='best')
ax[1].grid()

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the plots
plt.show()

We can see this is OK but not yet a great fit. Let's implement the APHYNITY framework with a vanilla Neural ODE.

In [None]:
class NeuralODE(nn.Module):
    """
    Neural ODE for modeling the time derivative of states (dx/dt).

    Args:
        input_dim (int): Input dimensionality.
        hidden_dim (int): Hidden layer dimensionality.
        output_dim (int): Output dimensionality.
    """

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim, bias=None)
        )
        self._initialize_weights(self.net)

    def _initialize_weights(self, module):
        """
        Initializes the weights of the neural network using Xavier initialization.

        Args:
            module (nn.Module): Neural network module.
        """
        for layer in module:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

    def forward(self, t, x):
        """
        Computes the time derivative of the input state.

        Args:
            t (float): Current time step.
            x (Tensor): Current state.

        Returns:
            Tensor: Time derivative (dx/dt).
        """
        dx_dt = self.net(x)
        return dx_dt
    
    def get_derivative_dataset(self, t_set, x_set):
        """
        Handles derivatives for multiple data points.

        Args:
            t_set (Tensor): Set of time steps.
            x_set (Tensor): Set of states.

        Returns:
            Tensor: Derivatives for all data points.
        """
        res = []
        for i in range(t_set.size(0)):
            res.append(self.forward(t_set[i], x_set[i, :]))
        return torch.vstack(res)

class DerivativeEstimator(nn.Module):
    """
    Combines a physics-based model and a data-driven augmentation component.

    Args:
        model_phy (nn.Module): Incomplete physics-based model.
        model_aug (nn.Module): Data-driven augmentation component.
    """

    def __init__(self, model_phy, model_aug):
        super().__init__()
        self.model_phy = model_phy
        self.model_aug = model_aug

    def forward(self, t, state):
        """
        Combines predictions from the physics-based model and augmentation model.

        Args:
            t (float): Current time step.
            state (Tensor): Current state.

        Returns:
            Tensor: Combined derivatives.
        """
        res_phy = self.model_phy(t, state)
        res_aug = self.model_aug(t, state)
        return res_phy + res_aug

**Task 4: Instantiate the NeuralODE and DerivativeEstimator**
* `DerivativeEstimator` class will combine both the physical model and the data-driven model

In [None]:
# We will only train the data-driven model
for param in model_phy.parameters():
    param.requires_grad = False

for param in init_conditions.parameters():
    param.requires_grad = False

# Data-driven augmentation model
model_aug = NeuralODE(...)
combined_model = DerivativeEstimator(...)


**Task 5: Implement APHYNITY training for our combined model**
* Solve ODE for combined model
* Compute derivatives with the augmentation model
* Predicted reported cases
* Complete update on lambda
* Implement APHYNITY loss

In [None]:
# Training loop constants
LAMBDA = 1e-1
TAU = 1e-3
EPSILON = 1e-5

def update_lambda(lambda_value, loss):
    """
    Updates the lambda parameter based on the loss.

    Args:
        lambda_value (float): Current lambda parameter.
        loss (Tensor): Current loss value.

    Returns:
        float: Updated lambda parameter.
    """
    return lambda_value + TAU * loss.detach().cpu().item()

def train_model(model):
    """ Function to train our models """
    
    # Optimizer and scheduler
    params = list(model.model_aug.parameters()) # train only data-driven component
    optimizer = optim.RMSprop(params, lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=250)
    loss_fn = nn.MSELoss()

    # Set maximum training iterations and print frequency for logging
    max_epochs = 1000
    print_every = 50
    lambda_param = LAMBDA

    # Initial conditions
    y0 = init_conditions.forward()

    # Training loop
    for epoch in range(max_epochs+1):
        optimizer.zero_grad()

        # TODO: Solve the ODE
        states = ...

        # TODO: Compute derivatives with the augmentation model
        aug_derivatives = ...
        loss_aug = ((aug_derivatives.norm(p=2, dim=1) / (states.norm(p=2, dim=1) + EPSILON)) ** 2).mean()
        
        # Predicted reported cases and trajectory loss
        # TODO: Use the right argument to the function
        pred_traj = model.model_phy.new_reported_cases(...)
        traj_loss = loss_fn(pred_traj, cases)

        # Total loss
        # TODO: Implement APHYNITY loss
        total_loss = ...

        if epoch % print_every == 0:
            print(
                f"Epoch {epoch}: Loss Aug: {loss_aug.item():.4e}, "
                f"Trajectory Loss (sqrt): {traj_loss.sqrt().item():.4f}, "
                f"Learning Rate: {scheduler.get_last_lr()[0]:.1e}, "
                f"Lambda: {lambda_param:.3f}"
            )
            # TODO: Complete update on lambda
            lambda_param = update_lambda(...)

            # Break if learning rate is too small
            if scheduler.get_last_lr()[0] < 1e-4:
                break

        # Backpropagation
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1e-1)
        optimizer.step()
        scheduler.step(total_loss)

    return pred_traj
        
pred_traj_node = train_model(combined_model)

Let's visualize our new fitting.

In [None]:
# Create subplots
fig, ax = plt.subplots(1, 1, figsize=(5, 4.3))  

# Reported vs Predicted Cases
ax.plot(cases, label='Actual Cases', color='blue')
ax.plot(pred_traj_node.detach().numpy(), label='Predicted Cases', color='green')
ax.set_title('Reported vs Predicted Cases')
ax.set_xlabel('Time')
ax.set_ylabel('Cases')
ax.legend(loc='best')
ax.grid()

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the plots
plt.show()

## Part 2: Improving combined model with time-aware Neural ODE

In the previous section, we observed that the fitting of the Neural ODE was not yet up to standard. One way to improve this is by enhancing expressiveness through the incorporation of time embeddings, which we will implement as positional embeddings. If you do not know what a position embedding is, you can see https://kazemnejad.com/blog/transformer_architecture_positional_encoding/.

**Task 6: Use the `PositionalEmbedding` module in your Neural ODE**
* Create any modules you may need
* Update your forward function
* Aim for a trajectory loss (sqrt) of less than 1.8

In [None]:
class PositionalEmbedding(nn.Module):
    """
    Encodes positional information for time steps using sinusoidal embeddings.

    Args:
        d_model (int): Dimensionality of embeddings.
        max_len (int): Maximum length of time series. Default is 1000.
    """

    def __init__(self, d_model, max_len=1000):
        super().__init__()
        
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        # Initialize sinusoidal encoding
        pe = torch.zeros(max_len, d_model, dtype=torch.float32)
        # pe.require_grad = False
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register buffer to avoid tracking gradients
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Returns the positional encoding for the input position.

        Args:
            x (int): Position index.

        Returns:
            Tensor: Positional encoding vector.
        """
        return self.pe[int(x), :]


class TimeNeuralODE(NeuralODE):
    """
    Neural ODE with positional encoding for time.

    Args:
        input_dim (int): Input dimensionality.
        hidden_dim (int): Hidden layer dimensionality.
        output_dim (int): Output dimensionality.
    """

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__(input_dim, hidden_dim, output_dim)

        # TODO: Instantiate the positional embedding and other modules you may need
        ...

        # TODO: Initialize weights for new modules if needed
        ...

    def forward(self, t, x):
        """
        Computes the time derivative of the input state.

        Args:
            t (float): Current time step.
            x (Tensor): Current state.

        Returns:
            Tensor: Time derivative (dx/dt).
        """
        # TODO: Incorporate your position embedding
        dx_dt = ...
        return dx_dt

In [None]:
# TODO: Instantiate the updated NeuralODE and DerivativeEstimator, as done previously, to allow for a direct comparison
model_aug = TimeNeuralODE(...)
combined_model = DerivativeEstimator(...)
pred_traj_time_node = train_model(combined_model)

Visualize the fitting for this improved Neural ODE.

In [None]:
# Create subplots
fig, ax = plt.subplots(1, 1, figsize=(5, 4.3))  

# Reported vs Predicted Cases
ax.plot(cases, label='Actual Cases', color='blue')
ax.plot(pred_traj_time_node.detach().numpy(), label='Predicted Cases', color='green')
ax.set_title('Reported vs Predicted Cases')
ax.set_xlabel('Time')
ax.set_ylabel('Cases')
ax.legend(loc='best')
ax.grid()

# Adjust layout to prevent overlap
plt.tight_layout()

# Show the plots
plt.show()

**Task 7: Comment on the results**

Address the following points:
* Reflect on the trade-offs between interpretability (from the physics-based component) and flexibility (from the neural network component). In which scenarios would you prefer each of the models we tested?
* Describe the role of the positional embeddings and how they enhance the expressiveness of the Neural ODE. Could other types of embeddings or feature transformations be explored?
* In this exercise, we fitted COVID-19 data. What changes would you make if this model were to be used for forecasting (predict the future)? 