# Introduction to the first multi-task network

In this notebook we introduce the use of the first multi-task network through an example dataset (TCGA-BRCA).

## Import

In [14]:
# Import related packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchtuples as tt
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from pycox.models.utils import pad_col
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox import models
from typing import Tuple
from torch import Tensor
from pycox.models import utils
from torchtuples import TupleTree
from scipy.interpolate import UnivariateSpline
 
import sys
sys.path.insert(0, '/')
from eval import EvalSurv


In [15]:
# set some seeds to make this reproducable
np.random.seed(123456)
_ = torch.manual_seed(123456)

In [16]:
# Import data
file_path = 'BRCA.txt'

df_train = pd.read_csv(file_path, sep='\t', header=0)

df_train.head()

Unnamed: 0,X100130426,X100133144,X100134869,X10357,X10431,X136542,X155060,X26823,X280660,X317712,...,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,event2,T2,event1,T1
0,0.0,4.12,3.8,5.73,8.68,0,10.21,0.0,0.0,0.0,...,10.7,8.02,10.24,11.78,10.89,10.21,0,4047,1,1808
1,0.0,3.36,4.2,6.14,9.14,0,9.01,1.06,0.63,0.0,...,10.39,7.64,9.24,12.43,10.37,8.67,0,4005,0,4005
2,0.93,3.66,3.35,7.28,10.41,0,9.21,0.0,0.0,0.0,...,9.59,8.38,9.06,12.41,9.88,9.0,0,1474,0,1474
3,0.0,3.71,3.59,7.18,9.76,0,9.11,0.5,0.0,0.0,...,9.76,7.46,9.25,12.47,9.61,9.46,0,1448,0,1448
4,0.0,2.97,3.95,6.41,9.58,0,8.03,0.51,0.0,0.0,...,10.04,3.91,9.6,11.98,9.7,9.79,0,348,0,348


In [17]:
# View all gene names
xx = df_train.drop(columns=['event2','T2','event1','T1'])
xx.columns.tolist()

