In [23]:
# Import related packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchtuples as tt
import torch.nn.functional as F
import torch.nn as nn
import math
import warnings
import random

from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from pycox.models.utils import pad_col
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox import models
from typing import Tuple
from torch import Tensor
from pycox.models import utils
from torchtuples import TupleTree
from scipy.interpolate import UnivariateSpline

from sklearn.model_selection import StratifiedKFold
from sklearn.utils import shuffle
from itertools import product
from sklearn.model_selection import train_test_split

import sys
sys.path.insert(0, '/')
from eval import EvalSurv


In [24]:
# Discretization of survival times
class LabTransform(LabTransDiscreteTime):
    def transform(self, durations, events):
        durations, is_event = super().transform(durations, events > 0)
        events[is_event == 0] = 0
        return durations, events.astype('int64')

def index_3d_to_1d(i, j, k, n):
    """
    Converts a 3D index (i, j, k) from sets T1, T2, and cancertype, each with elements ranging from 1 to n, to a 1D index.
    """
    return i * n * n + j * n + k

def index_2d_to_1d(i, j, n):
    """
    Converts a 2D index (i, j) from sets T1 and T2, each with elements ranging from 1 to n, to a 1D index.
    """
    return i * n + j

# Neural network architecture
class MaskedLinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, mask=None):
        super(MaskedLinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if bias:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        self.mask = mask if mask is not None else torch.ones(out_features, in_features)

    def forward(self, input):
        masked_weight = self.weight * self.mask
        return nn.functional.linear(input, masked_weight, self.bias)

class pretrain_survival(torch.nn.Module):

    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                 out_features, batch_norm=True, dropout=None, mask=None):
        super().__init__()
        self.first_layer = MaskedLinearLayer(in_features, num_nodes_shared[0], mask=mask)
        self.shared_net = tt.practical.MLPVanilla(
            num_nodes_shared[0], num_nodes_shared[1:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(32 * num_T1):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.first_layer(input)
        out = self.shared_net(out)
        out_list = [net(out) for net in self.risk_nets]
        out = torch.stack(out_list, dim=1).view(out.size(0), 32, num_T1, out_features)
        return out

class finetune_survival(torch.nn.Module):

    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                 out_features, batch_norm=True, dropout=None, mask=None):
        super().__init__()
        self.first_layer = MaskedLinearLayer(in_features, num_nodes_shared[0], mask=mask)
        self.shared_net = tt.practical.MLPVanilla(
            num_nodes_shared[0], num_nodes_shared[1:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(32 * num_T1): 
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.first_layer(input)
        out = self.shared_net(out)
        out_list = [net(out) for net in self.risk_nets]
        out = torch.stack(out_list, dim=1).view(out.size(0), 32, num_T1, out_features)
        return out

class multi_task(tt.Model):

    def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
        self.duration_index = duration_index
        if loss is None:
            loss = Loss1(alpha, sigma)
        super().__init__(net, loss, optimizer, device)

    @property
    def duration_index(self):
        
        return self._duration_index

    @duration_index.setter
    def duration_index(self, val):
        self._duration_index = val

    def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
        dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
                                             make_dataset=models.data.DeepHitDataset)
        return dataloader
    
    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_pmf_12(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1]
        pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2).transpose(2, 3)
        return tt.utils.array_or_tensor(pmf, numpy, input)

def _reduction(loss: Tensor, reduction: str = 'mean') -> Tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")

def index_3d_from_1d(k, n):
    """
    Converts a 1D index k to a 3D index (i, j, m) assuming each dimension ranges from 0 to n-1.
    """
    i = k // (n * n)
    k = k % (n * n)
    j = k // n
    m = k % n
    return (i, j, m)

# Parameter optimization and loss function
def nll_pmf_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, reduction: str = 'mean',
               epsilon: float = 1e-7) -> Tensor:

    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()
    
    idx_durations1, idx_durations2, idx_durations3 = index_3d_from_1d(idx_durations.view(-1), 32)
    batch_size = phi.size(0)
    sm = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    index = torch.arange(batch_size)
    part1 = sm[index, idx_durations3, idx_durations1, idx_durations2].relu().add(epsilon).log().mul(event_03)
    part2 = (sm[index, idx_durations3, idx_durations1, :].sum(1) - sm.cumsum(3)[index, idx_durations3, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_02)
    part3 = (sm[index, idx_durations3, :, idx_durations2].sum(1) - sm.cumsum(2)[index, idx_durations3, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_01)
    part4 = (1 - sm.cumsum(3)[index, idx_durations3, :, idx_durations2].sum(1) + sm.cumsum(2).cumsum(3)[index, idx_durations3, idx_durations1, idx_durations2] - sm.cumsum(2)[index, idx_durations3, idx_durations1, :].sum(1)).relu().add(epsilon).log().mul(event_00)    

    loss = - part1.add(part2).add(part3).add(part4)
    return _reduction(loss, reduction)

class _Loss(torch.nn.Module):

    def __init__(self, reduction: str = 'mean') -> None:
        super().__init__()
        self.reduction = reduction

class _Loss1(_Loss):

    def __init__(self, alpha: float, sigma: float, reduction: str = 'mean') -> None:
        super().__init__(reduction)
        self.alpha = alpha
        self.sigma = sigma

    @property
    def alpha(self) -> float:
        return self._alpha

    @alpha.setter
    def alpha(self, alpha: float) -> None:
        if (alpha < 0) or (alpha > 1):
            raise ValueError(f"Need `alpha` to be in [0, 1]. Got {alpha}.")
        self._alpha = alpha

    @property
    def sigma(self) -> float:
        return self._sigma

    @sigma.setter
    def sigma(self, sigma: float) -> None:
        if sigma <= 0:
            raise ValueError(f"Need `sigma` to be positive. Got {sigma}.")
        self._sigma = sigma

class Loss1(_Loss1):

    def forward(self, phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor) -> Tensor:
        nll =  nll_pmf_cr(phi, idx_durations, events, self.reduction)
        return nll

In [25]:
# set the seed to make this reproducable
np.random.seed(123)
_ = torch.manual_seed(123)

# Import data
relevantdata = pd.read_csv('relevantdata.txt', header=0)
targetdata = pd.read_csv('targetdata.txt', header=0)

In [26]:
# Train/test split
np.random.seed(123)
_ = torch.manual_seed(123)

relevantdata = relevantdata.copy()
relevantdata['cancertype_num'] = pd.factorize(relevantdata['cancertype'])[0]
relevantdata_train = relevantdata.drop(columns=['cancertype'])
relevantdata_train
targetdata = targetdata.copy()
targetdata['cancertype_num'] = 31

np.random.seed(123)
_ = torch.manual_seed(123)

group_00 = targetdata[(targetdata['event1'] == 0) & (targetdata['event2'] == 0)]  
group_01 = targetdata[(targetdata['event1'] == 0) & (targetdata['event2'] == 1)]  
group_10 = targetdata[(targetdata['event1'] == 1) & (targetdata['event2'] == 0)]  
group_11 = targetdata[(targetdata['event1'] == 1) & (targetdata['event2'] == 1)] 

train_size_00 = int(0.8 * len(group_00))
train_00 = group_00[:train_size_00]
test_00 = group_00[train_size_00:]

train_size_01 = int(0.8 * len(group_01))
train_01 = group_01[:train_size_01]
test_01 = group_01[train_size_01:]

train_size_10 = int(0.8 * len(group_10))
train_10 = group_10[:train_size_10]
test_10 = group_10[train_size_10:]

train_size_11 = int(0.8 * len(group_11))
train_11 = group_11[:train_size_11]
test_11 = group_11[train_size_11:]

train_targetdata = pd.concat([train_00, train_01, train_10, train_11]).sample(frac=1, random_state=123).reset_index(drop=True)

test_targetdata = pd.concat([test_00, test_01, test_10, test_11]).sample(frac=1, random_state=123).reset_index(drop=True)

# Calculate the proportion of events in the original data set, training set, and test set
total_event1_rate = targetdata['event1'].mean()
total_event2_rate = targetdata['event2'].mean()

train_event1_rate = train_targetdata['event1'].mean()
train_event2_rate = train_targetdata['event2'].mean()

test_event1_rate = test_targetdata['event1'].mean()
test_event2_rate = test_targetdata['event2'].mean()

print(f"Total event1 rate: {total_event1_rate}, Total event2 rate: {total_event2_rate}")
print(f"Train event1 rate: {train_event1_rate}, Train event2 rate: {train_event2_rate}")
print(f"Test event1 rate: {test_event1_rate}, Test event2 rate: {test_event2_rate}")

target_train = train_targetdata.drop(columns=['cancertype'])
target_train

Total event1 rate: 0.6674107142857143, Total event2 rate: 0.46875
Train event1 rate: 0.6675977653631285, Train event2 rate: 0.4692737430167598
Test event1 rate: 0.6666666666666666, Test event2 rate: 0.4666666666666667


Unnamed: 0,ABCA1,ABCA10,ABCA12,ABCA13,ABCA2,ABCA3,ABCA5,ABCA6,ABCA7,ABCA8,...,SFRP5,TBL1X,TBL1XR1,VANGL1,VANGL2,event2,T2,event1,T1,cancertype_num
0,7.38,2.66,1.67,7.03,7.52,6.14,6.69,5.18,8.21,9.49,...,0.70,6.39,10.85,10.18,9.66,1,1766,1,1441,31
1,8.65,2.04,11.32,4.00,12.06,9.74,7.96,6.31,9.97,8.34,...,5.44,9.96,9.34,9.92,10.68,0,1160,0,1160,31
2,8.91,1.11,0.00,4.24,10.96,10.57,7.73,0.66,9.95,1.73,...,0.00,9.69,10.43,10.00,8.15,0,490,0,490,31
3,9.06,2.71,7.97,3.01,10.76,9.63,6.59,4.80,9.67,7.25,...,5.87,9.26,9.09,10.13,10.67,1,472,0,472,31
4,10.10,3.94,0.00,8.44,10.67,10.82,9.59,7.78,8.87,13.13,...,0.73,9.31,11.02,10.18,5.21,0,2249,0,2249,31
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
353,7.63,2.42,0.00,4.75,9.42,5.68,7.30,3.63,6.24,4.78,...,1.11,9.55,11.64,7.82,9.54,0,3932,0,3932,31
354,9.88,3.04,6.62,3.37,11.89,8.92,9.11,4.08,8.51,4.36,...,3.16,9.05,10.56,9.87,10.40,0,405,0,405,31
355,13.10,1.89,0.00,5.11,10.07,11.60,8.52,3.27,8.85,0.00,...,0.00,9.43,10.73,10.33,8.84,1,468,1,335,31
356,10.29,7.52,0.49,0.85,11.58,9.13,9.23,11.74,10.16,11.03,...,1.93,8.71,10.94,8.69,9.03,0,3176,0,3176,31


In [27]:
# View all gene names
xx = relevantdata_train.drop(columns=['event2','T2','event1','T1','cancertype_num'])

In [28]:
# Sparse connection layer mask matrix
mask = torch.load("mask.pt")
print(mask)

tensor([[ True,  True,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ...,  True,  True,  True]])


In [29]:
# Pre-training
np.random.seed(123)
_ = torch.manual_seed(123)

num_durations = 10
num_nodes_shared = (186, 100)
num_nodes_indiv = (50)
dropout = 0.6
lr = 0.001
weight_decay = 0.01 
batch_size = 100

df_train = relevantdata_train

# Covariate preprocessing
cols_standardize =  xx.columns.tolist()
standardize = [([col], StandardScaler()) for col in cols_standardize]
x_mapper = DataFrameMapper(standardize)
x_train = x_mapper.fit_transform(df_train).astype('float32')

labtrans1 = LabTransform(num_durations, scheme='equidistant')
get_target1 = lambda df: (df['T1'].values, df['event1'].values)
    
T1_train = labtrans1.fit_transform(*get_target1(df_train))

labtrans2 = LabTransform(num_durations, scheme='equidistant')
get_target2 = lambda df: (df['T2'].values, df['event2'].values)

T2_train = labtrans2.fit_transform(*get_target2(df_train))

T_train = list(T1_train)
T1_train_list = list(T1_train)
T2_train_list = list(T2_train)
cancertype_train_list = list(df_train['cancertype_num'].values)
T_train[0] = index_3d_to_1d(T1_train_list[0], T2_train_list[0], cancertype_train_list, 32)
T_train[1] = index_2d_to_1d(T1_train_list[1], T2_train_list[1], 2)
T_train = tuple(T_train)

in_features = x_train.shape[1]
num_T1 = num_durations 
out_features = len(labtrans2.cuts)
batch_norm = True

net = pretrain_survival(in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                       out_features, batch_norm, dropout, mask=mask)

optimizer = tt.optim.AdamWR(lr=lr, decoupled_weight_decay=weight_decay)
model = multi_task(net, optimizer, alpha=1, duration_index=labtrans1.cuts)
epochs = 500
callbacks = [tt.callbacks.EarlyStoppingCycle()]
verbose = False
warnings.simplefilter("ignore", UserWarning)
log = model.fit(x_train, T_train, batch_size, epochs, callbacks, verbose)



In [30]:
np.random.seed(123)
_ = torch.manual_seed(123)

target_train['event_combination'] = target_train['event1'].astype(str) + target_train['event2'].astype(str)

event_combination_rate_train = target_train['event_combination'].value_counts(normalize=True)

n_repeats = 25 
n_samples = 20 
concordance1 = []
concordance2 = []
    
for i in range(n_repeats):
            sampled_data = pd.DataFrame()
        
            for combination, proportion in event_combination_rate_train.items():
                n_combination_samples = int(n_samples * proportion)
                combination_data = target_train[target_train['event_combination'] == combination]
        
                if n_combination_samples > 0:
                    sampled_combination_data = combination_data.sample(n=n_combination_samples, random_state=i)
                    sampled_data = pd.concat([sampled_data, sampled_combination_data])

            if len(sampled_data) < n_samples:
                remaining_samples = target_train.drop(sampled_data.index)
                extra_samples = remaining_samples.sample(n=n_samples - len(sampled_data), random_state=i)
                sampled_data = pd.concat([sampled_data, extra_samples])

            lr = 0.001
            batch_size = 20

            df_train = sampled_data.drop(columns=['event_combination'])
            df_test = test_targetdata.drop(columns=['cancertype'])
 
            # Covariate preprocessing
            cols_standardize =  xx.columns.tolist()
            standardize = [([col], StandardScaler()) for col in cols_standardize]
            x_mapper = DataFrameMapper(standardize)
            x_train = x_mapper.fit_transform(df_train).astype('float32')
            x_test = x_mapper.transform(df_test).astype('float32')

            labtrans1 = LabTransform(num_durations, scheme='equidistant')
            get_target1 = lambda df: (df['T1'].values, df['event1'].values)
    
            T1_train = labtrans1.fit_transform(*get_target1(df_train))
            T1_test, event1_test = labtrans1.transform(*get_target1(df_test))

            labtrans2 = LabTransform(num_durations, scheme='equidistant')
            get_target2 = lambda df: (df['T2'].values, df['event2'].values)

            T2_train = labtrans2.fit_transform(*get_target2(df_train))
            T2_test, event2_test = get_target2(df_test)

            T_train = list(T1_train)
            T1_train_list = list(T1_train)
            T2_train_list = list(T2_train)
            cancertype_train_list = list(df_train['cancertype_num'].values)
            T_train[0] = index_3d_to_1d(T1_train_list[0], T2_train_list[0], cancertype_train_list, 32)
            T_train[1] = index_2d_to_1d(T1_train_list[1], T2_train_list[1], 2)
            T_train = tuple(T_train)

            in_features = x_train.shape[1]
            num_T1 = num_durations 
            out_features = len(labtrans2.cuts)
            batch_norm = True

            pre_trained_model = model
            # Initialize the new model
            new_net = finetune_survival(in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                       out_features, batch_norm, dropout, mask=mask)

            # Load pre-trained shared layer parameters
            new_net.first_layer.load_state_dict(pre_trained_model.net.first_layer.state_dict())
            new_net.shared_net.load_state_dict(pre_trained_model.net.shared_net.state_dict())
            new_net.risk_nets.load_state_dict(pre_trained_model.net.risk_nets.state_dict())
    
            new_optimizer = tt.optim.AdamWR(lr=lr, decoupled_weight_decay=weight_decay)
            new_model = multi_task(new_net, new_optimizer, alpha=1, duration_index=labtrans1.cuts)
            epochs = 500
            callbacks = [tt.callbacks.EarlyStoppingCycle()]
            verbose = False
            new_log = new_model.fit(x_train, T_train, batch_size, epochs, callbacks, verbose)

            index = torch.arange(T1_test.size)
            cancertype_test=df_test['cancertype_num'].values

            x = np.linspace(0, num_durations-1, num_durations)

            xnew = np.linspace(0, num_durations-1, 1000)
            ynew=[]    
            for i in range(T2_test[event1_test==1].size):  
                spline_interp = UnivariateSpline(x, (1. - (new_model.predict_pmf_12(x_test)[cancertype_test[event1_test==1],T1_test[event1_test==1],:,index[event1_test==1]].T
                                               /new_model.predict_pmf_12(x_test)[cancertype_test[event1_test==1],T1_test[event1_test==1],:,index[event1_test==1]].sum(1)).cumsum(0))[:,i], s=0)
                y_spline = spline_interp(xnew)
                ynew.append(y_spline)
        
            ynew_array = np.stack(ynew, axis=1)
    
            surv = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 1000))
    
            ev = EvalSurv(surv, np.array(T2_test[event1_test==1]), np.array(event2_test)[event1_test==1], censor_surv='km')
            time_grid = np.linspace(T2_test[event1_test==1].min(), T2_test[event1_test==1].max(), 100)
            concordance1.append(ev.concordance_td())
            concordance2.append(ev.integrated_brier_score(time_grid))


