In [None]:
!pip install numpy scipy pandas scikit-learn matplotlib seaborn torch hyperopt  copulae shapely tqdm online-conformal

In [None]:
!pip install torchdiffeq
!pip install epiweeks

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
import os


project_path = '/content/drive/MyDrive/cp-trajectory-master/'

if not os.path.exists(project_path):
    print(f"ERROR: Path not found: {project_path}")
else:
    sys.path.append(project_path)
    print(f"added {project_path} to system path.")

    print("Files found:", os.listdir(project_path))

In [None]:
from torch import optim
import pandas as pd
import math
import torch
import torch.nn as nn
import numpy as np
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp
from torchdiffeq import odeint
from IPython.display import Image, display

import cp_ours
import score_func
import core.eval.eval_utils as eval
from cp_ours import run_aci_simulation
from score_func import score_function
from cp_ours import collect_scores

In [None]:
class SIR(nn.Module):
    """
    SEIR model for epidemiological modeling, commonly used for diseases like COVID-19.

    Args:
        N (int): Total population size.
        Rt (float): Reproduction number.
        beta_init (float): Initial value for the transmission rate parameter.
        sigma_init (float): Initial value for the transition rate parameter.
        reporting_rate (float): Fraction of cases reported, default is 0.025.
    """
    def __init__(self, N, beta_init, gamma_init):
        super().__init__()

        self.N = N

        EPS = -1e-12
        self.beta = Parameter(torch.tensor(np.arctanh(2*beta_init  - 1), dtype=torch.float32))
        self.gamma = Parameter(torch.tensor(np.arctanh(2 * gamma_init - 1 ), dtype=torch.float32))

    def get_scaled_params(self, convert_cpu=False):
        """
        Converts real-value parameters to scaled values in the range (0, 1).

        Args:
            convert_cpu (bool): If True, detach and convert parameters to CPU for visualization.

        Returns:
            dict: Scaled model parameters ('beta', 'sigma', 'gamma').
        """
        params = {}

        params['beta'] = 0.5 * (torch.tanh(self.beta) + 1)
        params['gamma'] = 0.5 * (torch.tanh(self.gamma) + 1)
        params['Rt'] = params['beta'] / (params['gamma'] + 1e-8)


        if convert_cpu:
            for k, v in params.items():
                if torch.is_tensor(v):
                    params[k] = v.detach().cpu().data.item()
        return params


    def forward(self, t, state):
        """
        Computes the ODE derivatives for the SEIR model.

        Args:
            t (float): Current time.
            state (Tensor): Current state values (S, E, I, R).

        Returns:
            Tensor: Derivatives of the state values.
        """
        params = self.get_scaled_params()

        S = state[0]
        I = state[1]
        R = state[2]

        dS_dt = -params['beta'] * S * I / self.N

        dI_dt = (params['beta'] * S * I / self.N) - (params['gamma'] * I)

        dR_dt = params['gamma'] * I

        return torch.stack([dS_dt, dI_dt, dR_dt], 0)

In [None]:
class InitialConditions(nn.Module):
    """
    Learnable initial conditions for the SEIR model states (S0, E0, I0, R0).

    Args:
        N (int): Total population size.
        E0_init (float): Initial exposed population.
        I0_init (float): Initial infectious population.
        R0_init (float): Initial recovered population.
    """
    def __init__(self, N, I0_init, R0_init):
        super().__init__()
        self.N = N
        #self.E0 = Parameter(torch.tensor(E0_init, dtype=torch.float32))
        self.I0 = Parameter(torch.tensor(I0_init, dtype=torch.float32))
        self.R0 = Parameter(torch.tensor(R0_init, dtype=torch.float32))

    def forward(self):
        """
        Computes the initial susceptible population (S0) based on total population and initial conditions.

        Returns:
            Tensor: Initial state values [S0, E0, I0, R0].
        """
        # TODO: Complete construction of initial conditions
        S0 = self.N  - self.I0 - self.R0
        return torch.stack([S0, self.I0, self.R0], 0)

In [None]:

# Model initialization
beta = 0.33
Rt = 1.19
sigma = 0.48
E0, I0, R0 = 600, 800, 1e4
POP_SIZE = 372258 # population size


