In [None]:
import os 
import pickle
from pathlib import Path
import yaml
from coxkan import CoxKAN
import sympy 
from sympy.printing.latex import latex

def format_cindex(input_string):
    return input_string[:9] + "\\newline " + input_string[9:]

def ex_round(ex1, floating_digit=4):
    ex2 = ex1
    for a in sympy.preorder_traversal(ex1):
        if isinstance(a, sympy.Float):
            ex2 = ex2.subs(a, round(a, floating_digit))
    return ex2

# Simulation studies (synthetic datasets)

In [None]:
import sympy
from sympy.printing.latex import latex

experiments = ['sim_gaussian','sim_depth_1', 'sim_deep', 'sim_difficult']

def format_cindex(input_string):
    return input_string[:9] + "\\newline " + input_string[9:]

print('True CoxPH CoxKAN_Symbolic')
for exp_name in experiments:
    with open(f'checkpoints/{exp_name}/results.pkl', 'rb') as f:
        results = pickle.load(f)

    formula = ex_round(results['coxkan_formula'], floating_digit=1)

    print('-------------------------')
    print(f'Experiment: {exp_name}')
    print(f'{format_cindex(results['cindex_true'])} & {format_cindex(results["cindex_cph"])} & {format_cindex(results["cindex_symbolic"])}')
    print(latex(formula))
    display(formula)
    print(results['coxph_formula'])

    if exp_name == 'sim_deep':
        formula = sympy.simplify(results['coxkan_formula'])
        formula = ex_round(formula, floating_digit=1)
        print(latex(formula))
        display(formula)


# Real Clinical Data

In [None]:
import os 
import pickle
from pathlib import Path
import yaml
from coxkan import CoxKAN
import sympy

# datasets = ['gbsg', 'metabric', 'flchain', 'nwtco']
datasets = ['gbsg', 'metabric', 'support', 'flchain', 'nwtco']

def ex_round(ex1, floating_digit=4):
    ex2 = ex1
    for a in sympy.preorder_traversal(ex1):
        if isinstance(a, sympy.Float):
            ex2 = ex2.subs(a, round(a, floating_digit))
    return ex2

# Official results from DeepSurv publication (MLP-based survival model)
deepsurv_results = {
    'support': '0.618308 (0.616, 0.620)',
    'metabric': '0.643375 (0.639, 0.647)',
    'gbsg': '0.668402 (0.665, 0.671)',
}

incl_pruned = True

if incl_pruned:
    print('CoxPH DeepSurv CoxKAN_pre CoxKAN_pruned CoxKAN_Symbolic')
else:
    print('CoxPH DeepSurv CoxKAN_pre CoxKAN_Symbolic')
for dataset in datasets:

    directory = Path('checkpoints') / dataset
    with open(directory / 'results.pkl', 'rb') as f:
        results = pickle.load(f)

    with open(directory / 'config.yml', 'r') as f:
        config = yaml.safe_load(f)

    ckan = CoxKAN(seed=42, **config['init_params'])
    ckan.load_ckpt(directory / 'model.pt')
    normalizer = ckan.normalizer
    means = normalizer[0]
    stds = normalizer[1]

    ### C-Index results
    cindex_cph = results['CoxPH']['test']
    if 'DeepSurv' in results:
        cindex_deepsurv = results['DeepSurv']['test']
    else:
        cindex_deepsurv = deepsurv_results[dataset]
    cindex_pre = results['CoxKAN']['Pre']['test']
    cindex_pruned = results['CoxKAN']['Pruned']['test']
    cindex_symbolic = results['CoxKAN']['Symbolic']['test']

    print('\n-------------------------')
    print(f'Dataset: {dataset}')

    print(f'')

    if incl_pruned:
        print(f'{format_cindex(cindex_cph)} & {format_cindex(cindex_deepsurv)} & {format_cindex(cindex_pre)} & {format_cindex(cindex_pruned)} & {format_cindex(cindex_symbolic)}')
    else:
        print(f'{format_cindex(cindex_cph)} & {format_cindex(cindex_deepsurv)} & {format_cindex(cindex_pre)} & {format_cindex(cindex_symbolic)}')

    ### Formulas
    print('CoxKAN:')
    coxkan_formula = results['CoxKAN']['Symbolic']['formula']
    # coxkan_formula = ex_round(coxkan_formula, floating_digit=4)
    print('')
    if type(coxkan_formula) == str:
        print(coxkan_formula)
        print(latex(coxkan_formula))
    else:
        print(latex(coxkan_formula))
        display(coxkan_formula)

    print('CoxPH:')
    cph_coeffs = results['CoxPH']['summary']['coef']
    cph_formula = []
    for x, coef in cph_coeffs.items():
        mean = means[x]
        std = stds[x]
        x = (sympy.Symbol(x) - mean) / std
        cph_formula.append(ex_round(coef * x))
    cph_formula = sympy.Add(*cph_formula)
    display(cph_formula)
    print(latex(cph_formula))


