In [1]:
import numpy as np
import pandas as pd
import joblib
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from scipy.integrate import solve_ivp           # only for the shortcut
from typing import List, Dict

# ---------------------- load artefacts -----------------------------
with open("/content/drive/MyDrive/Capstone_Trial_Methodology/models/ensemble_model.pkl", "rb") as f:
    ensemble_model = joblib.load(f)


# PyTorch ODE -------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

# Utils

In [2]:
target_class = 2
# sign shap for all classes
sign_shap_all = {0: {'Pain Intensity': 1,
  'Pain Location_No Pain': -1,
  'Oxygen Saturation (%)': 1,
  'Body Temperature (°C)': 1,
  'Pain Crises Frequency (past year)': -1,
  'Respiratory Rate (bpm)': 1,
  'Heart Rate (bpm)': 1,
  'Acute Chest Syndrome': 1,
  'Hospitalization History (past year)': 1,
  'Dizziness': -1,
  'Pain Location_Head': 1,
  'Nausea': -1,
  'Fever': 1,
  'Jaundice': -1,
  'Shortness of Breath': -1,
  'Swelling': -1,
  'Headache': -1,
  'Fatigue': -1},
 1: {'Pain Intensity': 1,
  'Pain Location_No Pain': 1,
  'Oxygen Saturation (%)': -1,
  'Body Temperature (°C)': -1,
  'Pain Crises Frequency (past year)': 1,
  'Respiratory Rate (bpm)': -1,
  'Heart Rate (bpm)': -1,
  'Acute Chest Syndrome': -1,
  'Hospitalization History (past year)': 1,
  'Dizziness': -1,
  'Pain Location_Head': 1,
  'Nausea': -1,
  'Fever': -1,
  'Jaundice': -1,
  'Shortness of Breath': -1,
  'Swelling': -1,
  'Headache': -1,
  'Fatigue': -1},
 2: {'Pain Intensity': -1,
  'Pain Location_No Pain': 1,
  'Oxygen Saturation (%)': 1,
  'Body Temperature (°C)': 1,
  'Pain Crises Frequency (past year)': -1,
  'Respiratory Rate (bpm)': 1,
  'Heart Rate (bpm)': 1,
  'Acute Chest Syndrome': 1,
  'Hospitalization History (past year)': -1,
  'Dizziness': 1,
  'Pain Location_Head': -1,
  'Nausea': 1,
  'Fever': 1,
  'Jaundice': 1,
  'Shortness of Breath': 1,
  'Swelling': 1,
  'Headache': 1,
  'Fatigue': 1}}

# select sign shap for the target class of interests
sign_shap = sign_shap_all[target_class]
sign_shap

{'Pain Intensity': -1,
 'Pain Location_No Pain': 1,
 'Oxygen Saturation (%)': 1,
 'Body Temperature (°C)': 1,
 'Pain Crises Frequency (past year)': -1,
 'Respiratory Rate (bpm)': 1,
 'Heart Rate (bpm)': 1,
 'Acute Chest Syndrome': 1,
 'Hospitalization History (past year)': -1,
 'Dizziness': 1,
 'Pain Location_Head': -1,
 'Nausea': 1,
 'Fever': 1,
 'Jaundice': 1,
 'Shortness of Breath': 1,
 'Swelling': 1,
 'Headache': 1,
 'Fatigue': 1}

# ODE Model

In [3]:
FEATURES = [
    "Pain Intensity", "Pain Location_No Pain", "Oxygen Saturation (%)",
    "Body Temperature (°C)", "Pain Crises Frequency (past year)",
    "Respiratory Rate (bpm)", "Heart Rate (bpm)", "Acute Chest Syndrome",
    "Hospitalization History (past year)", "Dizziness", "Pain Location_Head",
    "Nausea", "Fever", "Jaundice", "Shortness of Breath",
    "Swelling", "Headache", "Fatigue"
]

sign_vector = [sign_shap[f] for f in FEATURES]
sign_vec = torch.tensor(sign_vector, dtype=torch.float32, device=device)

