# Biol 359A | Parameter Estimation and Regularization
### Spring 2025, Week 6
Objectives:
- Gain intuition for parameter estimation strategy (cross validation)
- Explore the impact of initial conditions on the overall trajectory of the epidemic and final outcomes.
- Identify the initial conditions that create stable steady-state in a one-variable example

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression, Lasso, Ridge, ElasticNet
from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import mean_squared_error
from scipy.integrate import solve_ivp
from ipywidgets import interact, FloatSlider, IntSlider, FloatText, IntText, widgets

## Cross-validatation

### Generate synthetic data
Today we will start by working with in-silico data. The code below will generate the data.

In [None]:
def generate_data(n_samples=100, degree=3, noise_level=0.5, x_range=(-3, 3)):
    """
    Generate synthetic data with polynomial relationship and controlled noise.
    
    Parameters:
    -----------
    n_samples : int
        Number of samples to generate
    degree : int
        True polynomial degree of the data
    noise_level : float
        Standard deviation of the Gaussian noise
    x_range : tuple
        Range of x values (min, max)
    
    Returns:
    --------
    X : ndarray of shape (n_samples,)
        Feature values
    y : ndarray of shape (n_samples,)
        Target values with noise
    true_coef : ndarray
        True coefficients used to generate data
    """
    # Generate random x values within the specified range
    np.random.seed(42)
    X = np.random.uniform(x_range[0], x_range[1], n_samples)
    
    # Generate random coefficients for polynomial
    true_coef = np.random.randn(degree + 1)
    true_coef = true_coef / np.max(np.abs(true_coef)) * 3  # Scale coefficients
    
    # Generate y values based on polynomial relationship
    y_true = np.zeros(n_samples)
    for i in range(degree + 1):
        y_true += true_coef[i] * X**i
    
    # Add noise
    y = y_true + noise_level * np.random.randn(n_samples)
    X = X.reshape(-1, 1)
    return X, y, true_coef

### Define models

In [None]:
import warnings
from sklearn.exceptions import ConvergenceWarning

# Suppress only ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)
def get_model(model_type='None', alpha=1.0, l1_ratio=0.5):
    """Create model based on regularization choice."""
    if model_type == 'Lasso':
        return Lasso(alpha=alpha, max_iter=10000)
    elif model_type == 'Ridge':
        return Ridge(alpha=alpha)
    elif model_type == 'Elastic':
        return ElasticNet(alpha=alpha, l1_ratio=l1_ratio, max_iter=10000)
    else:
        return LinearRegression()

def create_polynomial_features(X, degree):
    """Create polynomial features from input data."""
    poly = PolynomialFeatures(degree=degree, include_bias=True)
    return poly.fit_transform(X)

### Evaluate validation and test data

