# Simulation studies

(Move this notebook to the root directory to run it)

Note: For all these studies, the hyperparameter search has already been completed.

In [None]:
import sys
from pathlib import Path
import yaml
import pickle
import argparse
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter

from coxkan import CoxKAN
from coxkan.utils import bootstrap_metric, set_seed

SEED = set_seed(42)

In [None]:
### Reusable functions for the notebook

def true_cindex(df):
    global sim_config, duration_col, event_col, covariates
    lph = sim_config['log_partial_hazard'](**df[covariates])
    return concordance_index(df[duration_col], -lph, df[event_col])

def cph_cindex(df):
    global cph
    return cph.score(df, scoring_method='concordance_index')

def cph_reg_cindex(df):
    global cph_reg
    return cph_reg.score(df, scoring_method='concordance_index')

def cph_formula(cph):
    coefficients = cph.params_
    terms = []
    for covariate, coefficient in coefficients.items():
        term = f"{coefficient:.4f} * {covariate}"
        terms.append(term)
    expression = " + ".join(terms)
    return expression

def plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals):
    #plot residuals in 1x2 grid vs x1 and x2
    residuals = test_lph_vals - pred_lph_vals
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].scatter(x1_vals, residuals, alpha=0.5)
    #insert correlation as text
    corr = np.corrcoef(x1_vals, residuals)[0, 1]
    axes[0].text(0.5, 0.9, f'Correlation: {corr:.2f}', transform=axes[0].transAxes, fontsize=14)
    axes[0].set_xlabel(r'$x_1$', fontsize=14)
    axes[0].set_ylabel('Residuals', fontsize=14)
    axes[0].set_title('Residuals vs $x_1$', fontsize=14)
    axes[1].scatter(x2_vals, residuals, alpha=0.5)
    #insert correlation as text
    corr = np.corrcoef(x2_vals, residuals)[0, 1]
    axes[1].text(0.5, 0.9, f'Correlation: {corr:.2f}', transform=axes[1].transAxes, fontsize=14)
    axes[1].set_xlabel(r'$x_2$', fontsize=14)
    axes[1].set_ylabel('Residuals', fontsize=14)
    axes[1].set_title('Residuals vs $x_2$', fontsize=14)
    fig.tight_layout()
    plt.show()
    return fig
    

## Gaussian

For our first study, we set the log-partial hazard to be a Gaussian:

$$\theta(\mathbf{x}) = 5 \exp(-2 (x_1^2 + x_2^2))$$

In [None]:
exp_name = "sim_gaussian"
sim_name = "gaussian"

### load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

# (config from hyperparameter search)
with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]

Evaluate performance of the 'true' expression. Clearly, the C-Index will not be perfect since survival time is randomly distributed.

In [None]:
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")

Evaluate Cox Proportional Hazards Model:

In [None]:
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

#plot residuals
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph.predict_partial_hazard(df_test)

fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

# do same for CoxPH with regularization
cph_reg = CoxPHFitter(penalizer=0.5, l1_ratio=1)
cph_reg.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph_reg = bootstrap_metric(cph_reg_cindex, df_test, N=100)['formatted']
formula_cph_reg = cph_formula(cph_reg)
print(f"CoxPH (Reg) Expression: {formula_cph_reg}")
print(f"CoxPH (Reg) C-Index: {cindex_cph_reg}")
#plot residuals
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph_reg.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

In [None]:
# train DeepSurv on same task
import torchtuples as tt
from coxkan.utils import FastCoxLoss, count_parameters, bootstrap_metric, set_seed, SYMBOLIC_LIB
import torch
from sklearn.model_selection import train_test_split

with open(f'configs/mlp/sim_gaussian.yml', 'r') as f:
    mlp_config = yaml.safe_load(f)
    
mlp = tt.practical.MLPVanilla(
    in_features=len(covariates), out_features=1, output_bias=False, **mlp_config['init_params']
)
optimizer = tt.optim.Adam(**mlp_config['optimizer_params'])
deepsurv = tt.Model(mlp, loss=FastCoxLoss, optimizer=optimizer)
deepsurv_params = count_parameters(mlp)

# Convert to PyTorch tensors
X_test = torch.tensor(df_test[covariates].values).double()
y_test = torch.tensor(df_test[[duration_col, event_col]].values).double()

def mlp_cindex(df):
    lph = deepsurv.predict(torch.tensor(df[covariates].values).double())
    return concordance_index(df[duration_col], -lph, df[event_col])

def mlp_cindex_metric_fn(lph, labels):
    return concordance_index(labels[:, 0].detach().numpy(), -lph.detach().numpy(), labels[:, 1].detach().numpy())

# Training
if mlp_config['early_stopping']:
    train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
    X_val = torch.tensor(val[covariates].values).double()
    y_val = torch.tensor(val[[duration_col, event_col]].values).double()
    X_train = torch.tensor(train[covariates].values).double()
    y_train = torch.tensor(train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_val, y_val), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}, callbacks=[tt.callbacks.EarlyStopping(patience=20)]
    )
else:
    X_train = torch.tensor(df_train[covariates].values).double()
    y_train = torch.tensor(df_train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_test, y_test), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}
    )

cindex_mlp = bootstrap_metric(mlp_cindex, df_test, N=100)['formatted']
print(f"MLP C-Index: {cindex_mlp}")

#plot residuals
x1_vals = X_test[:, 0].numpy()
x2_vals = X_test[:, 1].numpy()
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = deepsurv.predict(X_test).squeeze().detach().cpu().numpy()
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