['X100130426',
 'X100133144',
 'X100134869',
 'X10357',
 'X10431',
 'X136542',
 'X155060',
 'X26823',
 'X280660',
 'X317712',
 'X340602',
 'X388795',
 'X390284',
 'X391343',
 'X391714',
 'X404770',
 'X441362',
 'X442388',
 'X553137',
 'X57714',
 'X645851',
 'X652919',
 'X653553',
 'X728045',
 'X728603',
 'X728788',
 'X729884',
 'X8225',
 'X90288',
 'A1BG',
 'A1CF',
 'A2BP1',
 'A2LD1',
 'A2M',
 'A2ML1',
 'A4GALT',
 'A4GNT',
 'AAA1',
 'AAAS',
 'AACS',
 'AACSL',
 'AADAC',
 'AADACL2',
 'AADACL3',
 'AADACL4',
 'AADAT',
 'AAGAB',
 'AAK1',
 'AAMP',
 'AANAT',
 'AARS',
 'AARS2',
 'AARSD1',
 'AASDH',
 'AASDHPPT',
 'AASS',
 'AATF',
 'AATK',
 'ABAT',
 'ABCA10',
 'ABCA1',
 'ABCA11P',
 'ABCA12',
 'ABCA13',
 'ABCA17P',
 'ABCA2',
 'ABCA3',
 'ABCA4',
 'ABCA5',
 'ABCA6',
 'ABCA7',
 'ABCA8',
 'ABCA9',
 'ABCB10',
 'ABCB11',
 'ABCB1',
 'ABCB4',
 'ABCB5',
 'ABCB6',
 'ABCB7',
 'ABCB8',
 'ABCB9',
 'ABCC10',
 'ABCC11',
 'ABCC12',
 'ABCC13',
 'ABCC1',
 'ABCC2',
 'ABCC3',
 'ABCC4',
 'ABCC5',
 'ABCC6',
 'ABCC6P1'

In [18]:
# Train/test/validation split
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

df_train

Unnamed: 0,X100130426,X100133144,X100134869,X10357,X10431,X136542,X155060,X26823,X280660,X317712,...,ZXDC,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,event2,T2,event1,T1
0,0.00,4.12,3.80,5.73,8.68,0,10.21,0.00,0.00,0.0,...,10.70,8.02,10.24,11.78,10.89,10.21,0,4047,1,1808
1,0.00,3.36,4.20,6.14,9.14,0,9.01,1.06,0.63,0.0,...,10.39,7.64,9.24,12.43,10.37,8.67,0,4005,0,4005
2,0.93,3.66,3.35,7.28,10.41,0,9.21,0.00,0.00,0.0,...,9.59,8.38,9.06,12.41,9.88,9.00,0,1474,0,1474
5,0.00,2.32,3.87,6.85,9.66,0,8.12,0.00,0.00,0.0,...,10.13,4.07,9.29,12.01,9.85,10.05,0,1477,0,1477
7,0.00,1.30,3.32,6.76,10.47,0,6.27,0.61,0.00,0.0,...,9.66,1.63,8.33,11.65,10.12,8.79,0,303,0,303
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1074,0.00,2.71,3.94,7.01,9.45,0,9.63,0.56,0.00,0.0,...,9.89,6.83,9.06,12.44,9.92,9.77,0,347,0,347
1077,0.00,3.94,4.49,7.12,9.35,0,8.74,1.91,0.00,0.0,...,10.31,4.48,9.43,12.32,10.92,9.39,0,467,0,467
1078,0.00,4.54,4.82,6.03,9.50,0,8.56,0.56,0.00,0.0,...,10.43,7.37,9.55,12.42,10.49,9.90,0,488,0,488
1079,0.00,1.71,3.05,6.43,10.16,0,7.98,0.68,0.00,0.0,...,9.64,5.73,8.99,12.70,9.56,9.55,0,3287,1,181


In [19]:
# Covariate preprocessing
cols_standardize =  xx.columns.tolist()

standardize = [([col], StandardScaler()) for col in cols_standardize]

x_mapper = DataFrameMapper(standardize)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

In [20]:
# Discretization of survival times
class LabTransform(LabTransDiscreteTime):
    def transform(self, durations, events):
        durations, is_event = super().transform(durations, events > 0)
        events[is_event == 0] = 0
        return durations, events.astype('int64')
        
num_durations = 20

labtrans1 = LabTransform(num_durations, scheme='equidistant')
get_target1 = lambda df: (df['T1'].values, df['event1'].values)

T1_train = labtrans1.fit_transform(*get_target1(df_train))
T1_val = labtrans1.transform(*get_target1(df_val))
T1_test, event1_test = labtrans1.transform(*get_target1(df_test))

labtrans2 = LabTransform(num_durations, scheme='equidistant')
get_target2 = lambda df: (df['T2'].values, df['event2'].values)

T2_train = labtrans2.fit_transform(*get_target2(df_train))
T2_val = labtrans2.transform(*get_target2(df_val))
# Discretization is not required because the prediction time is already a continuous value after spline interpolation when evaluated on the test set
T2_test, event2_test = get_target2(df_test)

In [21]:
# Package the data into the input format required by the network later
def index_to_1d(i, j, n):
    """
    Converts a 2D index (i, j) from sets T1 and T2, each with elements ranging from 1 to n, to a 1D index.
    """
    return i * n + j

T_train = list(T1_train)
T1_train_list = list(T1_train)
T2_train_list = list(T2_train)
T_train[0] = index_to_1d(T1_train_list[0], T2_train_list[0], num_durations)
T_train[1] = index_to_1d(T1_train_list[1], T2_train_list[1], 2)
T_train = tuple(T_train)

T_val = list(T1_val)
T1_val_list = list(T1_val)
T2_val_list = list(T2_val)
T_val[0] = index_to_1d(T1_val_list[0], T2_val_list[0], num_durations)
T_val[1] = index_to_1d(T1_val_list[1], T2_val_list[1], 2)
T_val = tuple(T_val)
val = (x_val, T_val)

## Neural net

In [22]:
# Neural network architecture
class Multivariate_survival(torch.nn.Module):

    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                 out_features, batch_norm=True, dropout=None):
        super().__init__()
        self.shared_net = tt.practical.MLPVanilla(
            in_features, num_nodes_shared[:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_T1):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.shared_net(input)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out

In [23]:
in_features = x_train.shape[1]
num_nodes_shared = [64, 64]
num_nodes_indiv = [32, 32]
num_T1 = num_durations 
out_features = len(labtrans2.cuts)
batch_norm = True
dropout = 0.7

net = Multivariate_survival(in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                       out_features, batch_norm, dropout)

## Training



In [24]:
class first_multi_task(tt.Model):

    def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
        self.duration_index = duration_index
        if loss is None:
            loss = Loss1(alpha, sigma)
        super().__init__(net, loss, optimizer, device)

    @property
    def duration_index(self):
        
        return self._duration_index

    @duration_index.setter
    def duration_index(self, val):
        self._duration_index = val

    def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
        dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
                                             make_dataset=models.data.DeepHitDataset)
        return dataloader
    
    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):

        surv = self.predict_pmf_1_cif_2(input, batch_size, True, eval_, True, num_workers)
        return pd.DataFrame(surv, self.duration_index)

    def predict_surv_2_condpmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        cif = self.predict_pmf_1_cif_2(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = self.predict_pmf_1(input, batch_size, False, eval_, to_cpu, num_workers)
        condsurv = 1. - cif/pmf
        return tt.utils.array_or_tensor(condsurv, numpy, input)
        
    def predict_pmf_1_cif_2(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        cif = pmf.cumsum(1)
        return tt.utils.array_or_tensor(cif, numpy, input) 

    def predict_pmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf12 = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = pmf12.sum(1)
        return tt.utils.array_or_tensor(pmf, numpy, input)
    
    def predict_pmf_12(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1]
        pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2)
        return tt.utils.array_or_tensor(pmf, numpy, input)

def _reduction(loss: Tensor, reduction: str = 'mean') -> Tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")

def _diff_cdf_at_time_i(pmf: Tensor, y: Tensor) -> Tensor:

    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)
    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    diag_r = r.diag().view(1, -1)
    r = ones.matmul(diag_r) - r
    return r.transpose(0, 1)