#### METABRIC

$$
\begin{aligned}
\hat{\theta}_{KAN} = 
        & - 0.24 \cdot \text{PGR}
        + 0.7 e^{- 26 \left(1 - 0.06 \cdot \text{ERBB2}\right)^{2}} \\
        & + 0.2 \tanh(1.9 \text{MKI67} - 10)
        +\left\{ 
        \begin{array}{ll}
        0.1 & \text{if } \text{hormonal therapy} \\
        0.03 & \text{otherwise}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.01 & \text{if } \text{radiotherapy} \\
        0.18 & \text{otherwise}
        \end{array}
        \right\}
        +\left\{ 
        \begin{array}{ll}
        0.6 & \text{if } \text{chemotherapy} \\
        -0.05 & \text{otherwise}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.07 & \text{if } \text{ER positive} \\
        -0.04 & \text{otherwise}
        \end{array}
        \right\}
        - 1.7 \sin{\left(0.04 \cdot \text{Age} - 9.5 \right)} \\
\end{aligned}
$$

$$
\hat{\theta}_{CPH} = 0.0422 \text{EGFR} + 0.0693 \text{ER} + 0.1023 \text{ERBB2} + 0.3145 \text{MKI67} - 0.0698 \text{PGR} + 0.0433 \text{age} + 0.7712 \text{chemo} + 0.1849 \text{hormone} - 0.2118 \text{radio} 
$$

### SUPPORT

$$
\begin{aligned}
\hat{\theta}_{KAN} = 

        & \, \phi_{interact} - 0.0002 \cdot \text{age}
        +\left\{ 
        \begin{array}{ll}
        0.003 & \text{if } \text{metastasis} \\
        -0.01 & \text{if } \text{no cancer} \\
        -0.0098 & \text{if } \text{cancer}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.007 & \text{if } \text{male} \\
        -0.01 & \text{if } \text{female}
        \end{array}
        \right\}
        - 0.01 \cdot \text{race} + 0.04 \cdot \text{comorbidity} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.03 & \text{if } \text{diabetes} \\
        0.0006 & \text{if } \text{diabetes}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.03 & \text{if } \text{dementia} \\
        -0.0008 & \text{if } \text{dementia}
        \end{array}
        \right\} \\
        & + 0.9 e^{- 0.06 \left(1 - 0.1 \cdot \text{meanbp}\right)^{2}}
        + 0.1 \tanh{\left(0.02 \cdot \text{hr} - 3 \right)} \\
        & - 0.06 \sin{\left(0.08 \cdot \text{rr} + 0.2 \right)}
        + 0.6 e^{- 572 \left(1 - 0.02 \cdot \text{temp}\right)^{2}} \\
        & + 0.0008 \cdot \text{sodium}
        + 0.03 \tan{\left(0.02 \cdot \text{wbc} - 4 \right)}
        + 0.003 \cdot \text{creatinine} \\
\end{aligned}
$$

$$
\begin{aligned}
\hat{\theta}_{CPH} = & \, 0.01 \cdot \text{age} - 0.09 \cdot \text{sex} + 0.02 \cdot \text{race} + 0.02 \cdot \text{comorbidity} - 0.05 \cdot \text{diabetes} \\ 
& + 0.1 \cdot \text{dementia} - 0.26 \cdot \text{cancer} - 0.003 \cdot \text{meanbp} + 0.002 \cdot \text{hr} \\
& + 0.002 \cdot \text{rr} + 0.01 \cdot \text{temp} - 0.004 \cdot \text{sodium} + 0.003 \cdot \text{wbc} + 0.03 \cdot \text{creatinine}
\end{aligned}
$$

### NWTCO