# TODO: Instantiate the SEIR model and initial conditions
init_conditions = InitialConditions(POP_SIZE, I0, R0)

In [None]:
# create data
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def simulate_sir(pop_size, beta, gamma, I0, R0, T):

    model_phy = SIR(pop_size,  beta, gamma).cpu()
    init_conditions = InitialConditions(pop_size,  I0, R0).cpu()

    initial_conditions = init_conditions()
    time_points = torch.linspace(0, T, T)

    states = odeint(
        model_phy,
        initial_conditions,
        time_points,
        method="rk4"
    )

    return states


# build dataset
def build_dataset(n=1000, T=200, pop_size=POP_SIZE):
    X = torch.empty(n, T, 3, dtype=torch.float32)
    Y = torch.empty(n, T, 3, dtype=torch.float32)

    time_points = torch.linspace(0, T, T)
    t_norm = (time_points / T).float()

    for i in range(n):
        if i % 100 == 0:
            print(f"Generated {i}/{n} trajectories")


        beta = torch.empty(1).uniform_(0.1, 0.7).item()
        gamma = torch.empty(1).uniform_(0.05, 0.5).item()
        I0, R0 =  800, 1e4

        states = simulate_sir(pop_size, beta, gamma, I0, R0, T)
        #normalizing
        states = states.detach() / pop_size

        params = torch.tensor([beta, gamma]).repeat(T, 1)

        times = t_norm.view(T, 1)

        input_vec = torch.cat([params, times], dim=1)

        X[i] = input_vec
        Y[i] = states

    return X, Y


In [None]:
import torch
import torch.nn as nn
import math

class SIRTransformer(nn.Module):
    def __init__(self, hidden=128, nhead=4, num_layers=4, K=8, seq_len=200):
        super().__init__()
        self.K = K
        self.hidden = hidden
        self.seq_len = seq_len

        input_dim = 2 + (2 * K) + 1

        self.input_projection = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, hidden),
            nn.Dropout(0.1)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden,
            nhead=nhead,
            dim_feedforward=hidden*4,
            batch_first=True,
            dropout=0.1,
            activation='gelu',
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.out_head = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.GELU(),
            nn.Linear(hidden, 3)
        )

    def time_features(self, t):
        feats = []
        for k in range(1, self.K + 1):
            feats.append(torch.sin(2 * torch.pi * k * t))
            feats.append(torch.cos(2 * torch.pi * k * t))
        return torch.cat(feats, dim=-1)

    def forward(self, x):
        total_points = x.size(0)

        if total_points % self.seq_len == 0:
            batch_size = total_points // self.seq_len
            current_seq_len = self.seq_len
        else:
            batch_size = total_points
            current_seq_len = 1

        x_reshaped = x.view(batch_size, current_seq_len, 3)

        params = x_reshaped[:, :, :2]
        t = x_reshaped[:, :, 2:]
        t_feats = self.time_features(t)
        combined_feats = torch.cat([params, t,t_feats], dim=-1)

        h = self.input_projection(combined_feats)

        h = self.transformer(h)

        out = self.out_head(h)

        out = torch.softmax(out, dim=-1)

        return out.reshape(total_points, 3)