def _rank_loss_deephit(pmf: Tensor, y: Tensor, rank_mat: Tensor, sigma: float,
                       reduction: str = 'mean') -> Tensor:

    r = _diff_cdf_at_time_i(pmf, y)
    loss = rank_mat * torch.exp(-r/sigma)
    loss = loss.mean(1, keepdim=True)
    return _reduction(loss, reduction)

def index_from_1d(k, n):

    i = k // n
    j = k % n
    return (i, j)
    
def nll_pmf_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, reduction: str = 'mean',
               epsilon: float = 1e-7) -> Tensor:

    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()
    
    idx_durations1, idx_durations2 = index_from_1d(idx_durations.view(-1), num_durations)
    batch_size = phi.size(0)
    sm = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    index = torch.arange(batch_size)
    part1 = sm[index, idx_durations1, idx_durations2].relu().add(epsilon).log().mul(event_03)
    part2 = (sm[index, idx_durations1, :].sum(1) - sm.cumsum(2)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_02)
    part3 = (sm[index, :, idx_durations2].sum(1) - sm.cumsum(1)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_01)
    part4 = (1 - sm.cumsum(2)[index, :, idx_durations2].sum(1) + sm.cumsum(1).cumsum(2)[index, idx_durations1, idx_durations2] - sm.cumsum(1)[index, idx_durations1, :].sum(1)).relu().add(epsilon).log().mul(event_00)     
     
    loss = - part1.add(part2).add(part3).add(part4)
    return _reduction(loss, reduction)