In [4]:
N_FEAT = 18
STATE0 = torch.tensor([1., 0., 0.], device=device)   # initial [C0,C1,C2]

class ODEModel(nn.Module):
    def __init__(self, sign_vector):
        super().__init__()
        # log‑base rates  (5 parameters)
        self.log_base = nn.ParameterDict({
            k: nn.Parameter(torch.log(torch.tensor(v, device=device)))
            for k, v in {
                'lam01': 0.02, 'lam02': 0.005,
                'rho1':  0.03, 'rho2':  0.01,
                'kap':   0.002
            }.items()
        })
        # β matrices  (5×18)
        for k in self.log_base:
            init = 0.01 * torch.tensor(sign_vector, device=device, dtype=torch.float32)
            self.register_parameter(f"beta_{k}",
                                    nn.Parameter(init.clone()))

    # -----------------------------------------------
    # RK4 loop
    def forward(self, x, horizon_h):
        dt = 1.0 / 24
        steps = int(horizon_h)
        C = STATE0.repeat(x.size(0), 1)

        for _ in range(steps):
            k1 = self._rhs(C, x)
            k2 = self._rhs(C + 0.5*dt*k1, x)
            k3 = self._rhs(C + 0.5*dt*k2, x)
            k4 = self._rhs(C + dt*k3, x)
            C = C + dt/6 * (k1 + 2*k2 + 2*k3 + k4)

            # ------- renormalise & clamp --------------------
            C = torch.clamp(C, 1e-6, 1.0)
            C = C / C.sum(dim=1, keepdim=True)

        return C

    # -----------------------------------------------
    def _rhs(self, C, x):
        rates = {}
        for k in self.log_base:
            beta = getattr(self, f"beta_{k}")                    # (18,)
            # ----- log‑rate (clip to ±10) --------------------
            log_rate = self.log_base[k] + (x * beta).sum(dim=1)
            log_rate = torch.clamp(log_rate, -10.0, 10.0)
            rates[k] = torch.exp(log_rate)                       # (B,)

        C = torch.clamp(C, 0.0, 1.0)                             # avoid drift <0
        C0, C1, C2 = C[:, 0], C[:, 1], C[:, 2]
        dC0 = -(rates['lam01']+rates['lam02'])*C0 + rates['rho1']*C1 + rates['rho2']*C2
        dC1 =  rates['lam01']*C0 - (rates['rho1']+rates['kap'])*C1
        dC2 =  rates['lam02']*C0 + rates['kap']*C1 - rates['rho2']*C2
        return torch.stack([dC0, dC1, dC2], dim=1)

In [5]:
model_path = "/content/drive/MyDrive/Capstone_Trial_Methodology/models/ode_theta_torch.pt"

ode_model = ODEModel(sign_vector).to(device)
ode_model.load_state_dict(torch.load(model_path, map_location=device))
ode_model.eval()

ODEModel(
  (log_base): ParameterDict(
      (kap): Parameter containing: [torch.FloatTensor of size ]
      (lam01): Parameter containing: [torch.FloatTensor of size ]
      (lam02): Parameter containing: [torch.FloatTensor of size ]
      (rho1): Parameter containing: [torch.FloatTensor of size ]
      (rho2): Parameter containing: [torch.FloatTensor of size ]
  )
)

# Prediction Engines

In [6]:
# ------------------------------------------------------------------
# 1.  ENGINE A – survival shortcut
# ------------------------------------------------------------------
def survival_shortcut(prob_class, delta_t_hours, T_ref=24.0):
    """
    prob_class : float in (0,1)  -- ensemble probability TODAY
    delta_t_hours : time horizon
    T_ref : reference horizon that produced prob_class (24h by default)
    """
    hazard = -np.log(1.0 - prob_class) / T_ref
    return 1.0 - np.exp(-hazard * delta_t_hours)