In [None]:
def perform_cross_validation(X, y, test_degree, n_folds=5, model_type='None', alpha=1.0, l1_ratio=0.5, plot_huge_loss=False):
    """
    Perform k-fold cross-validation for polynomial regression.
    
    Parameters:
    -----------
    X : ndarray
        Input features
    y : ndarray
        Target values
    test_degree : int
        Degree of polynomial to test
    n_folds : int
        Number of folds for cross-validation
    model_type : str
        Type of regularization to use
    alpha : float
        Regularization strength
    l1_ratio : float
        Mixing parameter for ElasticNet
        
    Returns:
    --------
    cv_results : dict
        Dictionary with cross-validation results
    """
    # Create polynomial features
    poly = PolynomialFeatures(degree=test_degree, include_bias=True)
    X_poly = poly.fit_transform(X)
    
    # Initialize k-fold cross-validation
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
    
    # Initialize lists to store results
    fold_train_losses = []
    fold_val_losses = []
    fold_coefs = []
    fold_intercepts = []
    
    # Perform cross-validation
    for i, (train_idx, val_idx) in enumerate(kf.split(X)):
        # Split data into train and validation sets
        # Important: We split the original features first, then transform
        X_train_original, X_val_original = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Apply polynomial transformation to each fold separately
        # This prevents data leakage between folds
        poly_fold = PolynomialFeatures(degree=test_degree, include_bias=True)
        X_train = poly_fold.fit_transform(X_train_original)
        X_val = poly_fold.transform(X_val_original)  # Use same transformation
        
        # Create and fit model
        model = get_model(model_type, alpha, l1_ratio)
        model.fit(X_train, y_train)
        
        # Calculate train and validation losses
        y_train_pred = model.predict(X_train)
        y_val_pred = model.predict(X_val)
        
        train_loss = mean_squared_error(y_train, y_train_pred)
        val_loss = mean_squared_error(y_val, y_val_pred)

        # Store results
        fold_train_losses.append(train_loss)
        fold_val_losses.append(val_loss)
        if val_loss > 10000 and plot_huge_loss:
            fig, ax = plt.subplots(figsize=(14/3, 5))
            ax.scatter(X_train_original, y_train, label='train')
            ax.scatter(X_val_original, y_val, label='val')
            
            X_flat = X_train_original.flatten()
            X_line = np.linspace(min(min(X_flat), min(X_val_original.flatten())), max(max(X_flat), max(X_val_original.flatten())), 10000).reshape(-1, 1)
            X_poly_line = create_polynomial_features(X_line, test_degree)
            y_pred = model.predict(X_poly_line)
            ax.plot(X_line, y_pred, 'g-', linewidth=2, label=f'Model fit (degree={test_degree})')
            ax.legend()
            ax.set_xlabel("X")
            ax.set_ylabel("Y")
            ax.set_title(f"Data and Model fit (degree = {test_degree})")
            
            
        # Store model parameters
        if hasattr(model, 'coef_'):
            fold_coefs.append(model.coef_)
        else:
            fold_coefs.append(None)
            
        if hasattr(model, 'intercept_'):
            fold_intercepts.append(model.intercept_)
        else:
            fold_intercepts.append(None)
    
    # Calculate average losses
    avg_train_loss = np.mean(fold_train_losses)
    avg_val_loss = np.mean(fold_val_losses)
    
    # Create dictionary with results
    cv_results = {
        'fold_train_losses': fold_train_losses,
        'fold_val_losses': fold_val_losses,
        'fold_coefs': fold_coefs,
        'fold_intercepts': fold_intercepts,
        'avg_train_loss': avg_train_loss,
        'avg_val_loss': avg_val_loss
    }
    
    return cv_results

# This function is no longer needed since we're handling the train-test split 
# in the main cross_validate_and_visualize function
# Keeping definition for backward compatibility but marking as deprecated
def train_test_model(X, y, test_degree, test_size=0.2, model_type='None', alpha=1.0, l1_ratio=0.5):
    """
    DEPRECATED: Use direct train-test split in cross_validate_and_visualize instead
    
    Train model on train set and evaluate on test set.
    
    Parameters:
    -----------
    X : ndarray
        Input features
    y : ndarray
        Target values
    test_degree : int
        Degree of polynomial to test
    test_size : float
        Proportion of data to use for testing
    model_type : str
        Type of regularization to use
    alpha : float
        Regularization strength
    l1_ratio : float
        Mixing parameter for ElasticNet
        
    Returns:
    --------
    test_results : dict
        Dictionary with test results
    """
    # Create polynomial features
    poly = PolynomialFeatures(degree=test_degree, include_bias=True)
    
    # Split data into train and test sets - split original data first
    X_train_orig, X_test_orig, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    
    # Apply polynomial transformation after splitting
    X_train = poly.fit_transform(X_train_orig)
    X_test = poly.transform(X_test_orig)
    
    # Create and fit model
    model = get_model(model_type, alpha, l1_ratio)
    model.fit(X_train, y_train)
    
    # Calculate train and test losses
    y_train_pred = model.predict(X_train)
    y_test_pred = model.predict(X_test)
    
    train_loss = mean_squared_error(y_train, y_train_pred)
    test_loss = mean_squared_error(y_test, y_test_pred)
    
    # Create dictionary with results
    test_results = {
        'train_loss': train_loss,
        'test_loss': test_loss,
        'model': model
    }
    
    return test_results

### Visualization

