# SCI ML Project
### Antonio Jimenez aoj268
### Ashton Cole 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [None]:
class PINN(nn.Module):
    def __init__(self, hidden_size, depth):
        super().__init__()
        # input t
        layers = [nn.Linear(1, hidden_size), nn.Tanh()]
        for _ in range(depth - 1):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.Tanh())
        # Output layer with 5 units for (s, e, i, r, d)
        layers.append(nn.Linear(hidden_size, 5))
        # Add softmax to enforce all components are positive and sum to 1
        layers.append(nn.Softmax(dim=1))
        
        self.net = nn.Sequential(*layers)

    def forward(self, t):
        return self.net(t)

In [None]:
def caputo_l1_diff(psi, alpha, dt):
    n = len(psi)
    # The derivative at t=0 is undefined for the L1 scheme
    derivatives = [torch.zeros(1, device=psi.device)] 
    
    # Pre-compute the log of the gamma function part for stability
    log_gamma_term = torch.lgamma(2.0 - alpha)

    for i in range(1, n):
        # Make vector of k values from 0 to i-1
        k = torch.arange(i, dtype=torch.float32, device=psi.device)
        
        # Calculate weights c_k^(i) 
        weights = ((k + 1)**(1 - alpha) - k**(1 - alpha))
        
        # Get the differences psi(t_{i-k}) - psi(t_{i-k-1})
        psi_diffs = psi[i - k.long()] - psi[i - k.long() - 1]
        
        summation = torch.sum(weights * psi_diffs.squeeze())
        
        # Combine everything to get the derivative at time t_i
        deriv_at_i = (1.0 / (dt**alpha * torch.exp(log_gamma_term))) * summation
        derivatives.append(deriv_at_i.unsqueeze(0))
        
    return torch.cat(derivatives).unsqueeze(1)

In [None]:
class FractionalSEIRD(nn.Module):
    def __init__(self, hidden_size, depth, initial_params):
        super().__init__()
        
        self.pinn = PINN(hidden_size, depth) 
        # trainable params
        self.raw_beta = nn.Parameter(torch.tensor([initial_params['beta']]))
        self.raw_sigma = nn.Parameter(torch.tensor([initial_params['sigma']]))
        self.raw_gamma = nn.Parameter(torch.tensor([initial_params['gamma']]))
        self.raw_mu = nn.Parameter(torch.tensor([initial_params['mu']]))
        # Init z_alpha such that the init alpha is close to 1.0
        self.z_alpha = nn.Parameter(torch.tensor([initial_params['z_alpha']])) # sigmoid(2.94) is approx 0.95

        weight_ic = 1.0
        weight_data = 1.0
        weight_phys = 1.0
        
        self.min_alpha = initial_params['min_alpha'] # Example minimum value for alpha
        self.dt = initial_params['dt']

    def beta(self):
        return nn.softplus(self.raw_beta)

    def sigma(self):
        return nn.softplus(self.raw_sigma)

    def gamma(self):
        return nn.softplus(self.raw_gamma)
        
    def mu(self):
        return nn.softplus(self.raw_mu)

    def alpha(self):
        # Restrict alpha to a specific range, (min_alpha, 1.0] 
        return self.min_alpha + (1.0 - self.min_alpha) * torch.sigmoid(self.z_alpha)
    
    def forward(self,t):
        self.pinn(t)

    def compute_loss(self, t_colloc, t_data, y_data, ic):
        # IC loss
        t_initial = t_colloc[0].unsqueeze(0) # get t_0
        y_initial_pred = self.forward(t_initial)
        loss_ic = nn.MSELoss(y_initial_pred - ic)
        
        # Data Loss
        y_data_pred = self.forward(t_data)
        loss_data = nn.MSELoss(y_data_pred - y_data)

        # Phys Loss
        y_all_pred = self.forward(t_colloc)
        s,e,i,r,d = y_all_pred.unbind(1)
        ds_dt = caputo_l1_diff(s, self.alpha, self.dt)
        de_dt = caputo_l1_diff(e, self.alpha, self.dt)
        di_dt = caputo_l1_diff(i, self.alpha, self.dt)
        dr_dt = caputo_l1_diff(r, self.alpha, self.dt)
        dd_dt = caputo_l1_diff(d, self.alpha, self.dt)

        # calculate RHS of equation 4
        num_living =  1 - d
        f_s = -self.beta() * s * i / num_living
        f_e = (self.beta() * s * i / num_living) - self.sigma() * e
        f_i = (self.sigma() * e) - (self.gamma()+ self.mu()) * i
        f_r = self.gamma() * i
        f_d = self.mu() * i

        # calc residuals (LHS - RHS = 0)
        residual_s = ds_dt - f_s
        residual_e = de_dt - f_e
        residual_i = di_dt - f_i
        residual_r = dr_dt - f_r
        residual_d = dd_dt - f_d

        all_residuals = torch.cat([residual_s, residual_e, residual_i, residual_r, residual_d], dim=1)
        loss_phys = torch.mean(all_residuals**2)

        tot_loss = weight_ic * loss_ic + weight_data * loss_data + weight_physics * loss_phys # cons loss handled b soft max, regularization loss handled by optimizer 
        return tot_loss