In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import os
import sys
import torch
from torch.utils.data import DataLoader, Dataset
from properscoring import crps_ensemble
import random
import pymc as pm
from patsy import dmatrix




sys.path.append(os.path.abspath('../src'))
project_dir = Path.cwd().parent

from data_tools.data_utils import SimulateSero, PartialCountDataset, TrueCountDataset
from model_tools.train_utils import BaseTrain, SeroTrain, DirectTrain
from model_tools.models import SimSero, DirectSero
from model_tools.evaluation import eval_sero_pnn, plot_sero_pnn_preds, eval_pnn, eval_direct_sero

start_year = 2022
end_year = 2022
data_split = [0.7, 0.15, 0.15]
seed = 123

D = 40
M = 50
T = 40
Q = 40
N = 4
S = 500
chains = 1
cores = 1


# 123, 2019, 2023, 15

def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False




In [2]:
# Simulate Data
# True counts * Constant sero prop vec * constant delay vec from average across world samples
p_sero = np.array([0.9, 0.0, 0.05, 0.05])
sero_all = ["DENV-1", "DENV-2", "DENV-3", "DENV-4"]



delays_df = pd.read_csv(project_dir / "data" / "transformed" / "DENG_delays.csv")
delays_df = delays_df.set_index("Collection date")
delays_df.index = pd.to_datetime(delays_df.index)

y_true = delays_df.sum(1)[:-2]
y_true_df = y_true.groupby(y_true.index.to_period("M")).sum()
y_true = np.array(y_true_df)
const = 10*np.ones_like(y_true)

denv_df = pd.read_csv(project_dir / "data" / "transformed" / "denv_df.csv")

denv_df = denv_df[denv_df['Delay'] < 60]

# If your original denv_df dates are datetime
denv_df['Collection date'] = pd.to_datetime(denv_df['Collection date'])

start_month = pd.to_datetime(delays_df.index.min())
end_month = pd.to_datetime(delays_df.index.max())

# Create a DataFrame of month start dates as datetime (not Period)
dates = pd.DataFrame({
    "Collection date": pd.date_range(start=start_month, end=end_month, freq='MS')
})

df = denv_df[denv_df['Sero'] == "DENV-1"] \
    .groupby(['Sero', 'Collection date', 'Delay']) \
    .size() \
    .reset_index(name='count')

df = df.pivot(index='Collection date', columns='Delay', values='count')

# Ensure df.index is datetime as well
df.index = pd.to_datetime(df.index)

p_delay = np.array(df.fillna(0).mean(0) / df.fillna(0).mean(0).sum())

p_sero = p_sero[:, np.newaxis, np.newaxis]   # shape (S,1,1)
const = const[np.newaxis, :, np.newaxis]  # shape (1,T,1)
p_delay = p_delay[np.newaxis, np.newaxis, :Q]  # shape (1,1,D)

# multiply to get (S,T,D)
sero_tensor = p_sero * const * p_delay
sero_tensor = sero_tensor.round()
for s,sero in enumerate(sero_all):
    sero_df = pd.DataFrame(sero_tensor[s, :, :])
    sero_df.index = y_true_df.index
    sero_df.to_csv(project_dir / "data" / "model" / "sero_dfs" / f"{sero}.csv", index=True)
    print("Saving sero: ", sero)

Saving sero:  DENV-1
Saving sero:  DENV-2
Saving sero:  DENV-3
Saving sero:  DENV-4


  denv_df['Collection date'] = pd.to_datetime(denv_df['Collection date'])


In [3]:
print(p_delay.shape)

(1, 1, 40)


In [4]:
# Create Serotype obj
sero_tensor = []
base_folder = project_dir / "data" / "model" / "sero_dfs"
files = [f for f in os.listdir(base_folder) if f.endswith('.csv')]
for file in files:
    file_path = os.path.join(base_folder, file)
    df = pd.read_csv(file_path)
    dates = df['Collection date']
    sero_df = df.drop(columns="Collection date")
    sero_tensor.append(sero_df)
sero_tensor = np.array(sero_tensor)

sero_dataset = SimulateSero(sero_tensor, dates, T=T, Q=Q, N=N, prop_vec=p_sero)

In [5]:
# Create count obj
delays_df = pd.read_csv(project_dir / "data" / "transformed" / "DENG_delays.csv")
delays_df['Collection date'] = pd.to_datetime(delays_df['Collection date'])

partial_count_dataset = PartialCountDataset(delays_df, D=D, M=M)
true_count_dataset = TrueCountDataset(delays_df)