In [None]:
def interactive_polynomial_regression():
    @interact(
        true_degree=widgets.IntSlider(min=1, max=9, step=1, value=3, description='True Degree:'),
        noise_level=widgets.FloatSlider(min=0.1, max=20.0, step=0.1, value=0.1, description='Noise Level:'),
        n_samples=widgets.IntSlider(min=20, max=200, step=10, value=100, description='Sample Size:'),
        test_degree=widgets.IntSlider(min=1, max=15, step=1, value=3, description='Test Degree:'),
        regularization=widgets.RadioButtons(
            options=['None', 'Lasso', 'Ridge', 'Elastic'],
            value='None',
            description='Regularization:'
        ),
        alpha=widgets.FloatLogSlider(
            min=-5, max=1, step=0.1, value=0.1, base=10, description='Alpha (Reg. Strength):'
        ),
        l1_ratio=widgets.FloatSlider(
            min=0.0, max=1.0, step=0.05, value=0.5, description='L1 Ratio (Elastic):'
        ),
        n_folds=widgets.IntSlider(min=3, max=15, step=1, value=5, description='CV Folds:'),
        plot_huge_loss=widgets.RadioButtons(
            options=[False, True],
            value=False,
            description='Plot val loss'
        )
    )
    def cross_validate_and_visualize(true_degree, noise_level, n_samples, test_degree,
                                    regularization, alpha, l1_ratio, n_folds, plot_huge_loss):
        # Generate synthetic data
        X, y, true_coef = generate_data(n_samples, true_degree, noise_level)
        
        # First split into train and test sets
        X_train_full, X_test, y_train_full, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )
        
        # Perform cross-validation on training data only
        cv_results = perform_cross_validation(
            X_train_full, y_train_full, test_degree, n_folds, regularization, alpha, l1_ratio, plot_huge_loss
        )
        
        # Train final model on all training data and evaluate on test set
        # Create polynomial features for train and test
        poly = PolynomialFeatures(degree=test_degree, include_bias=True)
        X_train_poly = poly.fit_transform(X_train_full)
        X_test_poly = poly.transform(X_test)
        
        # Create and fit model
        model = get_model(regularization, alpha, l1_ratio)
        model.fit(X_train_poly, y_train_full)
        
        # Calculate train and test losses
        y_train_pred = model.predict(X_train_poly)
        y_test_pred = model.predict(X_test_poly)
        
        train_loss = mean_squared_error(y_train_full, y_train_pred)
        test_loss = mean_squared_error(y_test, y_test_pred)
        
        # Create dictionary with test results
        test_results = {
            'train_loss': train_loss,
            'test_loss': test_loss,
            'model': model
        }
        
        # Create figure with 3 subplots
        fig, axes = plt.subplots(1, 3, figsize=(14, 5))
        
        # Plot 1: Data and model fit
        ax1 = axes[0]
        
        # Plot original data
        X_flat = X_train_full.flatten()
        X_test_flat = X_test.flatten()
        ax1.scatter(X_flat, y_train_full, alpha=0.6, label='Data points (Train)')
        ax1.scatter(X_test_flat, y_test, alpha=0.6, label='Data points (Test)')
        
        # Plot true function
        X_line = np.linspace(min(X_flat), max(X_flat), 100).reshape(-1, 1)
        y_true = np.zeros(100)
        for i in range(true_degree + 1):
            y_true += true_coef[i] * X_line.flatten()**i
        ax1.plot(X_line, y_true, 'r-', linewidth=2, label='True function')
        
        # Plot model fit
        X_poly_line = create_polynomial_features(X_line, test_degree)
        y_pred = test_results['model'].predict(X_poly_line)
        ax1.plot(X_line, y_pred, 'g-', linewidth=2, label=f'Model fit (degree={test_degree})')
        
        ax1.set_title(f'Data and Model Fit\nTrue degree: {true_degree}, Test degree: {test_degree}')
        ax1.set_xlabel('X')
        ax1.set_ylabel('y')
        ax1.legend()
        
        # Plot 2: Validation loss for each fold
        ax2 = axes[1]
        
        folds = list(range(1, n_folds + 1))
        ax2.bar(
            [f - 0.2 for f in folds], 
            cv_results['fold_train_losses'], 
            width=0.4, 
            color='blue', 
            alpha=0.6, 
            label='Train Loss'
        )
        ax2.bar(
            [f + 0.2 for f in folds], 
            cv_results['fold_val_losses'], 
            width=0.4, 
            color='red', 
            alpha=0.6, 
            label='Validation Loss'
        )
        
        ax2.axhline(
            cv_results['avg_train_loss'], 
            color='blue', 
            linestyle='--', 
            alpha=0.8,
            label=f'Avg Train Loss: {cv_results["avg_train_loss"]:.4f}'
        )
        ax2.axhline(
            cv_results['avg_val_loss'], 
            color='red', 
            linestyle='--', 
            alpha=0.8,
            label=f'Avg Val Loss: {cv_results["avg_val_loss"]:.4f}'
        )
        ax2.axhline(
            test_results['test_loss'], 
            color='green', 
            linestyle='--', 
            alpha=0.8,
            label=f'Test Loss: {test_results["test_loss"]:.4f}'
        )
        
        ax2.set_title(f'Train and Validation Loss for Each Fold\n({n_folds}-fold Cross-Validation)')
        ax2.set_xlabel('Fold')
        ax2.set_ylabel('Mean Squared Error')
        ax2.set_xticks(folds)
        ax2.legend()
        
        # Plot 3: Train, validation, test loss comparison across model complexities
        ax3 = axes[2]
        
        # Evaluate models with different degrees
        degrees = list(range(1, 16))
        train_losses = []
        val_losses = []
        test_losses = []
        
        for degree in degrees:
            # Perform cross-validation
            cv_res = perform_cross_validation(
                X, y, degree, n_folds, regularization, alpha, l1_ratio
            )
            
            # Train final model on all data and evaluate on test set
            test_res = train_test_model(
                X, y, degree, 0.2, regularization, alpha, l1_ratio
            )
            
            train_losses.append(cv_res['avg_train_loss'])
            val_losses.append(cv_res['avg_val_loss'])
            test_losses.append(test_res['test_loss'])
        
        ax3.plot(degrees, train_losses, 'o-', color='blue', label='Train Loss')
        ax3.plot(degrees, val_losses, 'o-', color='red', label='Validation Loss')
        ax3.plot(degrees, test_losses, 'o-', color='green', label='Test Loss')
        
        ax3.axvline(
            true_degree, 
            color='black', 
            linestyle='--', 
            alpha=0.5,
            label=f'True Degree: {true_degree}'
        )
        
        ax3.set_title('Model Performance vs. Complexity')
        ax3.set_xlabel('Polynomial Degree')
        ax3.set_ylabel('Mean Squared Error')
        ax3.legend()
        
        plt.tight_layout()
        plt.show()
        
        # Print detailed results
        print("\n=== Cross-Validation Results ===\n")
        coefs_array = np.array(cv_results['fold_coefs'])
        # Create a DataFrame for fold-by-fold results
        fold_results = pd.DataFrame({
            'Fold': list(range(1, n_folds + 1)),
            'Train Loss': cv_results['fold_train_losses'],
            'Validation Loss': cv_results['fold_val_losses']
        })
        
        print(fold_results)
        
        print(f"\nAverage Train Loss: {cv_results['avg_train_loss']:.4f}")
        print(f"Average Validation Loss: {cv_results['avg_val_loss']:.4f}")
        print(f"Test Loss: {test_results['test_loss']:.4f}")
        
        print("\n=== Model Parameters for Each Fold ===\n")
        
        # Create a DataFrame for model parameters
        param_cols = [f'Coef {i}' if i > 0 else 'Intercept' for i in range(coefs_array.shape[1])]
        coefs_array[:, 0] = cv_results['fold_intercepts']
        param_data = pd.DataFrame(coefs_array, columns=param_cols)
        param_data.insert(0, 'Fold', list(range(1, n_folds + 1)))
        
        print(param_data)
        
        # Print true coefficients
        print("\n=== True Coefficients ===\n")
        true_coef_names = [f'Coef {i}' if i > 0 else 'Intercept' for i in range(len(true_coef))]
        true_coef_df = pd.DataFrame([true_coef], columns=true_coef_names)
        print(true_coef_df)