In [31]:
# C-index for 25 trials
print(concordance1)

[0.6393146979260595, 0.7029616724738676, 0.6140667267808837, 0.6519386834986475, 0.7006745362563238, 0.6829679595278246, 0.5871080139372822, 0.7125435540069687, 0.6236933797909407, 0.7343205574912892, 0.6672473867595818, 0.6019163763066202, 0.5566202090592335, 0.7015177065767285, 0.7047038327526133, 0.6114982578397212, 0.5915238954012624, 0.7334494773519163, 0.5725879170423805, 0.6681695220919748, 0.6652613827993255, 0.6358885017421603, 0.6123693379790941, 0.5942290351668169, 0.6677186654643823]


In [32]:
# IBS for 25 trials
print(concordance2)

[0.17206414063102243, 0.18644535354778613, 0.18791078112334406, 0.1708359194701949, 0.19545759708882432, 0.21502874986036014, 0.23174430792174508, 0.186680149982983, 0.23180193692770826, 0.19622272138130967, 0.18009815835956478, 0.1782852622505404, 0.17820039943306437, 0.1954066181951045, 0.24601041518563935, 0.1778422234390337, 0.1720017359886243, 0.19609022847665922, 0.23602510640481983, 0.16979478479856164, 0.21507292669549827, 0.17029463801976233, 0.17400340204490905, 0.17196736711775173, 0.17497449401196022]