In [None]:
def build_dataset(n=1000, T=200, pop_size=POP_SIZE, n_cal = 100, n_test = 100):
    X = torch.empty(n, T, 3, dtype=torch.float32)
    Y = torch.empty(n, T, 3, dtype=torch.float32)

    #ncal and ntest sets of 10 trajcetories and 1 ground truth trajectory at index 0
    X_cal = torch.empty(n_cal, 11, T, 3, dtype=torch.float32)
    Y_cal = torch.empty(n_cal, 11, T, 3, dtype=torch.float32)

    X_test = torch.empty(n_test, 11, T, 3, dtype=torch.float32)
    Y_test = torch.empty(n_test, 11, T, 3, dtype=torch.float32)

    time_points = torch.linspace(0, T, T)
    t_norm = (time_points / T).float()
    I0, R0 = 800, 1e4

    for i in range(n_cal):
      if i % 100 == 0:
          print(f"Generated {i}/{n_cal}  cal trajectories")

      #create range
      base_beta  = np.random.uniform(0.6, 0.7)
      base_gamma = np.random.uniform(0.4, 0.5)
      #width of the calibration set
      delta_b = 0.05
      delta_g = 0.05

      lower_b = max(0.6, base_beta  - delta_b)
      upper_b = min(0.7, base_beta  + delta_b)

      lower_g = max(0.4, base_gamma - delta_g)
      upper_g = min(0.5, base_gamma + delta_g)
      candidate_betas = []
      candidate_gammas = []

      for j in range(10):

        gamma = np.random.uniform(lower_g, upper_g)
        beta  = np.random.uniform(lower_b, upper_b)
        candidate_betas.append(beta)
        candidate_gammas.append(gamma)
        states = simulate_sir(pop_size, beta, gamma, I0, R0, T)

        states = states.detach() / pop_size
        params = torch.tensor([beta, gamma]).repeat(T, 1)
        times = t_norm.view(T, 1)
        input_vec = torch.cat([params, times], dim=1)
        X_cal[i, j+1] = input_vec
        Y_cal[i, j+1] = states


      #create GT
      beta_true  = float(np.mean(candidate_betas))
      gamma_true = float(np.mean(candidate_gammas))
      states = simulate_sir(pop_size, beta_true, gamma_true, I0, R0, T)
      states = states.detach() / pop_size
      params = torch.tensor([beta_true, gamma_true]).repeat(T,1)
      input_vec = torch.cat([params, times], dim=1)
      X_cal[i,0] = input_vec
      Y_cal[i,0] = states


    for i in range(n_test):
        if i % 100 == 0:
            print(f"Generated {i}/{n_test}  test trajectories")
        base_beta  = np.random.uniform(0.6, 0.7)
        base_gamma = np.random.uniform(0.4, 0.5)

        #smaller interval than calibration
        delta_b = 0.05
        delta_g = 0.05

        lower_b = max(0.6, base_beta  - delta_b)
        upper_b = min(0.7, base_beta  + delta_b)

        lower_g = max(0.4, base_gamma - delta_g)
        upper_g = min(0.5, base_gamma + delta_g)
        candidate_betas = []
        candidate_gammas = []

        for j in range(10):

          gamma = np.random.uniform(lower_g, upper_g)
          beta  = np.random.uniform(lower_b, upper_b)
          candidate_betas.append(beta)
          candidate_gammas.append(gamma)
          states = simulate_sir(pop_size, beta, gamma, I0, R0, T)

          states = states.detach() / pop_size
          params = torch.tensor([beta, gamma]).repeat(T, 1)
          times = t_norm.view(T, 1)
          input_vec = torch.cat([params, times], dim=1)
          X_test[i, j+1] = input_vec
          Y_test[i, j+1] = states
        #create GT
        beta_true  = float(np.mean(candidate_betas))
        gamma_true = float(np.mean(candidate_gammas))

        states = simulate_sir(pop_size, beta_true, gamma_true, I0, R0, T)
        states = states.detach() / pop_size
        params = torch.tensor([beta_true, gamma_true]).repeat(T,1)
        input_vec = torch.cat([params, times], dim=1)
        X_test[i,0] = input_vec
        Y_test[i,0] = states


    for i in range(n):
        if i % 100 == 0:
            print(f"Generated {i}/{n} trajectories")

        #outbreak vs decay
        if i < n // 2:

            gamma = np.random.uniform(0.1, 0.4)
            beta  = np.random.uniform(gamma + 0.05, 0.7)
        else:

            gamma = np.random.uniform(0.2, 0.6)
            beta  = np.random.uniform(0.1, gamma - 0.05)


        states = simulate_sir(pop_size, beta, gamma, I0, R0, T)

        states = states.detach() / pop_size

        params = torch.tensor([beta, gamma]).repeat(T, 1)
        times = t_norm.view(T, 1)
        input_vec = torch.cat([params, times], dim=1)

        X[i] = input_vec
        Y[i] = states


    indices = torch.randperm(n)
    indices_cal = torch.randperm(n_cal)
    indices_test = torch.randperm(n_test)
    return X[indices], Y[indices], X_cal[indices_cal], Y_cal[indices_cal], X_test[indices_test ], Y_test[indices_test]

