# Introduction to the second multi-task network

In this notebook we introduce the use of the second multi-task network through an example dataset (TCGA-BRCA).

## Import

In [15]:
# Import related packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchtuples as tt
import torch.nn.functional as F
import torch.nn as nn
import math

from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from pycox.models.utils import pad_col
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox import models
from typing import Tuple
from torch import Tensor
from pycox.models import utils
from torchtuples import TupleTree
from scipy.interpolate import UnivariateSpline

import sys
sys.path.insert(0, '/')
from eval import EvalSurv


In [16]:
# set some seeds to make this reproducable
np.random.seed(123456)
_ = torch.manual_seed(123456)

In [17]:
# Import data
file_path = 'BRCA.txt'

df = pd.read_csv(file_path, sep='\t', header=0)

# Leaving only the genetic data associated with the 186KEGG pathway
def parse_kegg_file(file_path):
    kegg_data = {}
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('KEGG'):
                parts = line.strip().split('\t')
                pathway_name = parts[0]
                genes = parts[2:] if len(parts) > 2 else []
                kegg_data[pathway_name] = genes
    return kegg_data

file_path = 'kegg_legacy.txt'
kegg_pathways = parse_kegg_file(file_path)

all_genes = set()

for genes in kegg_pathways.values():
    all_genes.update(genes)

genes_columns = [gene for gene in all_genes if gene in df.columns]

filtered_df_train = df[genes_columns]

filtered_kegg_pathways = {}
for pathway, genes in kegg_pathways.items():
    filtered_genes = [gene for gene in genes if gene in genes_columns]
    filtered_kegg_pathways[pathway] = filtered_genes

concatenated_df = pd.concat([filtered_df_train, df.iloc[:,-4:]], axis=1)

df_train = concatenated_df

df_train.head()

Unnamed: 0,COX5B,GLA,COL4A1,DNAL1,EHD4,EFNA1,XYLT1,TRA2B,NDUFA3,FLT1,...,AGL,P2RY13,IFNK,CALR,SNRPE,PVR,event2,T2,event1,T1
0,9.85,10.3,13.51,8.84,9.43,10.75,8.42,11.32,7.79,9.84,...,11.05,5.23,0.0,15.15,9.45,9.86,0,4047,1,1808
1,11.87,8.59,13.6,7.53,9.79,10.31,7.11,11.44,10.27,9.63,...,9.02,5.85,0.0,14.13,9.4,8.22,0,4005,0,4005
2,11.36,9.26,14.11,7.55,10.34,11.67,8.44,11.27,9.74,9.83,...,11.96,6.24,0.0,14.6,10.27,9.52,0,1474,0,1474
3,10.9,13.71,13.19,8.65,10.73,11.42,8.83,11.25,10.05,9.68,...,9.09,6.04,0.0,14.27,10.13,7.84,0,1448,0,1448
4,10.24,8.47,12.56,8.64,10.49,10.53,8.37,11.11,9.57,9.39,...,11.02,6.27,0.0,14.37,9.48,7.92,0,348,0,348


In [18]:
# Sparse connection layer mask matrix
num_pathways = len(filtered_kegg_pathways)
num_genes = len(genes_columns)

mask = torch.zeros(num_pathways, num_genes, dtype=torch.bool)

gene_to_index = {gene: idx for idx, gene in enumerate(genes_columns)}

for pathway_idx, (pathway, genes) in enumerate(filtered_kegg_pathways.items()):
    for gene in genes:
        if gene in gene_to_index:
            gene_idx = gene_to_index[gene]
            mask[pathway_idx, gene_idx] = 1  


In [19]:
# View all gene names
xx = df_train.drop(columns=['event2','T2','event1','T1'])
xx.columns.tolist()

