# 3.0 - Survival analysis of METABRIC patients

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import torch

import os
import sys
sys.path.append(f"../")
from sslcox.utils.model_evaluation import model_already_trained

import lifelines
from lifelines import CoxPHFitter
from sslcox.data.load_datasets import load_METABRIC


In [2]:
EXPRESSIONS = 'METABRIC'

DS_DIR = f'{EXPRESSIONS}-optuna'
CV_DIR = lambda cv: f'CV-{cv}'
MODEL_DIR = lambda m: f'{m}-model-results'

In [3]:
training_results = []

for cv in sorted(os.listdir(f'../data/training-results/{DS_DIR}')):
    if cv.startswith('.'): continue
    cv_results = {}

    for model in os.listdir(f'../data/training-results/{DS_DIR}/{cv}'):
        if model.startswith('.'): continue

        name = model.split('-model-results')[0]
        base_path = f'../data/training-results/{DS_DIR}/{cv}/{model}'

        if not model_already_trained(DS_DIR, CV_DIR(cv.split('-')[1]), MODEL_DIR(name)):
            continue

        X_train_latent = pd.read_csv(f'{base_path}/X_train_latent.tsv', sep='\t', index_col=['index'])
        X_test_latent = pd.read_csv(f'{base_path}/X_test_latent.tsv', sep='\t', index_col=['index'])
        
        cv_results[name] = {
            'X_train_latent': X_train_latent,
            'X_test_latent': X_test_latent,
        }
        
        # IF VAE
        if 'vae' in name:

            with open(f'{base_path}/study.pickle', 'rb') as f:
                study = pickle.load(f)
            torch_weights = torch.load(f'{base_path}/torch_model.pt', map_location=torch.device('cpu'))

            cv_results[name].update({
                'study': study,
                'torch_weights': torch_weights,
            })
    
    training_results += [cv_results]

In [4]:
data, metadata = load_METABRIC()

In [5]:
os_time, os_event = metadata['OS_MONTHS'], metadata['OS_STATUS'].apply(lambda x: int(x[0]))
rfs_time, rfs_event = metadata['RFS_MONTHS'].fillna(0), metadata['RFS_STATUS'].apply(lambda x: int(x[0] if type(x) == str else 0))

## Modeling transcription and decay enrichment

In [6]:
from tqdm import tqdm

PROCESS = ['rfs']

survival_data = {
    'rfs': [rfs_time, rfs_event],
}

ridge_models = {}
performance = {}
SCORING_METHOD = 'concordance_index'

for cv in range(len(training_results)):
    ridge_models[cv] = {}
    performance[cv] = {}

    for model, values in tqdm(training_results[cv].items()):
        ridge_models[cv][model] = {}
        performance[cv][model] = {}

        for process in PROCESS:
            y_time, y_event = survival_data[process]

            X_train = values['X_train_latent']
            X_test = values['X_test_latent']

            X_train_mean, X_train_std = X_train.mean(axis=0), X_train.std(axis=0)
            X_train = (X_train - X_train_mean)/X_train_std
            X_test = (X_test - X_train_mean)/X_train_std

            train_set = pd.concat((X_train, y_time, y_event), axis=1, join='inner')
            train_set.columns = list(train_set.columns[:-2]) + ['time', 'event']
            test_set = pd.concat((X_test, y_time, y_event), axis=1, join='inner')
            test_set.columns = list(test_set.columns[:-2]) + ['time', 'event']

            estimator = CoxPHFitter(penalizer=0.1)
            estimator.fit(train_set, 'time', 'event')

            ridge_models[cv][model][process] = estimator

            score = estimator.score(test_set, SCORING_METHOD)
            performance[cv][model][process] = score
        

100%|██████████| 5/5 [01:41<00:00, 20.38s/it]
100%|██████████| 5/5 [01:39<00:00, 19.88s/it]
100%|██████████| 5/5 [01:39<00:00, 19.83s/it]
100%|██████████| 5/5 [00:27<00:00,  5.59s/it]
100%|██████████| 5/5 [00:27<00:00,  5.47s/it]


In [7]:
results = {
    name: {
        d: [per[name][d] for per in performance.values()]
        for d in PROCESS
    }
    for name in performance[0].keys()
}

results = {name: results[name] for name in ['vae-cox', 'vae-mse', 'vae-div', 'pca-emb', 'no-embedding']}

In [8]:
mean_results = pd.DataFrame(
    [[(np.mean(results[name][d]), np.std(results[name][d])) for name in results.keys()] for d in PROCESS],
    index=PROCESS, columns=results.keys()
).T
mean_results.head()

Unnamed: 0,rfs
vae-cox,"(0.6370137001961169, 0.015321285656969798)"
vae-mse,"(0.6252743410333371, 0.011481513049029777)"
vae-div,"(0.6265273679943228, 0.014007834985210932)"
pca-emb,"(0.6157731593507318, 0.012197675010028311)"
no-embedding,"(0.6105721594683235, 0.02446224960510976)"


In [10]:
for name in results.keys():
    print(name, f"\({np.mean(results[name]['rfs']):.3f} \pm {np.std(results[name]['rfs'])/np.sqrt(len(results[name]['rfs'])):.3f}\)")

vae-cox \(0.637 \pm 0.007\)
vae-mse \(0.625 \pm 0.005\)
vae-div \(0.627 \pm 0.006\)
pca-emb \(0.616 \pm 0.005\)
no-embedding \(0.611 \pm 0.011\)