$$
\begin{aligned}
\hat{\theta}_{KAN} = 
        & \, \phi_{1,3,1} + \phi_{1,4,1} + 0.02 \cdot \text{age}  \\
        & +\left\{ 
        \begin{array}{ll}
        -0.047 & \text{if } \text{favourable histology (instit)} \\
        -0.136 & \text{if } \text{unfavourable histology (instit)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.22 & \text{if } \text{favourable histology (histol)} \\
        0.62 & \text{if } \text{unfavourable histology (histol)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.47 & \text{if } \text{stage } = 1 \\
         0.04 & \text{if } \text{stage } = 2 \\
         0.35 & \text{if } \text{stage } = 3 \\
         0.78 & \text{if } \text{stage } = 4 \\
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.02 & \text{if } 3^{rd} \text{ clinical study} \\
        0.01 & \text{if } 4^{th} \text{ clinical study}
        \end{array}
        \right\} \\
        & + \left\{
        \begin{array}{ll}
        0.2 & \text{if } \text{in subcohort} \\
        -0.07 & \text{otherwise}
        \end{array}
        \right\}
    
\end{aligned}
$$

$$
\begin{aligned}
\phi_{1,3,1} = - 2.5 \arctan \Bigg( 2 \Bigg[
        & +0.03 \cdot \text{age} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.1 & \text{if } \text{favourable histology (instit)} \\
        -0.4 & \text{if } \text{unfavourable histology (instit)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.3 & \text{if } \text{favourable histology (histol)} \\
        -0.4 & \text{if } \text{unfavourable histology (histol)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.1 & \text{if } \text{stage } = 1 \\
        -0.07 & \text{if } \text{stage } = 2 \\
        -0.03 & \text{if } \text{stage } = 3 \\
        -0.17 & \text{if } \text{stage } = 4 \\
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.2 & \text{if } 3^{rd} \text{ clinical study} \\
        0.09 & \text{if } 4^{th} \text{ clinical study}
        \end{array}
        \right\} \\

        & + \left\{
        \begin{array}{ll}
        1 & \text{if } \text{in subcohort} \\
        -0.3 & \text{otherwise}
        \end{array}
        \right\}
\Bigg] \Bigg)
\end{aligned}
$$

$$
\begin{aligned}
\phi_{1,4,1} = - 6 \tanh \Bigg( 0.7 \Bigg[
        & +0.006 \cdot \text{age} \\
        & +\left\{ 
        \begin{array}{ll}
        0.15 & \text{if } \text{favourable histology (instit)} \\
        -0.7 & \text{if } \text{unfavourable histology (instit)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        0.1 & \text{if } \text{favourable histology (histol)} \\
        0.01 & \text{if } \text{unfavourable histology (histol)}
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.7 & \text{if } \text{stage } = 1 \\
        -0.2 & \text{if } \text{stage } = 2 \\
        0.7 & \text{if } \text{stage } = 3 \\
        1.45 & \text{if } \text{stage } = 4 \\
        \end{array}
        \right\} \\
        & +\left\{ 
        \begin{array}{ll}
        -0.6 & \text{if } 3^{rd} \text{ clinical study} \\
        0.5 & \text{if } 4^{th} \text{ clinical study}
        \end{array}
        \right\} \\

        & + \left\{
        \begin{array}{ll}
        1 & \text{if } \text{in subcohort} \\
        -0.2 & \text{otherwise}
        \end{array}
        \right\}
\Bigg] \Bigg)
\end{aligned}
$$

$$
\begin{aligned}
\hat{\theta}_{CPH} = & \, + 0.2163 \cdot \text{instit} + 1.4956 \cdot \text{histol} + 0.3413 \cdot \text{stage} \\
& - 0.1407 \cdot \text{study} + 0.0064 \cdot \text{age} - 0.2192 \cdot \text{in.subcohort}
\end{aligned}
$$

#### FLCHAIN

$$
\begin{aligned}
\hat{\theta}_{KAN} =  & \, 0.09 \cdot \text{age}         
        +\left\{ 
        \begin{array}{ll}
        -0.047 & \text{if } \text{female} \\
        0.118 & \text{if } \text{male}
        \end{array}
        \right\} \\
        & + 0.4 \arctan(0.4 \cdot \text{year} - 737)
        + 0.04 \cdot \text{FLC}_{kappa} \\
        & + 0.3 \cdot \text{FLC}_{lambda} 
        + 0.009 \cdot \text{FLC}_\text{group}
        + 2\arctan(0.5 \cdot \text{creatinine} - 0.9) \\
\end{aligned}
$$

$$
\hat{\theta}_{CPH} = 0.1 \cdot \text{age} + 0.3 \cdot \text{sex} + 0.06 \cdot \text{year} + 0.01 \cdot \text{FLC}_{kappa} + 0.2 \cdot \text{FLC}_{lambda} + 0.06 \cdot \text{FLC}_\text{group} + 0.03 \cdot \text{creatinine} + 0.3 \cdot \text{mgus}
$$

#### GBSG

$$
\begin{aligned}
\hat{\theta}_{KAN} = 
    & + \left\{ 
        \begin{array}{ll}
        -0.21 & \text{if } \text{hormonal therapy} \\
        0.28 & \text{otherwise}
        \end{array}
        \right\} \\
    & +\left\{
        \begin{array}{ll}
        -0.07 & \text{if } \text{tumor size} \leq 20 \, \text{mm} \\
        0.21 & \text{if } 20 < \text{tumor size} < 50 \, \text{mm} \\
        0.48 & \text{if } \text{tumor size} \geq 50 \, \text{mm}
        \end{array}
        \right\} \\
    & +\left\{ 
        \begin{array}{ll}
        -0.12 & \text{if } \text{pre-menopausal} \\
        0.23 & \text{if } \text{post-menopausal}
        \end{array}
        \right\} \\
    & + 1.8 \left(1 - 0.02 \cdot \text{age}\right)^{2} \\ 
    & - 1.2 e^{- 0.02 \left(\text{nodes} + 0.4\right)^{2}} \\
    & + 0.1 \cosh{\left(0.002 \cdot \text{PGR} - 1.6 \right)} \\
    & - 0.0007 \cdot \text{ER}
\end{aligned}
$$

$$
\hat{\theta}_{CPH} = 0.003 \cdot \text{age} - 0.0003 \cdot \text{er} - 0.3 \cdot \text{hormon} + 0.26 \cdot \text{meno} + 0.06 \cdot \text{nodes} - 0.0003 \cdot \text{pgr} + 0.3 \cdot \text{size}
$$


# TCGA Genomics

In [None]:
import pickle
from sympy.printing.latex import latex

experiments = ['TCGA-STAD', 'TCGA-BRCA', 'TCGA-GBMLGG', 'TCGA-KIRC','TCGA-LUAD']

print('CoxPH  CoxPH Reg  DeepSurv  CoxKAN Pre  CoxKAN Pruned  CoxKAN Symbolic')
for exp_name in experiments:

    results = pickle.load(open(f'checkpoints/{exp_name}/results.pkl', 'rb'))

    print('-------------------------')
    print(f'Experiment: {exp_name}')

    try: cindex_cph = results['CoxPH']['test']
    except: cindex_cph = 'N/A '
    cindex_cph_reg = results['CoxPH Reg']['test']
    cindex_deepsurv = results['DeepSurv']['test']
    cindex_pre = results['CoxKAN']['Pre']['test']
    cindex_pruned = results['CoxKAN']['Pruned']['test']
    cindex_symbolic = results['CoxKAN']['Symbolic']['test']

    print(f'{exp_name[5:]} & {format_cindex(cindex_cph)} & {format_cindex(cindex_cph_reg)} & {format_cindex(cindex_deepsurv)} & {format_cindex(cindex_pre)} & {format_cindex(cindex_pruned)} & {format_cindex(cindex_symbolic)}')


In [None]:
exp_name = 'TCGA-STAD'
results = pickle.load(open(f'checkpoints/{exp_name}/results.pkl', 'rb'))

coxkan_formula = results['CoxKAN']['Symbolic']['formula']
display(coxkan_formula)

print(latex(coxkan_formula))

{k:v for k,v in list(results['CoxKAN']['Symbolic']['terms_std'].items())[:10]}

$$
\begin{split}
\hat{\theta}_{KAN} = & + 0.2 \tanh(\text{CALM2}_{RNA} - 0.4) \quad (\sigma=0.15) \, \\
& -0.1 \cdot \text{PRR15L}_{RNA} \quad (\sigma=0.1) \, \\
& + 0.2 \cdot \text{TOMM20}_{RNA} \quad (\sigma=0.09) \, \\
& - 0.09 \cdot \text{MUC16}_{mut} \quad (\sigma=0.09) \, \\
& + 0.8 \arctan(0.4 \cdot \text{C3}_{RNA} + 0.2) \quad (\sigma=0.08) \, \\
& -0.1 \cdot \text{HNRNPK}_{RNA} \quad (\sigma=0.08) \, \\
& -0.2 \cdot \text{MISP}_{RNA} \quad (\sigma=0.08) \, \\
& + \text{less significant terms},
\end{split}
$$

In [None]:
exp_name = 'TCGA-BRCA'
results = pickle.load(open(f'checkpoints/{exp_name}/results.pkl', 'rb'))

coxkan_formula = results['CoxKAN']['Symbolic']['formula']
display(coxkan_formula)

print(latex(coxkan_formula))

{k:v for k,v in list(results['CoxKAN']['Symbolic']['terms_std'].items())[:10]}

$$
\begin{equation}
\begin{split}
\hat{\theta}_{KAN} = & + 0.2 \cdot \text{KMT2C}_{mut} \quad (\sigma=0.24) \, \\
& + 0.6 \sin(0.5 \cdot \text{HSPA8}_{RNA} - 7) \quad (\sigma=0.18)\, \\
& -2 \exp{(-0.04 (0.9 \cdot \text{PLXNB2}_{RNA} + 1)^2)} \quad (\sigma=0.17) \, \\
& -2 \exp{(-0.05 (0.9 \cdot \text{PGK1}_{RNA} + 1)^2)} \quad (\sigma=0.15) \, \\
& -0.14 \cdot \text{RYR2}_{mut} \quad (\sigma=0.14) \, \\
& +0.1 \cdot \text{DMD}_{mut} \quad (\sigma=0.1) \, \\
& +0.01 \text{TTN}_{mut} \quad (\sigma=0.07) \, \\
& + \frac{0.4}{(1-0.1 \cdot \text{group\_46}_{CNV})^2} \quad (\sigma=0.06) \, \\
& + 0.9 \exp(-0.06(\text{H2BC5}_{RNA} - 0.5)^2) \quad (\sigma=0.05) \, \\
& - 0.3 \sin(0.5 \cdot \text{RPL14}_{RNA} + 5) \quad (\sigma=0.05) \, \\
& + \text{less significant terms},
\end{split}
\end{equation}
$$

In [None]:
exp_name = 'TCGA-GBMLGG'
results = pickle.load(open(f'checkpoints/{exp_name}/results.pkl', 'rb'))

coxkan_formula = results['CoxKAN']['Symbolic']['formula']
display(coxkan_formula)

print(latex(coxkan_formula))

{k:v for k,v in list(results['CoxKAN']['Symbolic']['terms_std'].items())[:10]}


\begin{equation}
\begin{split}
\hat{\theta}_{KAN} = 
& \, - 0.2 \cdot \text{(1p19q arm codeletion)} \quad (\sigma=0.19) \, \\
& +e^{-0.2(-0.6 \cdot (\text{10q}_{CNV}) - 1)^2} \quad (\sigma=0.19) \, \\
& -0.2 \cdot \text{IDH}_{mut} \quad (\sigma=0.17) \, \\
& -0.06 \tan(0.4 \cdot \text{CARD11}_{CNV} + 8) \quad (\sigma=0.16) \, \\
& -0.08 (0.6 \cdot \text{PTEN}_{CNV} + 1)^4 \quad (\sigma=0.14) \, \\
& -0.3 \sin(3 \cdot \text{JAK2}_{CNV} - 5) \quad (\sigma=0.12) \, \\
& - 0.1 \cdot \text{CDKN2A}_{CNV} \quad (\sigma=0.12) \, \\
& -0.1 \sin(9 \cdot \text{CDKN2B}_{CNV} - 4) \quad (\sigma=0.1) \, \\
& - 0.3 \sin(9 \cdot \text{EGFR}_{CNV} + 0.8) \quad (\sigma=0.1) \, \\
& + \text{less significant terms},
\end{split}
\end{equation}

In [None]:
exp_name = 'TCGA-KIRC'
results = pickle.load(open(f'checkpoints/{exp_name}/results.pkl', 'rb'))

coxkan_formula = results['CoxKAN']['Symbolic']['formula']
display(coxkan_formula)

print(latex(coxkan_formula))

# {k:v for k,v in list(results['CoxKAN']['Symbolic']['terms_std'].items())[:10]}

(The kirc terms were not auto-ranked so I re-implemented the equation).

$$
\begin{equation}
\begin{split}
\hat{\theta}_{KAN} = & + 0.43 \cdot \text{MT1X}_{rna} \quad (\sigma=0.42) \, \\
& +0.34 \cdot \text{DDX43}_{rna} \quad (\sigma=0.34) \\
& +0.23 \cdot \text{CWH43}_{rna} \quad (\sigma=0.31) \\
& +0.22 \cdot \text{CILP}_{RNA} \quad (\sigma=0.31) \\
& -0.24 \cdot \text{LOC153328}_{RNA} \quad (\sigma=0.29) \\
& -0.21 \cdot \text{CYP3A7}_{RNA} \quad (\sigma=0.28) \\
& + \text{less significant terms},
\end{split}
\end{equation}
$$