In [None]:
! pip install lifelines
! pip install scikit-survival

In [None]:
# Standard libraries
from collections import defaultdict

# Third-party libraries
import pandas as pd
import torch
import numpy as np
from omegaconf import DictConfig, OmegaConf
from sksurv.nonparametric import kaplan_meier_estimator
import matplotlib.pyplot as plt
from lifelines.statistics import logrank_test

# Local dependencies
from drim.helpers import get_datasets, get_targets, get_encoder
from drim.utils import log_transform, seed_everything, get_dataframes, prepare_data, interpolate_dataframe
from drim.multimodal import MultimodalDataset, DRIMSurv, MultimodalModel
from drim.datasets import SurvivalDataset
from drim.models import MultimodalWrapper


seed = 1999
n_outs = 20
import matplotlib.pyplot as plt
plt.rcParams["font.family"] = "serif"

In [None]:
cfg = OmegaConf.load('./configs/robustness.yaml')
for method in ['tensor', 'concat', 'max', 'drim']:
    if method == 'tensor':
        cfg.general.dim = 32
    else:
        cfg.general.dim = 128
    scores = []
    for fold in range(5):
        seed_everything(seed)
        dataframes = get_dataframes(fold)
        dataframes = {split: prepare_data(dataframe, ['DNAm', 'WSI', 'RNA', 'MRI']) for split, dataframe in dataframes.items()}
        test_datasets = {}
        encoders = {}
        if method == 'drim':
            encoders_u = {}

        for modality in ['DNAm', 'WSI', 'RNA', 'MRI']:
            datasets = get_datasets(dataframes, modality, fold, return_mask=True)
            test_datasets[modality] = datasets['test']
            encoder = get_encoder(modality, cfg).cuda()
            encoders[modality] = encoder
            if method == 'drim':
                encoder_u = get_encoder(modality, cfg).cuda()
                encoders_u[modality] = encoder_u
            
        targets, cut = get_targets(dataframes, cfg.general.n_outs)
        dataset_test = MultimodalDataset(test_datasets, return_mask=True)
        test_data = SurvivalDataset(dataset_test, *targets['test'])
        loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=24)
        if method == 'drim':
            from drim.fusion import MaskedAttentionFusion
            fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)
            fusion_u = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)
            fusion.cuda()
            fusion_u.cuda()
            encoder = DRIMSurv(encoders_sh=encoders, encoders_u=encoders_u, fusion_s=fusion, fusion_u=fusion_u)
            model = MultimodalWrapper(encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs)
            model.load_state_dict(torch.load(f'./models/drimsurv_split_{int(fold)}.pth'))
        else:
            if method == 'max':
                from drim.fusion import ShallowFusion
                fusion = ShallowFusion('max')
            elif method == 'tensor':
                from drim.fusion import TensorFusion
                fusion = TensorFusion(modalities=['DNAm', 'WSI', 'RNA', 'MRI'], input_dim=cfg.general.dim, projected_dim=cfg.general.dim, output_dim=cfg.general.dim, dropout=0.)
            elif method == 'concat':
                from drim.fusion import ShallowFusion
                fusion = ShallowFusion('concat')
            elif method == 'maf':
                from drim.fusion import MaskedAttentionFusion
                fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)
                
            fusion.cuda()
            encoder = MultimodalModel(encoders, fusion= fusion)
            if method == 'concat':
                size = cfg.general.dim * 4
            else:
                size = cfg.general.dim
            model = MultimodalWrapper(encoder, embedding_dim=size, n_outs=cfg.general.n_outs)
            if method == 'max':
                prefix = 'vanilla'
            else:
                prefix = 'aux_mmo'
            
            model.load_state_dict(torch.load(f'./models/{prefix}_{method}_split_{int(fold)}.pth'))

                
        model.cuda()
        model.eval()

            
        hazards = []

        with torch.no_grad():
            for batch in loader:
                data, time, event = batch
                data, mask = data
                outputs = model(data, mask, return_embedding=False)
                hazards.append(outputs)
        
        hazards = interpolate_dataframe(pd.DataFrame((1 - torch.cat(hazards, dim=0).sigmoid()).add(1e-7).log().cumsum(1).exp().cpu().numpy().transpose(), cut))
        scores.append(-hazards.sum(0).values)
      
    scores = torch.from_numpy(np.stack(scores).mean(0))
    high_scores = scores > scores.median()
    low_scores = scores <= scores.median()
    data_high = dataframes['test'].iloc[high_scores.numpy()]
    data_low = dataframes['test'].iloc[low_scores.numpy()]
    results = logrank_test(data_high["time"], data_low["time"], data_high["event"], data_low["event"])

    p_value = results.p_value
    fig = plt.figure(dpi=500)
    time, survival_prob, conf_int = kaplan_meier_estimator(
                data_low["event"].astype(bool),
                data_low["time"],
                conf_type="log-log",
            )

    plt.step(time, survival_prob, where="post", label=f"Low risk")
    plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
    time, survival_prob, conf_int = kaplan_meier_estimator(
                data_high["event"].astype(bool),
                data_high["time"],
                conf_type="log-log",
            )

    plt.step(time, survival_prob, where="post", label=f"High risk")
    plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
    plt.ylim(0, 1)
    plt.ylabel(r"Probability of survival $\hat{S}(t)$")
    plt.xlabel("Time $t$")    
    plt.legend(loc="best")
    if method == 'tensor' or method == 'concat':
        title = method.capitalize()
        title += ' w/ MMO'
    else:
        title = method.capitalize()
    plt.title(title)
    p_value = f'{p_value:0.2}'
    plt.text(11,0.8,r'$p_{value}$='+p_value, fontsize=12.8)
    plt.grid()
    plt.savefig(f'{method}.pdf')
    #tikzplotlib_fix_ncols(fig)
    #tikzplotlib.save(f'{method}.tikz')
    plt.show()