In [None]:
train_X, train_Y, X_cal, Y_cal, X_test, Y_test = build_dataset(
    n=2000,
    T=200,
    n_cal=100,
    n_test=10
)

print(f"Dataset made:")
print(f"  train_X: {train_X.shape}")
print(f"  train_Y: {train_Y.shape}")
print(f"  X_cal:   {X_cal.shape}")
print(f"  Y_cal:   {Y_cal.shape}")
print(f"  X_test:  {X_test.shape}")
print(f"  Y_test:  {Y_test.shape}")

In [None]:
from torch.utils.data import TensorDataset, DataLoader
N = train_X.shape[0]
train_N = int(0.8 * N)
val_X, val_Y = train_X[train_N:], train_Y[train_N:]
train_loader = DataLoader(
    TensorDataset(train_X, train_Y),
    batch_size=64,
    shuffle=True,
    drop_last=False
)

In [None]:
for xb, yb in train_loader:
    print("xb shape:", xb.shape)
    print("yb shape:", yb.shape)
    break

In [None]:
class MSLELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        log_pred = torch.log(pred + 1e-6)

        log_target = torch.log(target + 1e-6)

        return torch.mean((log_pred - log_target)**2)


sur = SIRTransformer(hidden=128, nhead=4,K=2, num_layers=4, seq_len=200 ).to(device)

opt = torch.optim.Adam(sur.parameters(), lr=1e-4)
loss_fn = MSLELoss()

for epoch in range(200):
    sur.train()
    total_loss = 0.0


    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)

        xb_flat = xb.view(-1, 3)
        yb_flat = yb.view(-1, 3)

        pred = sur(xb_flat)

        loss = loss_fn(pred, yb_flat)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item() * xb.size(0)

    #validation
    if epoch % 10 == 0:
        sur.eval()
        with torch.no_grad():
            v_xb = val_X.to(device).view(-1, 3)
            v_yb = val_Y.to(device).view(-1, 3)

            val_pred = sur(v_xb)
            val_loss = loss_fn(val_pred, v_yb).item()

        print(f"[Epoch {epoch:03d}] Train {total_loss/len(train_loader.dataset):.6f} | Val {val_loss:.6f}")

In [None]:
class Transformer_DataGenerator:
    def __init__(self, model, true_trajectory, candidate_params, pop_size, T_horizon, device='cuda'):
        self.model = model
        self.truth = true_trajectory
        self.candidates = candidate_params
        self.pop_size = pop_size
        self.H = T_horizon
        self.device = device
        self.state_idx = 1
        self.current_subset = "Transformer_Ensemble"

    def get_reference_time(self, t_idx):
        return t_idx

    def get_trajectory_samples(self, t_idx, random=False):
        H_eff = min(self.H, len(self.truth) - t_idx)

        truth_seq = self.truth[t_idx:t_idx + H_eff, self.state_idx]
        y_truth = truth_seq.cpu().numpy() * self.pop_size

        N = len(self.candidates)
        TRAIN_T = 200
        t_start = t_idx / TRAIN_T
        t_end = (t_idx + H_eff) / TRAIN_T
        t_grid = torch.linspace(t_start, t_end, H_eff, device=self.device).view(-1, 1)

        params_expanded = self.candidates.to(self.device).unsqueeze(1).repeat(1, H_eff, 1)
        t_expanded = t_grid.unsqueeze(0).repeat(N, 1, 1)
        model_input = torch.cat([params_expanded, t_expanded], dim=2)

        with torch.no_grad():
            self.model.seq_len = H_eff
            pred_flat = self.model(model_input.view(-1, 3))
            pred_seq = pred_flat.view(N, H_eff, 3)
            samples = pred_seq[:, :, self.state_idx].cpu().numpy() * self.pop_size


        return y_truth, samples


In [None]:
#Runnning CP traj without calibration
#params
HORIZON = 10
start_t = 20
TRAIN_T = 200
ALPHA = 0.1

TRUTH_T = 200
TOTAL_DAYS = 130



#craft an ensemble of random params between a interval
n_members = 20

candidate_betas = 0.3 + (0.6 - 0.3) * torch.rand(n_members)
candidate_gammas = 0.1 + (0.3 - 0.1) * torch.rand(n_members)