['COX5B',
 'GLA',
 'COL4A1',
 'DNAL1',
 'EHD4',
 'EFNA1',
 'XYLT1',
 'TRA2B',
 'NDUFA3',
 'FLT1',
 'L1CAM',
 'TREH',
 'F11R',
 'CDC7',
 'OR5M3',
 'UGT8',
 'TNFSF18',
 'ACACA',
 'ENTPD1',
 'ACIN1',
 'OR5P3',
 'GSTA3',
 'CLCA4',
 'PPT1',
 'CTLA4',
 'B3GALT2',
 'ADSL',
 'MTHFS',
 'BMP5',
 'MAN2B1',
 'OR5AK2',
 'STAM',
 'LTB4R',
 'HNRNPA1',
 'ACTN4',
 'TAF10',
 'MOS',
 'AP1S3',
 'ATP6V1B2',
 'TBP',
 'GPX3',
 'OR6M1',
 'FSHB',
 'FLT4',
 'MGAT4A',
 'IL13RA1',
 'GPX5',
 'SPAM1',
 'HCST',
 'TNFRSF10B',
 'FGF17',
 'RPS6KA2',
 'MMP2',
 'ITGB2',
 'NADK',
 'NDUFA4L2',
 'FBXW8',
 'ATP6V1E2',
 'C7',
 'MANBA',
 'RCOR1',
 'BMP4',
 'GNG7',
 'TAS2R31',
 'ACADM',
 'COX6A2',
 'FPGS',
 'DBI',
 'CREBBP',
 'RFK',
 'OSMR',
 'TAT',
 'TAS2R7',
 'OR6C65',
 'SGSH',
 'APH1A',
 'ARAP1',
 'PGP',
 'HPSE',
 'OR2M5',
 'OR52B4',
 'VAMP2',
 'DUSP16',
 'ITPA',
 'NCK2',
 'TLN1',
 'PSMF1',
 'RPS10',
 'SLC18A2',
 'ZFYVE16',
 'CANX',
 'OR56A1',
 'PDE4A',
 'CPT1B',
 'UBE2Q2',
 'GPI',
 'OR4F15',
 'UBE2J2',
 'E2F4',
 'ADCY7',
 '

In [20]:
# Train/test/validation split
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

df_train

Unnamed: 0,COX5B,GLA,COL4A1,DNAL1,EHD4,EFNA1,XYLT1,TRA2B,NDUFA3,FLT1,...,AGL,P2RY13,IFNK,CALR,SNRPE,PVR,event2,T2,event1,T1
0,9.85,10.30,13.51,8.84,9.43,10.75,8.42,11.32,7.79,9.84,...,11.05,5.23,0.00,15.15,9.45,9.86,0,4047,1,1808
1,11.87,8.59,13.60,7.53,9.79,10.31,7.11,11.44,10.27,9.63,...,9.02,5.85,0.00,14.13,9.40,8.22,0,4005,0,4005
2,11.36,9.26,14.11,7.55,10.34,11.67,8.44,11.27,9.74,9.83,...,11.96,6.24,0.00,14.60,10.27,9.52,0,1474,0,1474
5,10.84,9.76,12.59,8.63,10.35,11.83,8.96,11.33,9.73,9.33,...,12.05,7.32,0.00,14.22,10.20,8.03,0,1477,0,1477
7,10.57,8.41,11.27,8.20,9.46,10.06,6.11,11.38,9.96,8.26,...,9.07,2.22,0.00,14.54,10.56,10.26,0,303,0,303
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1074,10.75,9.18,11.59,8.35,9.90,11.95,7.59,11.05,9.91,9.24,...,9.56,6.27,0.56,13.94,10.13,8.69,0,347,0,347
1077,10.30,8.14,13.49,8.38,10.61,10.64,9.51,11.12,8.35,10.08,...,9.81,6.94,0.00,13.57,9.37,9.44,0,467,0,467
1078,10.26,8.47,13.67,8.86,10.80,11.37,9.81,11.12,9.24,10.19,...,10.21,7.53,0.00,14.24,9.29,8.63,0,488,0,488
1079,11.39,10.67,12.45,8.20,10.48,12.92,7.75,11.46,10.73,9.28,...,9.31,7.09,0.00,14.27,10.75,9.11,0,3287,1,181


In [21]:
# Covariate preprocessing
cols_standardize =  xx.columns.tolist()

standardize = [([col], StandardScaler()) for col in cols_standardize]

x_mapper = DataFrameMapper(standardize)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

In [22]:
# Discretization of survival times
class LabTransform(LabTransDiscreteTime):
    def transform(self, durations, events):
        durations, is_event = super().transform(durations, events > 0)
        events[is_event == 0] = 0
        return durations, events.astype('int64')
        
num_durations = 10

labtrans1 = LabTransform(num_durations, scheme='equidistant')
get_target1 = lambda df: (df['T1'].values, df['event1'].values)

T1_train = labtrans1.fit_transform(*get_target1(df_train))
T1_val = labtrans1.transform(*get_target1(df_val))
T1_test, event1_test = labtrans1.transform(*get_target1(df_test))

labtrans2 = LabTransform(num_durations, scheme='equidistant')
get_target2 = lambda df: (df['T2'].values, df['event2'].values)

T2_train = labtrans2.fit_transform(*get_target2(df_train))
T2_val = labtrans2.transform(*get_target2(df_val))
# Discretization is not required because the prediction time is already a continuous value after spline interpolation when evaluated on the test set
T2_test, event2_test = get_target2(df_test)

In [23]:
# Package the data into the input format required by the network later
def index_to_1d(i, j, n):
    """
    Converts a 2D index (i, j) from sets T1 and T2, each with elements ranging from 1 to n, to a 1D index.
    """
    return i * n + j

T_train = list(T1_train)
T1_train_list = list(T1_train)
T2_train_list = list(T2_train)
T_train[0] = index_to_1d(T1_train_list[0], T2_train_list[0], num_durations)
T_train[1] = index_to_1d(T1_train_list[1], T2_train_list[1], 2)
T_train = tuple(T_train)

T_val = list(T1_val)
T1_val_list = list(T1_val)
T2_val_list = list(T2_val)
T_val[0] = index_to_1d(T1_val_list[0], T2_val_list[0], num_durations)
T_val[1] = index_to_1d(T1_val_list[1], T2_val_list[1], 2)
T_val = tuple(T_val)
val = (x_val, T_val)

## Neural net

In [24]:
# Neural network architecture
class MaskedLinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, mask=None):
        super(MaskedLinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if bias:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        self.mask = mask if mask is not None else torch.ones(out_features, in_features)

    def forward(self, input):
        masked_weight = self.weight * self.mask
        return nn.functional.linear(input, masked_weight, self.bias)

class Multivariate_survival(torch.nn.Module):

    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                 out_features, batch_norm=True, dropout=None, mask=None):
        super().__init__()
        self.first_layer = MaskedLinearLayer(in_features, num_nodes_shared[0], mask=mask)
        self.shared_net = tt.practical.MLPVanilla(
            num_nodes_shared[0], num_nodes_shared[1:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_T1):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.first_layer(input)
        out = self.shared_net(out)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out



In [25]:
in_features = x_train.shape[1]
num_nodes_shared = [num_pathways, 32, 32]
num_nodes_indiv = [32, 32]
num_T1 = num_durations 
out_features = len(labtrans2.cuts)
batch_norm = True
dropout = 0.7

net = Multivariate_survival(in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                       out_features, batch_norm, dropout, mask=mask)

## Training



In [26]:
class second_multi_task(tt.Model):

    def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
        self.duration_index = duration_index
        if loss is None:
            loss = Loss1(alpha, sigma)
        super().__init__(net, loss, optimizer, device)

    @property
    def duration_index(self):
        
        return self._duration_index

    @duration_index.setter
    def duration_index(self, val):
        self._duration_index = val

    def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
        dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
                                             make_dataset=models.data.DeepHitDataset)
        return dataloader
    
    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):

        surv = self.predict_pmf_1_cif_2(input, batch_size, True, eval_, True, num_workers)
        return pd.DataFrame(surv, self.duration_index)

    def predict_surv_2_condpmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        cif = self.predict_pmf_1_cif_2(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = self.predict_pmf_1(input, batch_size, False, eval_, to_cpu, num_workers)
        condsurv = 1. - cif/pmf
        return tt.utils.array_or_tensor(condsurv, numpy, input)
        
    def predict_pmf_1_cif_2(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        cif = pmf.cumsum(1)
        return tt.utils.array_or_tensor(cif, numpy, input) 

    def predict_pmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf12 = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = pmf12.sum(1)
        return tt.utils.array_or_tensor(pmf, numpy, input)
    
    def predict_pmf_12(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1]
        pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2)
        return tt.utils.array_or_tensor(pmf, numpy, input)

def _reduction(loss: Tensor, reduction: str = 'mean') -> Tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")

def _diff_cdf_at_time_i(pmf: Tensor, y: Tensor) -> Tensor:

    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)
    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    diag_r = r.diag().view(1, -1)
    r = ones.matmul(diag_r) - r
    return r.transpose(0, 1)