In [6]:
# End of 2023 appears to have some incomplete data
if end_year == 2023:
    dates = list(pd.date_range(f"{start_year}-01-01",f"{end_year}-12-25", freq='D'))
else:
    dates = list(pd.date_range(f"{start_year}-01-01",f"{end_year}-12-31", freq='D'))
data_split_sizes = (np.array(data_split) * len(dates)).astype(int)

# Ensure dates line up fully, also removing 2023-12-31 as doesn't appear complete
data_split_sizes[-1] = len(dates) - data_split_sizes[:-1].sum()




In [None]:

def silu(x):
    return x * pm.math.sigmoid(x)


def sampler_kwargs():
    return dict(
        nuts_sampler="nutpie",
        cores=cores,
        init="adapt_diag",
        chains=chains,
        draws=S,
        tune=500,
        target_accept=0.95,
        max_treedepth=10,
        nuts_sampler_kwargs={"backend": "jax", "gradient_backend": "jax"}
    )

def get_mask(D):
        mask_matrix = np.ones(shape=(D, D), dtype=bool)
        for i in range(D):
            for j in range(D):
                if i + j > D - 1:
                    mask_matrix[i, j] = False
        return mask_matrix

def create_fourier_features(t, n, p=10.0):
    x = 2 * np.pi * (np.arange(n) + 1) * t[:, None] / p
    return np.concatenate((np.cos(x), np.sin(x)), axis=1)

class SeroBNNDataset(Dataset):
    def __init__(self, partial_count_obj, true_count_obj, sero_obj, dates, S, T, Q):
        self.partial_count_obj = partial_count_obj
        self.true_count_obj = true_count_obj
        self.sero_obj = sero_obj
        self.dates = dates
        self.S = S
        self.T = T
        self.Q = Q
    
    def __len__(self):
        return len(self.dates)
    
    def __getitem__(self, index):
        date = self.dates[index]
        window_dates = [date - pd.Timedelta(days=i) for i in range(M)]
        window_dates = sorted(window_dates)
        Z_obs = self.partial_count_obj.get_obs(date)

        prop_mat = [self.sero_obj.get_prop_vec(date) for day in window_dates]


        y_sero_true = self.true_count_obj.get_y_prop(date, )
        y_sero_true = np.array(y_true)
        dow = date.day_of_week

        sero_obs = self.sero_obj.get_obs(date)
        return Z_obs, sero_obs, y_sero_true, window_dates
    
set_seed(seed)
sero_props = pd.read_csv(project_dir / "data" / "transformed" / "sero_props.csv")


sero_pnn_dataset = SeroBNNDataset(partial_count_dataset, true_count_dataset, sero_dataset, dates)



In [13]:
Z_obs, sero_obs, y_true, y_sero_true, window_dates = sero_pnn_dataset.__getitem__(0)
Z_obs

array([[[1.18360376e-02, 1.52177626e-02, 1.53714774e-02, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [8.36891546e-04, 2.20324509e-03, 2.08368915e-03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        ...,
        [2.22032451e-03, 2.57045260e-03, 1.38343296e-03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [2.38257899e-02, 3.53543980e-02, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],

       [[6.57557643e-04, 8.45431255e-04, 8.53970965e-04, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [5.97779675e-04, 1.16140051e-03, 1.87019641e-03, ...,
         0.00000000e+00, 0.00000000e+00, 8.53970965e-06],
        [1.50640478e-02, 

In [None]:
# Constants
t = np.arange(0, M)
t_week = t % 7
t_norm = t / M

n = 14
fourier_basis_biweek = create_fourier_features(t, n=n, p=3.5)
fourier_basis_week = create_fourier_features(t, n=n, p=7)

fourier_basis_biweek = fourier_basis_biweek - fourier_basis_biweek.mean(0, keepdims=True)
fourier_basis_week = fourier_basis_week - fourier_basis_week.mean(0, keepdims=True)

spline_trend = dmatrix(
    "bs(t, df=14, degree=3, include_intercept=False)", {"t": t_norm}, return_type='dataframe'
)
X_trend = np.asarray(spline_trend)

spline_week = dmatrix(
    "cc(t_week, df=7)", {"t_week": t_week}, return_type='dataframe'
)
X_week = np.asarray(spline_week)

t_input = np.arange(M)[:, None] / M
time_input = np.concatenate([t_input, fourier_basis_biweek, fourier_basis_week], axis=1)

mask = np.ones((M,D), dtype=bool)
mask[-D:,:] = get_mask(D)