In [7]:
# ------------------------------------------------------------------
# 2.  ENGINE B – batched ODE on GPU
# ------------------------------------------------------------------
@torch.no_grad()
def ode_probabilities_torch(x_scaled: np.ndarray,
                            horizon_h: float,
                            batch: int = 256) -> np.ndarray:
    """
    x_scaled : (N,18) numpy float32, already scaled
    returns  : (N,3) numpy probabilities at horizon_h
    """
    ds = TensorDataset(torch.tensor(x_scaled, dtype=torch.float32, device=device))
    dl = DataLoader(ds, batch_size=batch, shuffle=False)
    out = []
    for (xb,) in dl:
        out.append(ode_model(xb, horizon_h).cpu())
    return torch.cat(out, dim=0).numpy()

In [8]:
# ------------------------------------------------------------------
# 3.  Public API
# ------------------------------------------------------------------

def predict_crisis_risk(patient_row: "pd.Series",
                        horizons: List[int] = [1, 6, 12, 24],
                        use_ode: bool = True) -> Dict[int, Dict[str, float]]:
    """
    Returns P(no / acute / chronic) at each requested horizon.
    """
    # 1. Prepare features
    x_raw = patient_row[FEATURES].values.astype(np.float32).reshape(1, -1) # Change here to enforce float32 dtype

    # 2. Static ensemble probabilities
    # x_static = patient_row[FEATURES].values.reshape(1, -1)
    p_static = ensemble_model.predict_proba(patient_row)[0]   # (3,)

    results = {}
    for H in horizons:
        if use_ode:
            probs = ode_probabilities_torch(x_raw, H)[0]     # (3,)
        else:
            probs = np.zeros(3, dtype=np.float32)
            for cls in (1, 2):                                  # acute / chronic
                probs[cls] = survival_shortcut(p_static[cls], H)
            probs[0] = 1.0 - probs[1] - probs[2]
            probs = np.clip(probs, 0.0, 1.0)

        results[H] = {
            "no_crisis":       round(float(probs[0]), 7),
            "acute_crisis":    round(float(probs[1]), 7),
            "chronic_crisis":  round(float(probs[2]), 7),
        }
    return results

In [9]:
sample_patient = pd.DataFrame(
    {
        "Pain Intensity":                     [7.5],      # 0–10 VAS
        "Pain Location_No Pain":              [False],    # boolean
        "Oxygen Saturation (%)":              [93.0],     # %
        "Body Temperature (°C)":              [38.2],     # febrile
        "Pain Crises Frequency (past year)":  [4],        # count
        "Respiratory Rate (bpm)":             [24.0],     # breaths/min
        "Heart Rate (bpm)":                   [110.0],    # beats/min
        "Acute Chest Syndrome":               [0.0],      # 0 = no history
        "Hospitalization History (past year)":[2.0],      # admissions
        "Dizziness":                          [1.0],      # symptom score 0/1
        "Pain Location_Head":                 [True],     # boolean
        "Nausea":                             [1.0],
        "Fever":                              [1.0],
        "Jaundice":                           [0.0],
        "Shortness of Breath":                [1.0],
        "Swelling":                           [0.0],
        "Headache":                           [1.0],
        "Fatigue":                            [1.0]
    }
)

In [10]:
print("ODE-based forecast:")
print(pd.DataFrame(predict_crisis_risk(sample_patient, [1,6,12,24], use_ode=True)).T)

ODE-based forecast:
    no_crisis  acute_crisis  chronic_crisis
1    0.000001      0.999997        0.000002
6    0.000001      0.999988        0.000011
12   0.000001      0.999976        0.000023
24   0.000001      0.999954        0.000045


In [11]:
print("\nExponential shortcut forecast:")
print(pd.DataFrame(predict_crisis_risk(sample_patient, [1,6,12,24], use_ode=False)).T)


Exponential shortcut forecast:
    no_crisis  acute_crisis  chronic_crisis
1    0.917910      0.075187        0.006903
6    0.584926      0.374361        0.040713
12   0.311656      0.608576        0.079768
24   0.000039      0.846787        0.153173


In [12]:
!pip freeze > requirements.txt