# Run the interactive application
interactive_polynomial_regression()

## Write equations for diagrams

### One-variable model

In [None]:
# Define the ODE system for a single variable
def ode_system(t, x):
    # dx/dt = -0.1(x-1)(x-2)(x-4)
    dxdt = -0.1 * (x - 1) * (x - 2) * (x - 4)
    return [dxdt]

# Function to solve and plot the system
def plot_ode_solution(x_init=0.5, t_max=20):
    # Set up the figure
    plt.figure(figsize=(10, 5))
    
    # Solve the ODE system
    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, 1000)
    initial_conditions = [x_init]
    
    solution = solve_ivp(
        ode_system, 
        t_span, 
        initial_conditions, 
        method='RK45', 
        t_eval=t_eval
    )
    
    t = solution.t
    x = solution.y[0]
    
    # Plot the solution trajectory
    plt.plot(t, x, 'b-', linewidth=2, label='x(t)')
    plt.axhline(y=1, color='r', linestyle='--', alpha=0.6, label='x=1 (equilibrium)')
    plt.axhline(y=2, color='g', linestyle='--', alpha=0.6, label='x=2 (equilibrium)')
    plt.axhline(y=4, color='m', linestyle='--', alpha=0.6, label='x=4 (equilibrium)')
    
    plt.xlabel('Time (t)', fontsize=12)
    plt.ylabel('x', fontsize=12)
    plt.title(f'Solution vs Time (Initial: x={x_init})', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Add phase line diagram (direction field on x-axis)
    plt.figure(figsize=(10, 2))
    x_range = np.linspace(-1, 6, 1000)
    dxdt = -0.1 * (x_range - 1) * (x_range - 2) * (x_range - 4)
    
    plt.plot(x_range, dxdt, 'k-', linewidth=2)
    plt.axhline(y=0, color='gray', linestyle='-', alpha=0.6)
    plt.axvline(x=1, color='r', linestyle='--', alpha=0.6)
    plt.axvline(x=2, color='g', linestyle='--', alpha=0.6)
    plt.axvline(x=4, color='m', linestyle='--', alpha=0.6)
    
    plt.xlabel('x', fontsize=12)
    plt.ylabel('dx/dt', fontsize=12)
    plt.title('Phase Line Diagram: dx/dt vs x', fontsize=14)
    plt.grid(True, alpha=0.3)
    
    # Add arrows to show direction
    for x_val in [0, 1.5, 3, 5]:
        derivative = -0.1 * (x_val - 1) * (x_val - 2) * (x_val - 4)
        direction = 'right' if derivative > 0 else 'left'
        plt.annotate('', xy=(x_val + 0.3 if direction == 'right' else x_val - 0.3, 0), 
                    xytext=(x_val, 0),
                    arrowprops=dict(arrowstyle="->", color='blue'))
    
    plt.tight_layout()
    plt.show()
  
# Create interactive widget
interact(
    plot_ode_solution,
    x_init=FloatSlider(min=-1.0, max=6.0, step=0.1, value=0.5, description='x(0):'),
    t_max=FloatSlider(min=5.0, max=50.0, step=1.0, value=10.0, description='Max Time:'),
);

### SIR model

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from ipywidgets import interact, FloatSlider, IntSlider, fixed
%matplotlib inline

# Define the SIR model with births and deaths
def sir_model(t, y, r_B, r_S, r_I, r_R, r_D):
    S, I, R = y
    # Model equations
    dSdt = r_B + r_S*R - r_I*S*I    # Susceptible population change
    dIdt = r_I*S*I - r_R*I - r_D*I   # Infected population change 
    dRdt = r_R*I - r_S*R             # Recovered population change
    
    return [dSdt, dIdt, dRdt]

# Function to solve and visualize the SIR model
def plot_sir_model(S_init=0.9, I_init=0.1, R_init=0.0, 
                  r_B=0.05, r_S=0.1, r_I=0.3, r_R=0.1, r_D=0.01, 
                  t_max=100, total_population=10000):
    
    # Ensure initial populations sum to 1 (normalized)
    total_frac = S_init + I_init + R_init
    S_init, I_init, R_init = S_init/total_frac, I_init/total_frac, R_init/total_frac
    
    # Solve the ODE system
    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, int(t_max))
    initial_conditions = [S_init, I_init, R_init]
    solution = solve_ivp(
        lambda t, y: sir_model(t, y, r_B, r_S, r_I, r_R, r_D), 
        t_span, 
        initial_conditions, 
        method='RK45', 
        t_eval=t_eval
    )
    
    t = solution.t
    S = solution.y[0]
    I = solution.y[1]
    R = solution.y[2]
    N = S + I + R  # Total population (as fraction)
    # Convert fractions to numbers of individuals
    S_count = S * total_population
    I_count = I * total_population
    R_count = R * total_population
    N_count = N * total_population
    
    # Create separate figures for better compatibility with Colab
    plt.figure(figsize=(10, 5))
    plt.plot(t, S_count, 'b-', linewidth=2, label='Susceptible (S)')
    plt.plot(t, I_count, 'r-', linewidth=2, label='Infected (I)')
    plt.plot(t, R_count, 'g-', linewidth=2, label='Recovered (R)')
    
    plt.xlabel('Time', fontsize=12)
    plt.ylabel('Number of Individuals', fontsize=12)
    plt.title('SIR Model with Births, Deaths and Reinfection', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    
    # Display equations and parameters on the plot
    eqn_text = "SIR Model Equations:\n" \
              "dS/dt = r_B + r_S*R - r_I*S*I\n" \
              "dI/dt = r_I*S*I - r_R*I - r_D*I\n" \
              "dR/dt = r_R*I - r_S*R"
    plt.figtext(0.02, 0.02, eqn_text, fontsize=10, bbox=dict(facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()
    
    # Analysis
    print("Model Analysis:")
    print(f"Initial conditions: S(0)={int(S_init * total_population)} individuals, " +
          f"I(0)={int(I_init * total_population)} individuals, " +
          f"R(0)={int(R_init * total_population)} individuals")
    print(f"Final values: S({t_max})={int(S[-1] * total_population)} individuals, " +
          f"I({t_max})={int(I[-1] * total_population)} individuals, " +
          f"R({t_max})={int(R[-1] * total_population)} individuals")
    print(f"Total final population: {int(N[-1] * total_population)} individuals")

# Create text input widgets for initial conditions and parameters
w_S_init = FloatText(value=0.99, description='S(0) fraction:')
w_I_init = FloatText(value=0.01, description='I(0) fraction:')
w_R_init = FloatText(value=0.0, description='R(0) fraction:')

w_r_B = FloatText(value=0.03, description='r_B (birth):')
w_r_S = FloatText(value=0.01, description='r_S (suscept):')
w_r_I = FloatText(value=0.5, description='r_I (infection):')
w_r_R = FloatText(value=0.05, description='r_R (recovery):')
w_r_D = FloatText(value=0.02, description='r_D (death):')
w_t_max = FloatText(value=100, description='Max time:')
w_population = IntText(value=1000, description='Population:')
# Link the interact function to the UI
widgets = {
    'S_init': w_S_init,
    'I_init': w_I_init,
    'R_init': w_R_init,
    'r_B': w_r_B,
    'r_S': w_r_S,
    'r_I': w_r_I,
    'r_R': w_r_R,
    'r_D': w_r_D,
    't_max': w_t_max,
    'total_population': w_population
}

# Display the interactive model
interact(plot_sir_model, **widgets);

### Sensitivity analysis for SIR model

In [None]:
from IPython import display
display.Image("images/SIR.jpg")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from ipywidgets import interact, widgets, HBox, VBox, Layout
%matplotlib inline

# Define the SIR model with births and deaths
def sir_model(t, y, r_B, r_S, r_I, r_R, r_D):
    S, I, R = y
    # Model equations
    dSdt = r_B + r_S*R - r_I*S*I    # Susceptible population change
    dIdt = r_I*S*I - r_R*I - r_D*I   # Infected population change 
    dRdt = r_R*I - r_S*R             # Recovered population change
    
    return [dSdt, dIdt, dRdt]

# Function to solve the SIR model for a specific set of parameters
def solve_sir_model(S_init, I_init, R_init, r_B, r_S, r_I, r_R, r_D, t_max):
    # Ensure initial populations sum to 1 (normalized)
    total_frac = S_init + I_init + R_init
    S_init, I_init, R_init = S_init/total_frac, I_init/total_frac, R_init/total_frac
    
    # Solve the ODE system
    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, int(t_max))
    initial_conditions = [S_init, I_init, R_init]
    solution = solve_ivp(
        lambda t, y: sir_model(t, y, r_B, r_S, r_I, r_R, r_D), 
        t_span, 
        initial_conditions, 
        method='RK45', 
        t_eval=t_eval
    )
    
    return solution.t, solution.y

# Function to plot sensitivity analysis
def plot_sensitivity_analysis(S_init=0.99, I_init=0.01, R_init=0.0,
                             r_B=0.03, r_S=0.01, r_R=0.05, r_D=0.02,
                             t_max=100, total_population=1000,
                             infection_rates="0.3, 0.5, 0.7",
                             plot_susceptible=True, plot_infected=True, plot_recovered=True):
    
    # Parse the infection rates from the input string
    try:
        r_I_values = [float(rate.strip()) for rate in infection_rates.split(',')]
    except ValueError:
        print("Error: Please enter valid infection rates as comma-separated numbers.")
        return
    
    # Create the figure
    plt.figure(figsize=(12, 4))
    
    # Line styles and colors for different compartments
    line_styles = {'S': ('blue', '-'), 'I': ('red', '-'), 'R': ('green', '-')}
    
    # Dictionary to store results for analysis
    results = {}
    
    # Solve and plot for each infection rate
    for i, r_I in enumerate(r_I_values):
        # Solve the model with this infection rate
        t, solution = solve_sir_model(S_init, I_init, R_init, r_B, r_S, r_I, r_R, r_D, t_max)
        
        # Extract the results
        S = solution[0]
        I = solution[1]
        R = solution[2]
        
        # Store peak infection and time for analysis
        peak_infection = max(I)
        peak_time = t[np.argmax(I)]
        final_recovered = R[-1]
        
        results[r_I] = {
            'peak_infection': peak_infection,
            'peak_time': peak_time,
            'final_recovered': final_recovered
        }
        
        # Convert to actual population numbers
        S_count = S * total_population
        I_count = I * total_population
        R_count = R * total_population
        
        # Adjust line style for different infection rates
        ls_mod = ['-', '--', ':', '-.'][i % 4]
        
        # Plot selected compartments
        if plot_susceptible:
            plt.plot(t, S_count, color=line_styles['S'][0], linestyle=ls_mod, 
                    label=f'S (r_I={r_I})')
        if plot_infected:
            plt.plot(t, I_count, color=line_styles['I'][0], linestyle=ls_mod, 
                    label=f'I (r_I={r_I})')
        if plot_recovered:
            plt.plot(t, R_count, color=line_styles['R'][0], linestyle=ls_mod, 
                    label=f'R (r_I={r_I})')
    
    plt.xlabel('Time', fontsize=12)
    plt.ylabel('Number of Individuals', fontsize=12)
    plt.title('SIR Model Sensitivity Analysis: Effect of Infection Rate', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print analysis
    print("\nSensitivity Analysis Results:")
    print("-" * 60)
    print(f"{'Infection Rate':<15} {'Peak Infected':<20} {'Time to Peak':<15} {'Total Infected':<15}")
    print("-" * 60)
    
    for r_I, data in results.items():
        print(f"{r_I:<15.3f} {data['peak_infection']*total_population:<20.1f} {data['peak_time']:<15.1f} {data['final_recovered']*total_population:<15.1f}")

# Create input widgets
w_S_init = widgets.FloatText(value=0.99, description='S(0) fraction:', style={'description_width': 'initial'})
w_I_init = widgets.FloatText(value=0.01, description='I(0) fraction:', style={'description_width': 'initial'})
w_R_init = widgets.FloatText(value=0.0, description='R(0) fraction:', style={'description_width': 'initial'})

w_r_B = widgets.FloatText(value=0.03, description='r_B (birth):', style={'description_width': 'initial'})
w_r_S = widgets.FloatText(value=0.01, description='r_S (suscept):', style={'description_width': 'initial'})
w_r_R = widgets.FloatText(value=0.05, description='r_R (recovery):', style={'description_width': 'initial'})
w_r_D = widgets.FloatText(value=0.02, description='r_D (death):', style={'description_width': 'initial'})

w_t_max = widgets.FloatText(value=100, description='Max time:', style={'description_width': 'initial'})
w_population = widgets.IntText(value=1000, description='Population:', style={'description_width': 'initial'})
w_infection_rates = widgets.Text(
    value='0.3, 0.5, 0.7',
    description='Infection rates:',
    style={'description_width': 'initial'},
    layout=Layout(width='50%')
)

# Checkboxes for selecting which curves to plot
w_plot_susceptible = widgets.Checkbox(value=True, description='Plot Susceptible (S)')
w_plot_infected = widgets.Checkbox(value=True, description='Plot Infected (I)')
w_plot_recovered = widgets.Checkbox(value=True, description='Plot Recovered (R)')

interact_manual = interact(
    plot_sensitivity_analysis,
    S_init=w_S_init,
    I_init=w_I_init,
    R_init=w_R_init,
    r_B=w_r_B,
    r_S=w_r_S,
    r_R=w_r_R,
    r_D=w_r_D,
    t_max=w_t_max,
    total_population=w_population,
    infection_rates=w_infection_rates,
    plot_susceptible=w_plot_susceptible,
    plot_infected=w_plot_infected,
    plot_recovered=w_plot_recovered
);

### SAIR model

In [None]:
from IPython import display
display.Image("images/SIAR.jpg")

In [None]:
# Define the SAIR model with asymptomatic cases
def sair_model(t, y, r_B, r_S, r_I, r_A, r_RI, r_RA, r_DI, r_DA):
    S, A, I, R = y
    
    # Model equations
    dSdt = r_B + r_S*R - r_I*S*I - r_A*S*A     # Susceptible population change
    dAdt = r_A*S*A - r_RA*A - r_DA*A           # Asymptomatic population change
    dIdt = r_I*S*I - r_RI*I - r_DI*I           # Infected/symptomatic population change
    dRdt = r_RI*I + r_RA*A - r_S*R             # Recovered population change
    
    return [dSdt, dAdt, dIdt, dRdt]

# Function to solve and visualize the SAIR model
def plot_sair_model(S_init=0.99, A_init=0.005, I_init=0.005, R_init=0.0, 
                   r_B=0.03, r_S=0.01, r_I=0.5, r_A=0.3, 
                   r_RI=0.05, r_RA=0.07, r_DI=0.02, r_DA=0.005, 
                   t_max=100, total_population=1000):
    
    # Ensure initial populations sum to 1 (normalized)
    total_frac = S_init + A_init + I_init + R_init
    S_init, A_init, I_init, R_init = S_init/total_frac, A_init/total_frac, I_init/total_frac, R_init/total_frac
    
    # Solve the ODE system
    t_span = (0, t_max)
    t_eval = np.linspace(0, t_max, int(t_max))
    initial_conditions = [S_init, A_init, I_init, R_init]
    
    solution = solve_ivp(
        lambda t, y: sair_model(t, y, r_B, r_S, r_I, r_A, r_RI, r_RA, r_DI, r_DA), 
        t_span, 
        initial_conditions, 
        method='RK45', 
        t_eval=t_eval
    )
    
    t = solution.t
    S = solution.y[0]
    A = solution.y[1]
    I = solution.y[2]
    R = solution.y[3]
    N = S + A + I + R  # Total population (as fraction)
    
    # Convert fractions to numbers of individuals
    S_count = S * total_population
    A_count = A * total_population
    I_count = I * total_population
    R_count = R * total_population
    N_count = N * total_population
    
    # Create separate figures for better compatibility with Colab
    plt.figure(figsize=(10, 6))
    plt.plot(t, S_count, 'b-', linewidth=2, label='Susceptible (S)')
    plt.plot(t, A_count, 'm-', linewidth=2, label='Asymptomatic (A)')
    plt.plot(t, I_count, 'r-', linewidth=2, label='Symptomatic (I)')
    plt.plot(t, R_count, 'g-', linewidth=2, label='Recovered (R)')
    plt.plot(t, A_count + I_count, 'y--', linewidth=1.5, label='Total Infected (A+I)')
    
    plt.xlabel('Time (days)', fontsize=12)
    plt.ylabel('Number of Individuals', fontsize=12)
    plt.title('SAIR Model with Asymptomatic Cases', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=10)
    
    # Display equations and parameters on the plot
    eqn_text = "SAIR Model Equations:\n" \
              "dS/dt = r_B + r_S*R - r_I*S*I - r_A*S*A\n" \
              "dA/dt = r_A*S*A - r_RA*A - r_DA*A\n" \
              "dI/dt = r_I*S*I - r_RI*I - r_DI*I\n" \
              "dR/dt = r_RI*I + r_RA*A - r_S*R"
    plt.figtext(0.02, 0.02, eqn_text, fontsize=10, bbox=dict(facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.show()
# Create text input widgets for initial conditions and parameters
w_S_init = FloatText(value=0.99, description='S(0) fraction:')
w_A_init = FloatText(value=0.005, description='A(0) fraction:')
w_I_init = FloatText(value=0.005, description='I(0) fraction:')
w_R_init = FloatText(value=0.0, description='R(0) fraction:')

w_r_B = FloatText(value=0.03, description='r_B (birth):')
w_r_S = FloatText(value=0.01, description='r_S (suscept):')
w_r_I = FloatText(value=0.5, description='r_I (sympt infect):')
w_r_A = FloatText(value=0.3, description='r_A (asympt infect):')
w_r_RI = FloatText(value=0.05, description='r_RI (sympt recov):')
w_r_RA = FloatText(value=0.07, description='r_RA (asympt recov):')
w_r_DI = FloatText(value=0.02, description='r_DI (sympt death):')
w_r_DA = FloatText(value=0.005, description='r_DA (asympt death):')
w_t_max = FloatText(value=100, description='Max time (days):')
w_population = IntText(value=1000, description='Population:')

# Link the interact function to the UI
widgets = {
    'S_init': w_S_init,
    'A_init': w_A_init,
    'I_init': w_I_init,
    'R_init': w_R_init,
    'r_B': w_r_B,
    'r_S': w_r_S,
    'r_I': w_r_I,
    'r_A': w_r_A,
    'r_RI': w_r_RI,
    'r_RA': w_r_RA,
    'r_DI': w_r_DI,
    'r_DA': w_r_DA,
    't_max': w_t_max,
    'total_population': w_population
}

# Display the interactive model
interact(plot_sair_model, **widgets);