def rank_loss_deephit_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor,
                         sigma: float, reduction: str = 'mean') -> Tensor:

    idx_durations = idx_durations.view(-1)
    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()

    batch_size = phi.size(0)
    pmf = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    y = torch.zeros_like(pmf)
    y[torch.arange(batch_size), :, idx_durations] = 1.

    loss = []
    for i in range(4):
        rank_loss_i = _rank_loss_deephit(pmf[:, i, :], y[:, i, :], rank_mat, sigma, 'none')
        loss.append(rank_loss_i.view(-1) * (events == i).float())

    if reduction == 'none':
        return sum(loss)
    elif reduction == 'mean':
        return sum([lo.mean() for lo in loss])
    elif reduction == 'sum':
        return sum([lo.sum() for lo in loss])
    return _reduction(loss, reduction)

class _Loss(torch.nn.Module):

    def __init__(self, reduction: str = 'mean') -> None:
        super().__init__()
        self.reduction = reduction

class _Loss1(_Loss):

    def __init__(self, alpha: float, sigma: float, reduction: str = 'mean') -> None:
        super().__init__(reduction)
        self.alpha = alpha
        self.sigma = sigma

    @property
    def alpha(self) -> float:
        return self._alpha

    @alpha.setter
    def alpha(self, alpha: float) -> None:
        if (alpha < 0) or (alpha > 1):
            raise ValueError(f"Need `alpha` to be in [0, 1]. Got {alpha}.")
        self._alpha = alpha

    @property
    def sigma(self) -> float:
        return self._sigma

    @sigma.setter
    def sigma(self, sigma: float) -> None:
        if sigma <= 0:
            raise ValueError(f"Need `sigma` to be positive. Got {sigma}.")
        self._sigma = sigma

class Loss1(_Loss1):

    def forward(self, phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor) -> Tensor:
        nll =  nll_pmf_cr(phi, idx_durations, events, self.reduction)
        return nll

In [25]:
optimizer = tt.optim.AdamWR(lr=0.8, decoupled_weight_decay=0.01,
                                    cycle_eta_multiplier=0.6)
model = first_multi_task(net, optimizer, alpha=1,
                   duration_index=labtrans1.cuts)
batch_size = 128
lrfind = model.lr_finder(x_train, T_train, batch_size, tolerance=50)
model.optimizer.set_lr(lrfind.get_best_lr()) # The learning rates for the AdamWR optimizer were adjusted using the method proposed by Smith
epochs = 512
callbacks = [tt.callbacks.EarlyStoppingCycle()]
verbose = False
log = model.fit(x_train, T_train, batch_size, epochs, callbacks, verbose, val_data=val)

## Prediction and Evaluation

In [26]:
# Spline interpolation and evaluation of predictive performance
index = torch.arange(T1_test.size)
x = np.linspace(0, num_durations-1, num_durations)

xnew = np.linspace(0, num_durations-1, 10000)
ynew=[]    
for i in range(T2_test[event1_test==1].size):  
    spline_interp = UnivariateSpline(x, (1. - (model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].T
                                               /model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].sum(1)).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)
        
ynew_array = np.stack(ynew, axis=1)
    
surv1 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))
    
ev1 = EvalSurv(surv1, np.array(T2_test[event1_test==1]), np.array(event2_test)[event1_test==1], censor_surv='km')

ynew=[]

for i in range(T2_test[event1_test==0].size):  
    spline_interp = UnivariateSpline(x, (1-((model.predict_pmf_12(x_test).sum(0)[:,event1_test==0]-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].T)
                                            /(1-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].sum(1))).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)

ynew_array = np.stack(ynew, axis=1)
    
surv2 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))

ev2 = EvalSurv(surv2, np.array(T2_test[event1_test==0]), np.array(event2_test)[event1_test==0], censor_surv='km')

time_grid = np.linspace(T2_test.min(), T2_test.max(), 100) 

C_index=(ev1.concordance_td()*T2_test[event1_test==1].size + ev2.concordance_td()*T2_test[event1_test==0].size)/T2_test.size

IBS=(ev1.integrated_brier_score(time_grid)*T2_test[event1_test==1].size + ev2.integrated_brier_score(time_grid)*T2_test[event1_test==0].size)/T2_test.size

print("C-index:", C_index)
print("IBS:", IBS)

C-index: 0.7523214654098196
IBS: 0.12050520406432358