ensemble_params = torch.stack([candidate_betas, candidate_gammas], dim=1)


#pick a ground truth (choosing average for now)
beta_true = torch.mean(candidate_betas).item()
gamma_true = torch.mean(candidate_gammas).item()

print(f"\nGround Truth: Beta={beta_true:3f}, Gamma={gamma_true:4f}")

# generate true trajectory
truth_states = simulate_sir(
    POP_SIZE,
    beta_true,
    gamma_true,
    I0=800,
    R0=1e4,
    T=TRUTH_T
)
#normalize
truth_traj = truth_states.detach() / POP_SIZE

#initialize data gen
data_gen = Transformer_DataGenerator(
    model=sur,
    true_trajectory=truth_traj,
    candidate_params=ensemble_params,
    pop_size=POP_SIZE,
    T_horizon=HORIZON,
    device=device
)


#CP params
cp_params = {
    'gamma': 0.2,
    'power': 0.5,
    'B': 20,
    'var_a': 0.5,
    'score_window': 10,
    'optim_arg': {'e_coeff_init': 0.01}
}

print("Starting CP-Traj")
#run cptraj
timestamps, ground_truths, prediction_intervals = run_aci_simulation(
    data_generator=data_gen,
    score_func_args={'type': 'pcp', 'optional_args': {}},
    S_max=POP_SIZE/4,
    alpha=ALPHA,
    T_obs=TOTAL_DAYS,
    H=HORIZON,
    N=n_members,
    start_t=start_t,
    params=cp_params,
    method_opt='cpt',
    plot=True,
)

print("CP-traj Complete")


In [None]:
#needed to copy over this function due to problems with importing cp Traj into colab
def collect_scores(data_generator: Transformer_DataGenerator, score_func_args, T_obs=1000, H=15, N=20, start_t=10, twodim=False, subset_name=None):
    """
    Collection of information for ACI simulation. Get Fs and S_max from additional subsets.
    """


    score_func = score_function

    # Score function parameters
    score_function_type = score_func_args.get('type', 'abs-r')
    score_function_optional_args = score_func_args.get('optional_args', {})

    # Main simulation loop
    record_list = []
    for t in range(T_obs):
        current_time = data_generator.get_reference_time(t)
        y_truth, samp = data_generator.get_trajectory_samples(t, random=False)
        samp = samp[:N]
        for h in range(H):
            record = {
                'time': current_time,
                'time_idx': t,
                'horizon': h,
                'score': score_func(y_truth, samp, h, type=score_function_type, optional_args=score_function_optional_args),
                'ground_truth': y_truth[h],
                'lat': np.nanmean(samp[:, h, 0]) if twodim else np.nanmean(samp[:, h]),
                'lon': np.nanmean(samp[:, h, 1]) if twodim else 0.0,
                'subset': subset_name,
            }
            record_list.append(record)
    return record_list

In [None]:
import pickle
import json



cal_scores_list = []

print("Collecting Calibration Scores")
score_func = score_function
# Loop through calibration set samples
for i in range(50):
    # get GT amd inputs
    truth_traj = Y_cal[i, 0]
    ensemble_inputs = X_cal[i, 1:]
    ensemble_params = ensemble_inputs[:, 0, :2]


    # Initialize Generator
    cal_gen = Transformer_DataGenerator(
        model=sur,
        true_trajectory=truth_traj.to(device),
        candidate_params=ensemble_params.to(device),
        pop_size=POP_SIZE,
        T_horizon=10,
        device=device
    )

    # Collect scores for this sample
    batch_scores = collect_scores(
        data_generator=cal_gen,
        score_func_args={'type': 'pcp', 'optional_args': {}},
        T_obs=130,
        H=10,
        N=len(ensemble_params),
        start_t=20,
        subset_name='calibration'
    )

    cal_scores_list.extend(batch_scores)

print(f"Collected {len(cal_scores_list)} calibration scores.")

#save scores
with open('calibration_scores.pkl', 'wb') as f:
    pickle.dump(cal_scores_list, f)

cp_config = {
    'gamma': 0.1,
    'score_window': 10,
    'alpha': 0.1,
    'power': 0.5
}
with open('cp_config.json', 'w') as f:
    json.dump(cp_config, f)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_log_error, mean_absolute_error