In [None]:
#train SuMo-net
from sumo import sumo
from lifelines.utils import concordance_index

print(dir(sumo))

with open(f'configs/mlp/sim_gaussian.yml', 'r') as f:
    mlp_config = yaml.safe_load(f)
    
print(mlp_config)
    
lr = mlp_config['optimizer_params']['lr']
weight_decay = mlp_config['optimizer_params']['weight_decay']
num_nodes = mlp_config['init_params']['num_nodes']
dropout = mlp_config['init_params']['dropout']

model = sumo.SuMo(layers = num_nodes, dropout = dropout)
X= df_train[covariates].values
e = df_train['event'].values
t = df_train['duration'].values

#print min and max of e and t
print(f"Min event: {e.min()}, Max event: {e.max()}")
print(f"Min duration: {t.min()}, Max duration: {t.max()}")
model = model.fit(X,t,e,random_state=42, n_iter=1000, lr=lr, weight_decay=weight_decay)
                  

X_test = df_test[covariates].values
e_test = df_test['event'].values
t_test = df_test['duration'].values
# convert to torch tensors
X_test = torch.tensor(X_test).double()
t_test = torch.tensor(t_test).double()
e_test = torch.tensor(e_test).double()
print(X_test.shape, t_test.shape, e_test.shape)
survival, intensity = model.forward(X_test, t_test, gradient = False)
survival = survival.detach().numpy()

# calculate concordance between predicted survival and t_test
survival = survival

cindex_sumo = concordance_index(t_test, -survival, e_test)

# calculate concordance between
print(survival.shape,survival.max(), survival.min())
print(survival)
print(f"SuMo C-Index: {cindex_sumo}")

Evaluate CoxKAN:

In [None]:
assert False
# CoxKAN
ckan = CoxKAN(seed=42, **config['init_params'])

log = ckan.train(df_train, df_test, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png')
fig = ckan.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png', dpi=600)

In [None]:
# Pruning
ckan = ckan.prune_nodes(config['prune_threshold'])
ckan.prune_edges(config['prune_threshold'], verbose=True)
fig = ckan.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png', dpi=600)
cindex_pruned = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

We can recognise these activation functions as `x^2, x^2, exp`. We try fitting these functions:

In [None]:
r2 = ckan.fix_symbolic(0, 0, 0, 'x^2', verbose=False)
print(f"Activation (0,0,0): x^2 fits with R^2: {r2}")
r2 = ckan.fix_symbolic(0, 1, 0, 'x^2', verbose=False)
print(f"Activation (0,1,0): x^2 fits with R^2: {r2}")
r2 = ckan.fix_symbolic(1, 0, 0, 'exp', verbose=False)
print(f"Activation (1,0,0): exp fits with R^2: {r2}")

formula = ckan.symbolic_formula()[0][0]
formula

The high $R^2$ values (coefficient of determination) verify that the true expression was learnt. However, the affine parameters are not quite correct (although a close approximation). We finish by training the affine params:

In [None]:
ckan.train(df_train, df_test, duration_col, event_col, opt="LBFGS", steps=50)
formula = ckan.symbolic_formula()[0][0]
print(formula)

#plot residuals in ckan
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = ckan.predict(df_test)

fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

We see that the result is near-perfect. Now we just save the results/visualisations:

In [None]:
fig = ckan.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png', dpi=600)
cindex_symbolic = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")

results = {
    'cindex_true': cindex_true,
    'cindex_cph': cindex_cph,
    'cindex_pre': cindex_pre,
    'cindex_pruned': cindex_pruned,
    'cindex_symbolic': cindex_symbolic,
    'coxkan_formula': formula,
    'coxph_formula': formula_cph,
}

with open(f'checkpoints/{exp_name}/results.pkl', 'wb') as f:
    pickle.dump(results, f)

## Shallow

In survival analysis, it is common that we encounter covariates that satisfy the linear Cox Proportional Hazards model after some non-linear transformation. In other words, they have non-linear relationships to the patient's risk but they do not interact. To simulate this situation we use the following expression for the log-partial hazard:

$$\theta(\mathbf{x}) = \tanh(5x_1) + \sin(2\pi x_2) + x_3^2$$

This can be captured by a shallow KAN (no hidden layers).

In [None]:
exp_name = "sim_depth_1"
sim_name = "depth_1"

### load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

# (config from hyperparameter search)
with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]

Evaluate true expression:

In [None]:
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")

Evaluate Cox proportional hazards model:

In [None]:
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

#plot residuals
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

#do same but with regularisation
cph = CoxPHFitter(penalizer=0.5, l1_ratio=1)
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")
#plot residuals
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


In [None]:
# train DeepSurv on same task
import torchtuples as tt
from coxkan.utils import FastCoxLoss, count_parameters, bootstrap_metric, set_seed, SYMBOLIC_LIB
import torch

with open(f'configs/mlp/sim_depth_1.yml', 'r') as f:
    mlp_config = yaml.safe_load(f)
mlp = tt.practical.MLPVanilla(
    in_features=len(covariates), out_features=1, output_bias=False, **mlp_config['init_params']
)
optimizer = tt.optim.Adam(**mlp_config['optimizer_params'])
deepsurv = tt.Model(mlp, loss=FastCoxLoss, optimizer=optimizer)
deepsurv_params = count_parameters(mlp)
# Convert to PyTorch tensors
X_test = torch.tensor(df_test[covariates].values).double()
y_test = torch.tensor(df_test[[duration_col, event_col]].values).double()