def _rank_loss_deephit(pmf: Tensor, y: Tensor, rank_mat: Tensor, sigma: float,
                       reduction: str = 'mean') -> Tensor:

    r = _diff_cdf_at_time_i(pmf, y)
    loss = rank_mat * torch.exp(-r/sigma)
    loss = loss.mean(1, keepdim=True)
    return _reduction(loss, reduction)

def index_from_1d(k, n):

    i = k // n
    j = k % n
    return (i, j)
    
def nll_pmf_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, reduction: str = 'mean',
               epsilon: float = 1e-7) -> Tensor:

    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()
    
    idx_durations1, idx_durations2 = index_from_1d(idx_durations.view(-1), num_durations)
    batch_size = phi.size(0)
    sm = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    index = torch.arange(batch_size)
    part1 = sm[index, idx_durations1, idx_durations2].relu().add(epsilon).log().mul(event_03)
    part2 = (sm[index, idx_durations1, :].sum(1) - sm.cumsum(2)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_02)
    part3 = (sm[index, :, idx_durations2].sum(1) - sm.cumsum(1)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_01)
    part4 = (1 - sm.cumsum(2)[index, :, idx_durations2].sum(1) + sm.cumsum(1).cumsum(2)[index, idx_durations1, idx_durations2] - sm.cumsum(1)[index, idx_durations1, :].sum(1)).relu().add(epsilon).log().mul(event_00)     
     
    loss = - part1.add(part2).add(part3).add(part4)
    return _reduction(loss, reduction)