def evaluate_on__test(model, X_test, Y_test, pop_size,device=None, seq_len=200, make_plots=True):
    model.eval()

    if device is None:
        device = next(model.parameters()).device

    n = len(X_test)

    msle_list = []
    mae_list = []
    peaks = []


    with torch.no_grad():
        for i in range(n):
            # GT trajectory
            gt_input = X_test[i, 0].to(device)
            gt_target = Y_test[i, 0, :, 1].cpu().numpy()
            gt_target = gt_target * pop_size

            peaks.append(gt_target.max())

            # run model
            orig_seq_len = getattr(model, "seq_len", None)
            if orig_seq_len is not None:
                model.seq_len = seq_len

            pred = model(gt_input.view(-1, 3))
            pred_curve = pred.view(seq_len, 3)[:, 1].cpu().numpy() * pop_size

            if orig_seq_len is not None:
                model.seq_len = orig_seq_len

            gt_c = np.maximum(gt_target, 0)
            pred_c = np.maximum(pred_curve, 0)

            msle_list.append(mean_squared_log_error(gt_c, pred_c))
            mae_list.append(mean_absolute_error(gt_c, pred_c))

    msle_arr = np.asarray(msle_list)
    mae_arr = np.asarray(mae_list)
    peaks = np.asarray(peaks)

    fig = None
    if make_plots:
        idx_sorted = np.argsort(peaks)
        low_idx = idx_sorted[int(0.1 * len(peaks))]
        mid_idx = idx_sorted[int(0.5 * len(peaks))]
        high_idx = idx_sorted[int(0.9 * len(peaks))]

        pick_idxs = [low_idx, mid_idx, high_idx]
        labels = ["Decay", "Moderate outbreak", "Severe outbreak"]

        fig, axes = plt.subplots(1, 3, figsize=(16, 4), sharey=True)

        for j, idx in enumerate(pick_idxs):
            gt_target = Y_test[idx, 0, :, 1].cpu().numpy() * pop_size
            gt_input = X_test[idx, 0].to(device)

            with torch.no_grad():
                orig_seq_len = getattr(model, "seq_len", None)
                if orig_seq_len is not None:
                    model.seq_len = seq_len

                pred = model(gt_input.view(-1, 3))
                pred_curve = pred.view(seq_len, 3)[:, 1].cpu().numpy() * pop_size

                if orig_seq_len is not None:
                    model.seq_len = orig_seq_len

            beta = gt_input[0, 0].item()
            gamma = gt_input[0, 1].item()
            local_msle = msle_arr[idx]

            ax = axes[j]
            ax.plot(gt_target, label="ODE (truth)", linewidth=2)
            ax.plot(pred_curve, "--", label="surrogate", linewidth=2)
            ax.set_title(f"{labels[j]}\nβ={beta:.2f}, γ={gamma:.2f}")
            ax.set_xlabel("Time (days)")
            if j == 0:
                ax.set_ylabel("Infected individuals")
                ax.legend()

            ax.grid(alpha=0.3)
            ax.text(
                0.5, 0.9,
                f"MSLE: {local_msle:.4f}",
                transform=ax.transAxes,
                ha="center",
                bbox=dict(facecolor="white", alpha=0.8, edgecolor="none"),
                fontsize=9,
            )

        plt.tight_layout()
        plt.show()

    metrics = {
        "msle_mean": float(msle_arr.mean()),
        "msle_std": float(msle_arr.std()),
        "msle_median": float(np.median(msle_arr)),
        "msle_p10": float(np.percentile(msle_arr, 10)),
        "msle_p90": float(np.percentile(msle_arr, 90)),
        "mae_mean": float(mae_arr.mean()),
        "mae_std": float(mae_arr.std()),
        "mae_median": float(np.median(mae_arr)),
        "mae_p10": float(np.percentile(mae_arr, 10)),
        "mae_p90": float(np.percentile(mae_arr, 90)),
        "msle_all": msle_arr,
        "mae_all": mae_arr,
    }


    print(
        f"MSLE: mean={metrics['msle_mean']:.5f}, "
        f"median={metrics['msle_median']:.5f}, "
        f"[p10, p90]=[{metrics['msle_p10']:.5f}, {metrics['msle_p90']:.5f}]"
    )
    print(
        f"MAE:  mean={metrics['mae_mean']:.1f}, "
        f"median={metrics['mae_median']:.1f}, "
        f"[p10, p90]=[{metrics['mae_p10']:.1f}, {metrics['mae_p90']:.1f}]"
    )

    return metrics, fig