# Training
if mlp_config['early_stopping']:
    train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
    X_val = torch.tensor(val[covariates].values).double()
    y_val = torch.tensor(val[[duration_col, event_col]].values).double()
    X_train = torch.tensor(train[covariates].values).double()
    y_train = torch.tensor(train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_val, y_val), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}, callbacks=[tt.callbacks.EarlyStopping(patience=20)]
    )
else:
    X_train = torch.tensor(df_train[covariates].values).double()
    y_train = torch.tensor(df_train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_test, y_test), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}
    )
    
cindex_mlp = bootstrap_metric(mlp_cindex, df_test, N=100)['formatted']
print(f"MLP C-Index: {cindex_mlp}")
#plot residuals
x1_vals = X_test[:, 0].numpy()
x2_vals = X_test[:, 1].numpy()
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = deepsurv.predict(X_test).squeeze().detach().cpu().numpy()
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


CoxKAN:

In [None]:
# CoxKAN
ckan = CoxKAN(seed=42, **config['init_params'])

log = ckan.train(df_train, df_test, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png', bbox_inches='tight')
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$',r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png', bbox_inches='tight')

In [None]:
# Pruning
ckan = ckan.prune_nodes(config['prune_threshold'])
ckan.prune_edges(config['prune_threshold'], verbose=True)
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$',r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png', bbox_inches='tight')
cindex_pruned = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

We can again recognise these activations: 
- some s-shaped function like `tanh` or `sigmoid`
- some oscillating function like `sin` (or `cos` but this is just a matter of affine parameters)
- a quadratic `x^2`

In the case of the s-shaped function we try out both options and see which is better:

In [None]:
_ = ckan.predict(df_test)
fn, _, r2 = ckan.suggest_symbolic(0, 0, 0, lib=['tanh', 'sigmoid'], verbose=False)
print(f"Best: {fn} fits with R^2: {r2}")

In [None]:
r2 = ckan.fix_symbolic(0, 0, 0, 'tanh', verbose=False)
print(f"Activation (0,0,0): tanh fits with R^2: {r2}")
r2 = ckan.fix_symbolic(0, 1, 0, 'sin', verbose=False)
print(f"Activation (0,1,0): sin fits with R^2: {r2}")
r2 = ckan.fix_symbolic(0, 2, 0, 'x^2', verbose=False)
print(f"Activation (0,2,0): x^2 fits with R^2: {r2}")

formula = ckan.symbolic_formula()[0][0]
print(formula)

#plot residuals in ckan
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = ckan.predict(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


The high $R^2$ values (coefficient of determination) verify that the true expression was learnt. On first glance, the `sin` term appears incorrect, but note that: $ - 0.96 \sin(6.27x + 12.58) \approx -\sin(2\pi x - 3\pi) = \sin(2\pi x)$

In [None]:
_ = ckan.predict(df_test)
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$',r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png', bbox_inches='tight')
cindex_symbolic = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")

results = {
    'cindex_true': cindex_true,
    'cindex_cph': cindex_cph,
    'cindex_pre': cindex_pre,
    'cindex_pruned': cindex_pruned,
    'cindex_symbolic': cindex_symbolic,
    'coxkan_formula': formula,
    'coxph_formula': formula_cph,
}

with open(f'checkpoints/{exp_name}/results.pkl', 'wb') as f:
    pickle.dump(results, f)

Note that in this case, training the affine parameters doesnt yield a better expression:

In [None]:
ckan.train(df_train, df_test, duration_col, event_col, opt="LBFGS", steps=50)
formula = ckan.symbolic_formula()[0][0]
formula

This is likely due to the noise in the dataset - we can only get an approximation.

## Deep

To contrast with the previous example, we now try an expression for the log-partial hazard that requires a deep KAN (2 hidden layers) to capture:

$$\theta(\mathbf{x}) = 2\sqrt{(x_1-x_2)^2 + (x_3-x_4)^2}$$

In [None]:
exp_name = "sim_deep"
sim_name = "deep"

# load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]

In [None]:
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")

In [None]:
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

#plot residuals
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

# do same for CoxPH with regularization
cph_reg = CoxPHFitter(penalizer=0.5, l1_ratio=1)
cph_reg.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph_reg = bootstrap_metric(cph_reg_cindex, df_test, N=100)['formatted']
formula_cph_reg = cph_formula(cph_reg)
print(f"CoxPH (Reg) Expression: {formula_cph_reg}")
print(f"CoxPH (Reg) C-Index: {cindex_cph_reg}")
#plot residuals
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph_reg.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


In [None]:
# train DeepSurv on same task
import torchtuples as tt
from coxkan.utils import FastCoxLoss, count_parameters, bootstrap_metric, set_seed, SYMBOLIC_LIB
import torch

with open(f'configs/mlp/sim_deep.yml', 'r') as f:
    mlp_config = yaml.safe_load(f)
mlp = tt.practical.MLPVanilla(
    in_features=len(covariates), out_features=1, output_bias=False, **mlp_config['init_params']
)
optimizer = tt.optim.Adam(**mlp_config['optimizer_params'])
deepsurv = tt.Model(mlp, loss=FastCoxLoss, optimizer=optimizer)
deepsurv_params = count_parameters(mlp)
# Convert to PyTorch tensors
X_test = torch.tensor(df_test[covariates].values).double()
y_test = torch.tensor(df_test[[duration_col, event_col]].values).double()