def rank_loss_deephit_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor,
                         sigma: float, reduction: str = 'mean') -> Tensor:

    idx_durations = idx_durations.view(-1)
    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()

    batch_size = phi.size(0)
    pmf = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    y = torch.zeros_like(pmf)
    y[torch.arange(batch_size), :, idx_durations] = 1.

    loss = []
    for i in range(4):
        rank_loss_i = _rank_loss_deephit(pmf[:, i, :], y[:, i, :], rank_mat, sigma, 'none')
        loss.append(rank_loss_i.view(-1) * (events == i).float())

    if reduction == 'none':
        return sum(loss)
    elif reduction == 'mean':
        return sum([lo.mean() for lo in loss])
    elif reduction == 'sum':
        return sum([lo.sum() for lo in loss])
    return _reduction(loss, reduction)

class _Loss(torch.nn.Module):

    def __init__(self, reduction: str = 'mean') -> None:
        super().__init__()
        self.reduction = reduction

class _Loss1(_Loss):

    def __init__(self, alpha: float, sigma: float, reduction: str = 'mean') -> None:
        super().__init__(reduction)
        self.alpha = alpha
        self.sigma = sigma

    @property
    def alpha(self) -> float:
        return self._alpha

    @alpha.setter
    def alpha(self, alpha: float) -> None:
        if (alpha < 0) or (alpha > 1):
            raise ValueError(f"Need `alpha` to be in [0, 1]. Got {alpha}.")
        self._alpha = alpha

    @property
    def sigma(self) -> float:
        return self._sigma

    @sigma.setter
    def sigma(self, sigma: float) -> None:
        if sigma <= 0:
            raise ValueError(f"Need `sigma` to be positive. Got {sigma}.")
        self._sigma = sigma

class Loss1(_Loss1):

    def forward(self, phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor) -> Tensor:
        nll =  nll_pmf_cr(phi, idx_durations, events, self.reduction)
        return nll

In [27]:
optimizer = tt.optim.AdamWR(lr=0.8, decoupled_weight_decay=0.01,
                                    cycle_eta_multiplier=0.9)
model = second_multi_task(net, optimizer, alpha=1,
                   duration_index=labtrans1.cuts)
batch_size = 128
lrfind = model.lr_finder(x_train, T_train, batch_size, tolerance=50)
model.optimizer.set_lr(lrfind.get_best_lr()) # The learning rates for the AdamWR optimizer were adjusted using the method proposed by Smith
epochs = 512
callbacks = [tt.callbacks.EarlyStoppingCycle()]
verbose = False
log = model.fit(x_train, T_train, batch_size, epochs, callbacks, verbose, val_data=val)


## Prediction and Evaluation

In [28]:
# Spline interpolation and evaluation of predictive performance
index = torch.arange(T1_test.size)
x = np.linspace(0, num_durations-1, num_durations)

xnew = np.linspace(0, num_durations-1, 10000)
ynew=[]    
for i in range(T2_test[event1_test==1].size):  
    spline_interp = UnivariateSpline(x, (1. - (model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].T
                                               /model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].sum(1)).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)
        
ynew_array = np.stack(ynew, axis=1)
    
surv1 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))
    
ev1 = EvalSurv(surv1, np.array(T2_test[event1_test==1]), np.array(event2_test)[event1_test==1], censor_surv='km')

ynew=[]

for i in range(T2_test[event1_test==0].size):  
    spline_interp = UnivariateSpline(x, (1-((model.predict_pmf_12(x_test).sum(0)[:,event1_test==0]-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].T)
                                            /(1-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].sum(1))).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)

ynew_array = np.stack(ynew, axis=1)
    
surv2 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))

ev2 = EvalSurv(surv2, np.array(T2_test[event1_test==0]), np.array(event2_test)[event1_test==0], censor_surv='km')

time_grid = np.linspace(T2_test.min(), T2_test.max(), 100) 

C_index=(ev1.concordance_td()*T2_test[event1_test==1].size + ev2.concordance_td()*T2_test[event1_test==0].size)/T2_test.size

IBS=(ev1.integrated_brier_score(time_grid)*T2_test[event1_test==1].size + ev2.integrated_brier_score(time_grid)*T2_test[event1_test==0].size)/T2_test.size

print("C-index:", C_index)
print("IBS:", IBS)

C-index: 0.8325406330264106
IBS: 0.09440673971964854
