#  Bayesian Piecewise Cox PH (Issue #8)
 This notebook validates the Piecewise Constant Hazard model implementation using the toy example described in Issue #8.

In [9]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

# Make sure you have the correct import path
from cox import BayesianPiecewiseCoxPH

plt.style.use('ggplot')

 ### 1. Toy Example Data
 Recreating the exact dataset from the issue description:
 * **Subject 1:** Event at $t=5$, Covariate $x=0$.
 * **Subject 2:** Censored at $t=8$, Covariate $x=1$.
 * **Subject 3:** Event at $t=10$, Covariate $x=1$.

The issue suggests a partition at $t=7$. We will enforce this manually.

In [10]:
df_toy = pd.DataFrame({
    'time': [5, 8, 10],
    'event': [1, 0, 1], # 1=Event, 0=Censored
    'x': [0, 1, 1]
}).reset_index(drop=True)

print("Toy Dataset:")
print(df_toy)

Toy Dataset:
   time  event  x
0     5      1  0
1     8      0  1
2    10      1  1


 ### 2. Model Initialization
 We manually set `time_intervals` to `[0, 7, 11]` to match the example logic where $t=7$ is the cut point. (11 is just an upper bound > 10).
 * **Interval 1:** $0 \le t < 7$ (Hazard $\lambda_1$)
 * **Interval 2:** $7 \le t < 11$ (Hazard $\lambda_2$)

In [11]:
# We manually define the cut-points (including 0 and something > max time)
cut_points = [0, 7, 11]   # 11 > max observed time

model = BayesianPiecewiseCoxPH(time_intervals=cut_points)

# Fit the model
model.fit(
    data=df_toy,
    duration_col='time',
    event_col='event',
    draws=2000,
    tune=1000,
    chains=4,
    cores=1,
    target_accept=0.9
)

Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md

Initializing NUTS using jitter+adapt_diag...
Sequential sampling (4 chains in 1 job)
NUTS: [beta, lambda0]


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



error: (len(y)-offy>(n-1)*abs(incy)) failed for 1st keyword n: daxpy:n=7

 ### 3. Diagnostics & Summary
 We verify that we estimated $\beta$, $\lambda_0$ (first interval), and $\lambda_1$ (second interval).

In [None]:
# Summary of posterior
print(az.summary(model.idata, var_names=['beta', 'lambda0'], round_to=4))

# Trace plot (very useful with so few observations!)
az.plot_trace(model.idata, var_names=['beta', 'lambda0'], compact=True)
plt.tight_layout()
plt.show()

# Check R-hat and ESS
print("\nDiagnostics:")
print(az.summary(model.idata, var_names=['beta', 'lambda0'])[['r_hat', 'ess_bulk', 'ess_tail']])

### 4. Verify Data Expansion logic
Internally, the class converts the 3 rows into 5 rows (based on the splits).
* **Subj 1:** (0-5, Event) -> Interval 0
* **Subj 2:** (0-7, No Event) -> Interval 0
* **Subj 2:** (7-8, Censored) -> Interval 1
* **Subj 3:** (0-7, No Event) -> Interval 0
* **Subj 3:** (7-10, Event) -> Interval 1

In [None]:
# Prepare prediction grid
X_new = pd.DataFrame({
    'x': [0, 1]           # one curve for x=0, one for x=1
})

# Predict survival functions
surv_df = model.predict_survival_function(X_new, times=np.linspace(0, 10.5, 200))

# Rename columns for clarity
surv_df.columns = ['x = 0', 'x = 1']

# Plot
plt.figure(figsize=(10, 6))

for col in surv_df.columns:
    plt.step(surv_df.index, surv_df[col], where='post', label=col, lw=2.2)

# Add observed events/censorings
plt.plot(5,  surv_df.loc[5, 'x = 0'],  'o', ms=10, color='C0', alpha=0.8, label='Event x=0')
plt.plot(8,  surv_df.loc[8, 'x = 1'],  'x', ms=12, mew=3, color='C1', label='Censored x=1')
plt.plot(10, surv_df.loc[10, 'x = 1'], 'o', ms=10, color='C1', alpha=0.8, label='Event x=1')

plt.title('Estimated Survival Functions â€“ Piecewise Constant Cox PH', fontsize=13)
plt.xlabel('Time', fontsize=12)
plt.ylabel('Survival Probability S(t)', fontsize=12)
plt.legend(fontsize=11)
plt.ylim(0, 1.05)
plt.xlim(0, 11)
plt.grid(alpha=0.3)
plt.show()

### 5. Prediction
Predicting survival curves for $x=0$ and $x=1$.