# Training
if mlp_config['early_stopping']:
    train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
    X_val = torch.tensor(val[covariates].values).double()
    y_val = torch.tensor(val[[duration_col, event_col]].values).double()
    X_train = torch.tensor(train[covariates].values).double()
    y_train = torch.tensor(train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_val, y_val), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}, callbacks=[tt.callbacks.EarlyStopping(patience=20)]
    )
else:
    X_train = torch.tensor(df_train[covariates].values).double()
    y_train = torch.tensor(df_train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_test, y_test), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}
    )
    
cindex_mlp = bootstrap_metric(mlp_cindex, df_test, N=100)['formatted']
print(f"MLP C-Index: {cindex_mlp}")
#plot residuals
x1_vals = X_test[:, 0].numpy()
x2_vals = X_test[:, 1].numpy()
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = deepsurv.predict(X_test).squeeze().detach().cpu().numpy()
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

For this study (perhaps due to its difficulty), the hyperparameter search yielded early_stopping=True - hence we need to split the training set into a train and a validation set.

In [None]:
from sklearn.model_selection import train_test_split

# CoxKAN
ckan = CoxKAN(seed=42, **config['init_params'])

# Train/Val split for early stopping
train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])

log = ckan.train(train, val, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png')
fig = ckan.plot(beta=20, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png')

Another benefit of using a validation set, is that we can use it to select a more optimal pruning threshold:

In [None]:
# search for pruning thresholds
pruning_thresholds = np.linspace(0, 0.05, 20)
pruning_thresholds[0] = config['prune_threshold']
cindices = []
for threshold in pruning_thresholds:
    ckan_ = CoxKAN(seed=42, **config['init_params'])
    ckan_.load_ckpt(f'checkpoints/{exp_name}/model.pt', verbose=False)
    _ = ckan_.predict(df_test) # important forward pass after loading a model
    
    prunable = True
    for l in range(ckan_.depth):
        if not (ckan_.acts_scale[l] > threshold).any():
            prunable = False
            break
        
    ckan_ = ckan_.prune_nodes(threshold)
    if 0 in ckan_.width: prunable = False
    if not prunable:
        if threshold == config['prune_threshold']: 
            cindices.append(0)
            continue
        else: break

    _ = ckan_.predict(df_test) # important forward pass
    ckan_.prune_edges(threshold, verbose=False)
    cindices.append(ckan_.cindex(val))
    print(f'Pruning threshold: {threshold:.2f}, C-Index (Val): {cindices[-1]:.6f}')
best_threshold = pruning_thresholds[np.argmax(cindices)]
if np.max(cindices) < 0.51: best_threshold = 0

In [None]:
# Pruning
_ = ckan.predict(df_test)
ckan = ckan.prune_nodes(best_threshold)
_ = ckan.predict(df_test)
ckan.prune_edges(best_threshold, verbose=True)
fig = ckan.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png')
cindex_pruned = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

We examine these activations and fix them accordingly:

In [None]:
r2 = ckan.fix_symbolic(0, 0, 1, 'x^2', verbose=False)
print(f"Activation (0,0,1): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(0, 1, 1, 'x^2', verbose=False)
print(f"Activation (0,1,1): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(0, 2, 1, 'x^2', verbose=False)
print(f"Activation (0,2,1): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(0, 3, 1, 'x^2', verbose=False)
print(f"Activation (0,3,1): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(0, 0, 2, 'x', verbose=False)
print(f"Activation (0,0,2): x fits with R^2: {r2.item()}")
                       
r2 = ckan.fix_symbolic(0, 1, 2, 'x', verbose=False)
print(f"Activation (0,1,2): x fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(0, 2, 0, 'x', verbose=False)
print(f"Activation (0,2,0): x fits with R^2: {r2.item()}")
                       
r2 = ckan.fix_symbolic(0, 3, 0, 'x', verbose=False)
print(f"Activation (0,3,0): x fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(1, 0, 0, 'x^2', verbose=False)
print(f"Activation (1,0,0): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(1, 1, 0, 'x', verbose=False)
print(f"Activation (1,1,0): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(1, 2, 0, 'x^2', verbose=False)
print(f"Activation (1,2,0): x^2 fits with R^2: {r2.item()}")

r2 = ckan.fix_symbolic(2, 0, 0, 'sqrt', verbose=False)
print(f"Activation (2,0,0): sqrt fits with R^2: {r2.item()}")

The high $R^2$ in each case verifies that these functions have been learned.

In [None]:
formula = ckan.symbolic_formula()[0][0]
print(formula)

#plot residuals in ckan
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = ckan.predict(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


fig = ckan.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png')
cindex_symbolic = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")

results = {
    'cindex_true': cindex_true,
    'cindex_cph': cindex_cph,
    'cindex_pre': cindex_pre,
    'cindex_pruned': cindex_pruned,
    'cindex_symbolic': cindex_symbolic,
    'coxkan_formula': formula,
    'coxph_formula': formula_cph,
}

with open(f'checkpoints/{exp_name}/results.pkl', 'wb') as f:
    pickle.dump(results, f)

In [None]:
formula

The affine parameters are not quite right. However, if we take some liberal approximations, we see that it has captured the true expression:

$ 3.97 \sqrt{0.96*(0.06 - x_2)^2 + 0.94*(0.09 - x_1)^2 + (-x_3 - 0.04)^2 + 0.79*(-x_4 - 0.04)^2 - 0.5*(-0.97*x_1 - x_2 + 0.14)^2 - 0.67*(-x_3 - 0.71*x_4 - 0.06)^2 + 0.56} $

$ \approx 4 \sqrt{x_2^2 + x_1^2 + x_3^2 + x_4^2 - \frac{1}{2}(x_1 + x_2)^2 - \frac{1}{2}(x_3 + x_4)^2} $

$ = 2 \sqrt{2x_1^2 + 2x_2^2 + 2x_3^2 + 2x_4^2 - (x_1 + x_2)^2 - (x_3 + x_4)^2} $

$ = 2 \sqrt{x_1^2 + x_2^2 - 2x_1x_2 + x_3^2 + x_4^2 - 2x_3x_4}$

$ = 2 \sqrt{(x_1 - x_2)^2 + (x_3 - x_4)^2}$

Obviously, in the case of a real dataset, we would not be able to make these reckless approximations, nor would we necessarily go through spotting all these activations as their 'true' symbolic counterparts. In the next example, we will see an example where we instead just call `auto_symbolic` rather than recognising activations by eye. 

## Intentionally difficult dataset

Next we use an expression for the log-partial hazard which is intentionally difficult to capture.

$$\theta(\mathbf{x}) = \tanh(5(\log(x_1) + |x_2|))$$

The intuitition behind this choice is that
- $\tanh(5z)$ has a very shallow gradient in much of its domain, hence there will not be much of a training signal when comparing subjects in these regions.
- $|x_2|$ is non-smooth. KAN activations use B-Splines which are necessarily smooth - thus this activation is likely to be difficult to learn. 

In [None]:
exp_name = "sim_difficult"
sim_name = "difficult"

# load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]

In [None]:
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")

In [None]:
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

#plot residuals
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)

# do same for CoxPH with regularization
cph_reg = CoxPHFitter(penalizer=0.5, l1_ratio=1)
cph_reg.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph_reg = bootstrap_metric(cph_reg_cindex, df_test, N=100)['formatted']
formula_cph_reg = cph_formula(cph_reg)
print(f"CoxPH (Reg) Expression: {formula_cph_reg}")
print(f"CoxPH (Reg) C-Index: {cindex_cph_reg}")
#plot residuals
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = cph_reg.predict_partial_hazard(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


In [None]:
# train DeepSurv on same task
import torchtuples as tt
from coxkan.utils import FastCoxLoss, count_parameters, bootstrap_metric, set_seed, SYMBOLIC_LIB
import torch

with open(f'configs/mlp/sim_difficult.yml', 'r') as f:
    mlp_config = yaml.safe_load(f)
mlp = tt.practical.MLPVanilla(
    in_features=len(covariates), out_features=1, output_bias=False, **mlp_config['init_params']
)
optimizer = tt.optim.Adam(**mlp_config['optimizer_params'])
deepsurv = tt.Model(mlp, loss=FastCoxLoss, optimizer=optimizer)
deepsurv_params = count_parameters(mlp)
# Convert to PyTorch tensors
X_test = torch.tensor(df_test[covariates].values).double()
y_test = torch.tensor(df_test[[duration_col, event_col]].values).double()

# Training
if mlp_config['early_stopping']:
    train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
    X_val = torch.tensor(val[covariates].values).double()
    y_val = torch.tensor(val[[duration_col, event_col]].values).double()
    X_train = torch.tensor(train[covariates].values).double()
    y_train = torch.tensor(train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_val, y_val), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}, callbacks=[tt.callbacks.EarlyStopping(patience=20)]
    )
else:
    X_train = torch.tensor(df_train[covariates].values).double()
    y_train = torch.tensor(df_train[[duration_col, event_col]].values).double()
    log = deepsurv.fit(
        X_train, y_train, batch_size=len(X_train), val_data=(X_test, y_test), epochs=mlp_config['epochs'], verbose=False,
        metrics={'cindex': mlp_cindex_metric_fn}
    )
    
cindex_mlp = bootstrap_metric(mlp_cindex, df_test, N=100)['formatted']
print(f"MLP C-Index: {cindex_mlp}")
#plot residuals
x1_vals = X_test[:, 0].numpy()
x2_vals = X_test[:, 1].numpy()
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = deepsurv.predict(X_test).squeeze().detach().cpu().numpy()
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


In [None]:
# CoxKAN
ckan = CoxKAN(seed=42, **config['init_params'])

log = ckan.train(df_train, df_test, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png')
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png')

In [None]:
# Pruning
ckan = ckan.prune_nodes(config['prune_threshold'])
ckan.prune_edges(config['prune_threshold'], verbose=True)
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png')
cindex_pruned = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

Although the activations do seem to be close to their 'true' counterparts, we would not naturally recognise them by eye. As we expected, the model appears to have had issues with the `tanh` (parts of the domain that should be flat are not bumpy) and `abs` (too smooth at the bottom). 

Hence, instead of recognising activations by eye, we simply call `auto_symbolic`. 

In [None]:
# _ = ckan.auto_symbolic(only_interpretable_funcs=True, verbose=True)
_ = ckan.auto_symbolic(only_interpretable_funcs=False, verbose=True)

_ = ckan.predict(df_test)
fig = ckan.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png')
cindex_symbolic = bootstrap_metric(ckan.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")
formula = ckan.symbolic_formula()[0][0]
print(formula)

#plot residuals in ckan
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
test_lph_vals = sim_config['log_partial_hazard'](x1_vals, x2_vals, *df_test[covariates[2:]].values.T)
pred_lph_vals = ckan.predict(df_test)
fig = plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals)


Surprisingly, the `tanh` activation was recovered, and the symbolic fitting actually smoothed out some of the noisy bumps, yielding a better result. Unforunately, the other activations were not recovered. 

However, the C-Index is very close to that of the true expression, suggesting that our expression is a close approximation in the relevant domain. We now plot the true and predicted expressions:

In [None]:
import torch 
import numpy as np

x1_range=(0.1, 1)
x2_range=(-1, 1)

# create grid of x1 and x2 values
x1 = np.linspace(*x1_range, 100)
x2 = np.linspace(*x2_range, 100)
X = np.meshgrid(x1, x2)
X = np.array(X).reshape(2, -1).T
X = torch.tensor(X, dtype=torch.float64)

# add noise
X = torch.cat([X, torch.zeros_like(X)[:, :2]], dim=1)

# get the predicted log partial hazard
lph_true = sim_config['log_partial_hazard'](X[:,0], X[:,1], *X[:,2:].T)

with torch.no_grad():
    lph_pred = ckan(X)
lph_pred = lph_pred.detach().numpy().flatten()

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# plot the true log partial hazard
levels = np.linspace(lph_true.min(), lph_true.max(), 100)
contour = axes[0].tricontourf(X[:,0], X[:,1], lph_true, levels=levels, cmap='coolwarm')
axes[0].set_xlabel(r'$x_1$', fontsize=14)
axes[0].set_ylabel(r'$x_2$', fontsize=14)
axes[0].set_xticks(x1_range)
axes[0].set_yticks(x2_range)

# plot the predicted log partial hazard
levels = np.linspace(lph_pred.min(), lph_pred.max(), 100)
contour = axes[1].tricontourf(X[:,0], X[:,1], lph_pred, levels=levels, cmap='coolwarm')
axes[1].set_xlabel(r'$x_1$', fontsize=14)
axes[1].set_ylabel(r'$x_2$', fontsize=14)
axes[1].set_xticks(x1_range)
axes[1].set_yticks(x2_range)

axes[0].text(0.02, 1.05, r'(a) True $\theta(\mathbf{x})$', fontsize=17, transform=axes[0].transAxes)
axes[1].text(0.02, 1.05, r'(b) CoxKAN Symbolic $\hat{\theta}(\mathbf{x})$', fontsize=17, transform=axes[1].transAxes)

fig.savefig(f'checkpoints/{exp_name}/lph_surfaces.png')

We see that indeed, the predicted expression is a very close approximation to the truth. 

I argue that CoxKAN still has the properties of high performance and interpretability in this case.

In [None]:
results = {
    'cindex_true': cindex_true,
    'cindex_cph': cindex_cph,
    'cindex_pre': cindex_pre,
    'cindex_pruned': cindex_pruned,
    'cindex_symbolic': cindex_symbolic,
    'coxkan_formula': formula,
    'coxph_formula': formula_cph,
}

with open(f'checkpoints/{exp_name}/results.pkl', 'wb') as f:
    pickle.dump(results, f)

Linear

In [None]:
import copy
import sys
from pathlib import Path
import yaml
import pickle
import argparse
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter

from coxkan import CoxKAN
from coxkan.utils import bootstrap_metric, set_seed

from sklearn.model_selection import train_test_split

SEED = set_seed(42)

### Reusable functions for the notebook

def true_cindex(df):
    global sim_config, duration_col, event_col, covariates
    lph = sim_config['log_partial_hazard'](**df[covariates])
    return concordance_index(df[duration_col], -lph, df[event_col])

def cph_cindex(df):
    global cph
    return cph.score(df, scoring_method='concordance_index')

def cph_reg_cindex(df):
    global cph_reg
    return cph_reg.score(df, scoring_method='concordance_index')

def cph_formula(cph):
    coefficients = cph.params_
    terms = []
    for covariate, coefficient in coefficients.items():
        term = f"{coefficient:.4f} * {covariate}"
        terms.append(term)
    expression = " + ".join(terms)
    return expression

def cph_reg_formula(cph_reg):
    print(cph_reg.params_)
    coefficients = cph_reg.params_
    terms = []
    for covariate, coefficient in coefficients.items():
        term = f"{coefficient:.4f} * {covariate}"
        terms.append(term)
    expression = " + ".join(terms)
    return expression
exp_name = "sim_reviewer_1_2"
sim_name = "reviewer_1_2"

### load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

# (config from hyperparameter search)
with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

#CoxPH with Lasso
cph_reg = CoxPHFitter(penalizer=0.5,l1_ratio=1)
cph_reg.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph_reg = bootstrap_metric(cph_reg_cindex, df_test, N=100)['formatted']
formula_cph_reg = cph_reg_formula(cph_reg)
print(f"CoxPH with Lasso Expression: {formula_cph_reg}")
print(f"CoxPH with Lasso C-Index: {cindex_cph_reg}")

test_lph = sim_config['log_partial_hazard'](df_test['x1'], df_test['x2'], df_test['x3'], df_test['x4'])
pred_lph = cph_reg.predict_partial_hazard(df_test)
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
fig = plot_residuals(test_lph,pred_lph,x1_vals,x2_vals)

In [None]:
# CoxKAN
ckan_trained = CoxKAN(seed=42, **config['init_params'])
# Train/Val split for early stopping
train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
log = ckan_trained.train(train, val, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan_trained.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png')
fig = ckan_trained.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png')

# search for pruning thresholds
pruning_thresholds = np.linspace(0, 0.05, 20)
pruning_thresholds[0] = config['prune_threshold']
cindices = []
for threshold in pruning_thresholds:
    ckan_ = CoxKAN(seed=42, **config['init_params'])
    ckan_.load_ckpt(f'checkpoints/{exp_name}/model.pt', verbose=False)
    _ = ckan_.predict(df_test) # important forward pass after loading a model
    
    prunable = True
    for l in range(ckan_.depth):
        if not (ckan_.acts_scale[l] > threshold).any():
            prunable = False
            break
        
    ckan_ = ckan_.prune_nodes(threshold)
    if 0 in ckan_.width: prunable = False
    if not prunable:
        if threshold == config['prune_threshold']: 
            cindices.append(0)
            continue
        else: break

    _ = ckan_.predict(df_test) # important forward pass
    ckan_.prune_edges(threshold, verbose=False)
    cindices.append(ckan_.cindex(val))
    print(f'Pruning threshold: {threshold:.2f}, C-Index (Val): {cindices[-1]:.6f}')
best_threshold = pruning_thresholds[np.argmax(cindices)]
if np.max(cindices) < 0.51: best_threshold = 0

_ = ckan_trained.predict(df_test)
ckan_trained = ckan_trained.prune_nodes(best_threshold)
_ = ckan_trained.predict(df_test)
ckan_trained.prune_edges(best_threshold, verbose=True)
fig = ckan_trained.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png')
cindex_pruned = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

# _ = ckan_trained.auto_symbolic(only_interpretable_funcs=True, verbose=True)
_ = ckan.auto_symbolic(only_interpretable_funcs=False, verbose=True)

_ = ckan_trained.predict(df_test)
fig = ckan_trained.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png')
cindex_symbolic = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")
formula = ckan_trained.symbolic_formula()[0][0]
print(formula)


_ = ckan_trained.predict(df_test)
fig = ckan_trained.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png')
cindex_symbolic = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
#plot x1, x2 with survival as color in test data AND by inputting the test data into the symbolic expression
# have it be two-plot wide subplots

import matplotlib.pyplot as plt
import numpy as np

test_x1 = df_test['x1'].values
test_x2 = df_test['x2'].values
test_x3 = df_test['x3'].values
test_x4 = df_test['x4'].values
test_surv = df_test['duration'].values
test_lph = sim_config['log_partial_hazard'](test_x1, test_x2, test_x3, test_x4)
print(test_lph)
pred_lph = ckan_trained.predict(df_test)
fig = plot_residuals(test_lph,pred_lph,test_x1,test_x2)

Quadratic

In [None]:
import copy
import sys
from pathlib import Path
import yaml
import pickle
import argparse
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter

from coxkan import CoxKAN
from coxkan.utils import bootstrap_metric, set_seed

from sklearn.model_selection import train_test_split

SEED = set_seed(42)

### Reusable functions for the notebook

def true_cindex(df):
    global sim_config, duration_col, event_col, covariates
    lph = sim_config['log_partial_hazard'](**df[covariates])
    return concordance_index(df[duration_col], -lph, df[event_col])

def cph_cindex(df):
    global cph
    return cph.score(df, scoring_method='concordance_index')

def cph_reg_cindex(df):
    global cph_reg
    return cph_reg.score(df, scoring_method='concordance_index')

def cph_formula(cph):
    coefficients = cph.params_
    terms = []
    for covariate, coefficient in coefficients.items():
        term = f"{coefficient:.4f} * {covariate}"
        terms.append(term)
    expression = " + ".join(terms)
    return expression

def cph_reg_formula(cph_reg):
    print(cph_reg.params_)
    coefficients = cph_reg.params_
    terms = []
    for covariate, coefficient in coefficients.items():
        term = f"{coefficient:.4f} * {covariate}"
        terms.append(term)
    expression = " + ".join(terms)
    return expression
exp_name = "sim_reviewer_1_3"
sim_name = "reviewer_1_3"

def plot_residuals(test_lph_vals,pred_lph_vals,x1_vals,x2_vals):
    #plot residuals in 1x2 grid vs x1 and x2
    residuals = test_lph_vals - pred_lph_vals
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    axes[0].scatter(x1_vals, residuals, alpha=0.5)
    #insert correlation as text
    corr = np.corrcoef(x1_vals, residuals)[0, 1]
    axes[0].text(0.5, 0.9, f'Correlation: {corr:.2f}', transform=axes[0].transAxes, fontsize=14)
    axes[0].set_xlabel(r'$x_1$', fontsize=14)
    axes[0].set_ylabel('Residuals', fontsize=14)
    axes[0].set_title('Residuals vs $x_1$', fontsize=14)
    axes[1].scatter(x2_vals, residuals, alpha=0.5)
    #insert correlation as text
    corr = np.corrcoef(x2_vals, residuals)[0, 1]
    axes[1].text(0.5, 0.9, f'Correlation: {corr:.2f}', transform=axes[1].transAxes, fontsize=14)
    axes[1].set_xlabel(r'$x_2$', fontsize=14)
    axes[1].set_ylabel('Residuals', fontsize=14)
    axes[1].set_title('Residuals vs $x_2$', fontsize=14)
    fig.tight_layout()
    plt.show()
    return fig

### load configs
with open(f'./configs/simulation/{sim_name}.yml', 'r') as file:
    sim_config = yaml.safe_load(file)
    sim_config['true_expr'] =  sim_config['log_partial_hazard'].split(': ')[-1] # log partial hazard expression
    sim_config['log_partial_hazard'] = eval(sim_config['log_partial_hazard']) # convert to function

# (config from hyperparameter search)
with open(f'configs/coxkan/{exp_name}.yml', 'r') as file:
    config = yaml.safe_load(file)

# data already generated (from sweep.py)
df_train = pd.read_csv(f'./data/{exp_name}_train.csv')
df_test = pd.read_csv(f'./data/{exp_name}_test.csv')
duration_col, event_col, covariates = 'duration', 'event', df_train.columns[:-2]
# C-Index of true log partial hazard expression
cindex_true = bootstrap_metric(true_cindex, df_test, N=100)['formatted']

print(f"True log partial hazard: {sim_config['true_expr']}")
print(f"True C-Index: {cindex_true}")
# CoxPH
cph = CoxPHFitter()
cph.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph = bootstrap_metric(cph_cindex, df_test, N=100)['formatted']
formula_cph = cph_formula(cph)
print(f"CoxPH Expression: {formula_cph}")
print(f"CoxPH C-Index: {cindex_cph}")

test_lph = sim_config['log_partial_hazard'](test_x1, test_x2, test_x3, test_x4)
pred_lph = cph.predict_partial_hazard(df_test)

x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values

fig = plot_residuals(test_lph,pred_lph,x1_vals,x2_vals)

#CoxPH with Lasso
cph_reg = CoxPHFitter(penalizer=0.5,l1_ratio=1)
cph_reg.fit(df_train, duration_col=duration_col, event_col=event_col)
cindex_cph_reg = bootstrap_metric(cph_reg_cindex, df_test, N=100)['formatted']
formula_cph_reg = cph_reg_formula(cph_reg)
print(f"CoxPH with Lasso Expression: {formula_cph_reg}")
print(f"CoxPH with Lasso C-Index: {cindex_cph_reg}")

test_lph = sim_config['log_partial_hazard'](test_x1, test_x2, test_x3, test_x4)
pred_lph = cph_reg.predict_partial_hazard(df_test)

x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values

fig = plot_residuals(test_lph,pred_lph,x1_vals,x2_vals)

In [None]:

# CoxKAN
ckan_trained = CoxKAN(seed=42, **config['init_params'])
# Train/Val split for early stopping
train, val = train_test_split(df_train, test_size=0.2, random_state=42, stratify=df_train['event'])
log = ckan_trained.train(train, val, duration_col, event_col, **config['train_params'])

cindex_pre = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Pre-symbolic: {cindex_pre}")

# Save
ckan_trained.save_ckpt(f'checkpoints/{exp_name}/model.pt')
fig = log.plot()
fig.savefig(f'checkpoints/{exp_name}/evolution.png')
fig = ckan_trained.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pre.png')

# search for pruning thresholds
pruning_thresholds = np.linspace(0, 0.05, 20)
pruning_thresholds[0] = config['prune_threshold']
cindices = []
for threshold in pruning_thresholds:
    ckan_ = CoxKAN(seed=42, **config['init_params'])
    ckan_.load_ckpt(f'checkpoints/{exp_name}/model.pt', verbose=False)
    _ = ckan_.predict(df_test) # important forward pass after loading a model
    
    prunable = True
    for l in range(ckan_.depth):
        if not (ckan_.acts_scale[l] > threshold).any():
            prunable = False
            break
        
    ckan_ = ckan_.prune_nodes(threshold)
    if 0 in ckan_.width: prunable = False
    if not prunable:
        if threshold == config['prune_threshold']: 
            cindices.append(0)
            continue
        else: break

    _ = ckan_.predict(df_test) # important forward pass
    ckan_.prune_edges(threshold, verbose=False)
    cindices.append(ckan_.cindex(val))
    print(f'Pruning threshold: {threshold:.2f}, C-Index (Val): {cindices[-1]:.6f}')
best_threshold = pruning_thresholds[np.argmax(cindices)]
if np.max(cindices) < 0.51: best_threshold = 0

_ = ckan_trained.predict(df_test)
ckan_trained = ckan_trained.prune_nodes(best_threshold)
_ = ckan_trained.predict(df_test)
ckan_trained.prune_edges(best_threshold, verbose=True)
fig = ckan_trained.plot(beta=40, in_vars=[r'$x_1$', r'$x_2$', r'$x_3$', r'$x_4$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_pruned.png')
cindex_pruned = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Pruned: {cindex_pruned}")

# _ = ckan_trained.auto_symbolic(only_interpretable_funcs=True, verbose=True)
_ = ckan.auto_symbolic(only_interpretable_funcs=False, verbose=True)

_ = ckan_trained.predict(df_test)
fig = ckan_trained.plot(beta=10, in_vars=[r'$x_1$', r'$x_2$', r'$\epsilon_1$', r'$\epsilon_2$'])
fig.savefig(f'checkpoints/{exp_name}/coxkan_symbolic.png')
cindex_symbolic = bootstrap_metric(ckan_trained.cindex, df_test, N=100)['formatted']
print(f"Symbolic: {cindex_symbolic}")
formula = ckan_trained.symbolic_formula()[0][0]
print(formula)


import matplotlib.pyplot as plt
import numpy as np

test_x1 = df_test['x1'].values
test_x2 = df_test['x2'].values
test_x3 = df_test['x3'].values
test_x4 = df_test['x4'].values
test_surv = df_test['duration'].values
test_lph = sim_config['log_partial_hazard'](test_x1, test_x2, test_x3, test_x4)
print(test_lph)

pred_lph = ckan_trained.predict(df_test)
x1_vals = df_test['x1'].values
x2_vals = df_test['x2'].values
fig = plot_residuals(test_lph,pred_lph,x1_vals,x2_vals)