# usage:
metrics, fig = evaluate_on__test(sur, X_test, Y_test, POP_SIZE)


In [None]:

with open('calibration_scores.pkl', 'rb') as f:
    loaded_scores = pickle.load(f)
loaded_scores = pd.DataFrame(loaded_scores)

coverage_by_test_idx = {}
all_test_outputs = {}

clean_scores = loaded_scores[loaded_scores['score'] < 100000].copy()
clean_scores = clean_scores.drop(columns=['time_idx'])

for i in range(10):
  TEST_IDX = i
  truth_traj = Y_test[TEST_IDX, 0]
  ensemble_params = X_test[TEST_IDX, 1:, 0, :2]

  test_gen = Transformer_DataGenerator(
      model=sur,
      true_trajectory=truth_traj.to(device),
      candidate_params=ensemble_params.to(device),
      pop_size=POP_SIZE,
      T_horizon=10,
      device=device
  )

  # 3. Define CP Parameters
  cp_params = {
      'gamma': 0.1,
      'power': 0.5,
      'B': 20,
      'var_a': .5,
      'score_window': 10,
      'optim_arg': {'e_coeff_init': 0.01}
  }

  print("Running CP-Traj on Test")
  save_plot_path = "/saveplot"
  os.makedirs(save_plot_path, exist_ok=True)

  timestamps, ground_truths, prediction_intervals = run_aci_simulation(
      data_generator=test_gen,
      score_func_args={'type': 'pcp', 'optional_args': {}},
      S_max=POP_SIZE ,
      alpha=0.1,
      T_obs=130,
      H=10,
      N=10,
      start_t=20,
      params=cp_params,
      method_opt='cpt',
      plot=True,
      learned_scores=clean_scores,
      save_path=save_plot_path
  )

  gt = np.array(ground_truths)
  ivals = np.array(prediction_intervals)

  lower_samples = ivals[..., 0]
  upper_samples = ivals[..., 1]


  lower = lower_samples.min(axis=2)
  upper = upper_samples.max(axis=2)

  in_band = (gt >= lower) & (gt <= upper)

  covered_timesteps = in_band.all(axis=-1)

  coverage_fraction = covered_timesteps.mean()

  print(f"Coverage for TEST_IDX {TEST_IDX}: {coverage_fraction * 100}%")

  all_test_outputs[TEST_IDX] = {
      "timestamps": timestamps,
      "ground_truths": ground_truths,
      "prediction_intervals": prediction_intervals,
      "cp_params": cp_params,
      "coverage_fraction": float(coverage_fraction),
  }

  coverage_by_test_idx[TEST_IDX] = float(coverage_fraction)

  print("Simulation Complete.")


In [None]:
from IPython.display import Image, display
import os

plots_to_display = [
    "horizon_coverage_expand_boundaries.png",
    "coverage_expand_boundaries.png",
    "alphas_h0.png"
]

for plot_filename in plots_to_display:
    full_path = os.path.join(save_plot_path, plot_filename)
    if os.path.exists(full_path):
        print(f"Displaying: {plot_filename}")
        display(Image(filename=full_path))
    else:
        print(f"Plot file '{plot_filename}' not found at '{full_path}'.")


In [None]:
# debug to see how the confidence intervals are performing
def check_interval_widths(day_idx, prediction_intervals, pop_size=372258):
    step = 5

    print(f"Tube Widths at Day {day_idx} + {step} days ahead")
    for i in range(5):
        lower = prediction_intervals[day_idx, step, i, 0]
        upper = prediction_intervals[day_idx, step, i, 1]
        width = upper - lower
        print(f"Member {i}: Width = {width:.0f} (Coverage: {lower:.0f} to {upper:.0f})")


check_interval_widths(90, prediction_intervals, POP_SIZE)