# Introduction to the second multi-task network

In this notebook we introduce the use of the second multi-task network through an example dataset (TCGA-BRCA).

## Import

In [1]:
# Import related packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torchtuples as tt
import torch.nn.functional as F
import torch.nn as nn
import math

from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
from pycox.models.utils import pad_col
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox import models
from typing import Tuple
from torch import Tensor
from pycox.models import utils
from torchtuples import TupleTree
from scipy.interpolate import UnivariateSpline

import sys
sys.path.insert(0, '/')
from eval import EvalSurv


In [2]:
# set some seeds to make this reproducable
np.random.seed(123456)
_ = torch.manual_seed(123456)

In [3]:
# Import data
file1 = 'BRCA1.txt'
file2 = 'BRCA2.txt'

with open(file1, 'r', encoding='utf-8') as f1:
    lines1 = f1.readlines()

with open(file2, 'r', encoding='utf-8') as f2:
    lines2 = f2.readlines()

combined_lines = lines1 + lines2

output_file = 'BRCA.txt'

with open(output_file, 'w', encoding='utf-8') as outfile:
    outfile.writelines(combined_lines)
    
file_path = 'BRCA.txt'

df = pd.read_csv(file_path, sep='\t', header=0)

# Leaving only the genetic data associated with the 186KEGG pathway
def parse_kegg_file(file_path):
    kegg_data = {}
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith('KEGG'):
                parts = line.strip().split('\t')
                pathway_name = parts[0]
                genes = parts[2:] if len(parts) > 2 else []
                kegg_data[pathway_name] = genes
    return kegg_data

file_path = 'kegg_legacy.txt'
kegg_pathways = parse_kegg_file(file_path)

all_genes = set()

for genes in kegg_pathways.values():
    all_genes.update(genes)

genes_columns = [gene for gene in all_genes if gene in df.columns]

filtered_df_train = df[genes_columns]

filtered_kegg_pathways = {}
for pathway, genes in kegg_pathways.items():
    filtered_genes = [gene for gene in genes if gene in genes_columns]
    filtered_kegg_pathways[pathway] = filtered_genes

concatenated_df = pd.concat([filtered_df_train, df.iloc[:,-4:]], axis=1)

df_train = concatenated_df

df_train.head()

Unnamed: 0,MGST2,SGK1,PARD6B,DRD3,GNG10,TAAR2,FZR1,CNTFR,E2F5,IFNA1,...,SMC1B,IL1B,L2HGDH,PPARGC1A,VCAM1,MADCAM1,event2,T2,event1,T1
0,8.64,8.29,8.74,0.0,8.01,0.0,10.76,2.36,6.99,0.0,...,0.76,5.43,8.17,2.04,8.43,3.37,0,4047,1,1808
1,9.36,9.24,8.23,0.0,8.57,0.0,10.87,4.16,6.45,0.0,...,7.42,4.64,8.19,2.09,9.61,3.36,0,4005,0,4005
2,9.89,9.58,7.82,0.0,9.13,0.0,10.74,5.18,7.5,0.0,...,0.93,3.2,8.03,2.47,9.01,0.0,0,1474,0,1474
3,9.59,9.82,8.05,0.0,8.43,0.0,10.36,5.16,7.17,0.0,...,1.16,7.06,7.32,6.28,9.78,1.16,0,1448,0,1448
4,8.07,10.4,8.43,0.0,8.62,0.0,9.76,4.75,8.12,0.0,...,1.99,5.17,8.04,4.98,7.48,0.0,0,348,0,348


In [4]:
# Sparse connection layer mask matrix
num_pathways = len(filtered_kegg_pathways)
num_genes = len(genes_columns)

mask = torch.zeros(num_pathways, num_genes, dtype=torch.bool)

gene_to_index = {gene: idx for idx, gene in enumerate(genes_columns)}

for pathway_idx, (pathway, genes) in enumerate(filtered_kegg_pathways.items()):
    for gene in genes:
        if gene in gene_to_index:
            gene_idx = gene_to_index[gene]
            mask[pathway_idx, gene_idx] = 1  


In [5]:
# View all gene names
xx = df_train.drop(columns=['event2','T2','event1','T1'])
xx.columns.tolist()

['MGST2',
 'SGK1',
 'PARD6B',
 'DRD3',
 'GNG10',
 'TAAR2',
 'FZR1',
 'CNTFR',
 'E2F5',
 'IFNA1',
 'TAF4B',
 'DET1',
 'MYD88',
 'MAP3K11',
 'ABCD4',
 'NOG',
 'CRLS1',
 'CAMK2A',
 'TTK',
 'IL21R',
 'CHAT',
 'CBLC',
 'SARDH',
 'CHRM2',
 'PDHB',
 'RPL39',
 'ICAM2',
 'HMGCS2',
 'MAGI1',
 'COL2A1',
 'STAT4',
 'DHRS3',
 'TFDP1',
 'UGCG',
 'OR13C8',
 'CYP11B1',
 'BUB3',
 'CASP5',
 'GABRA6',
 'OR5AS1',
 'RPS6KA3',
 'RAB11FIP2',
 'ACOX1',
 'E2F1',
 'MYH2',
 'OXCT2',
 'CDKN2D',
 'CADM3',
 'GABRG3',
 'TAS2R5',
 'CACNB3',
 'PGK1',
 'DDC',
 'OR52K2',
 'CYP2A7',
 'ACVR1C',
 'DSC2',
 'CHIT1',
 'MOS',
 'SRD5A1',
 'GAMT',
 'MRAS',
 'LRAT',
 'NDUFB7',
 'ATP6V0A2',
 'SPCS1',
 'AACS',
 'ALS2',
 'MFNG',
 'OR8B12',
 'CCNB2',
 'MAP4K1',
 'JAG1',
 'SLC18A2',
 'OR6C65',
 'ATP6V0D2',
 'DHRS4L2',
 'OR8H3',
 'DHX8',
 'DAPK1',
 'GTF2B',
 'SNW1',
 'CCND1',
 'ELOVL6',
 'GNAO1',
 'PRPF19',
 'TNFRSF10C',
 'DUSP3',
 'VWF',
 'GPX1',
 'TUBA1C',
 'EI24',
 'AP3M1',
 'NCR3',
 'GOSR2',
 'TNR',
 'P2RX4',
 'IGF2R',
 'NR1H3',
 '

In [6]:
# Train/test/validation split
df_test = df_train.sample(frac=0.2)
df_train = df_train.drop(df_test.index)
df_val = df_train.sample(frac=0.2)
df_train = df_train.drop(df_val.index)

df_train

Unnamed: 0,MGST2,SGK1,PARD6B,DRD3,GNG10,TAAR2,FZR1,CNTFR,E2F5,IFNA1,...,SMC1B,IL1B,L2HGDH,PPARGC1A,VCAM1,MADCAM1,event2,T2,event1,T1
0,8.64,8.29,8.74,0.00,8.01,0.0,10.76,2.36,6.99,0.0,...,0.76,5.43,8.17,2.04,8.43,3.37,0,4047,1,1808
1,9.36,9.24,8.23,0.00,8.57,0.0,10.87,4.16,6.45,0.0,...,7.42,4.64,8.19,2.09,9.61,3.36,0,4005,0,4005
2,9.89,9.58,7.82,0.00,9.13,0.0,10.74,5.18,7.50,0.0,...,0.93,3.20,8.03,2.47,9.01,0.00,0,1474,0,1474
5,9.35,9.57,10.09,0.00,8.40,0.0,10.35,3.01,7.36,0.0,...,4.02,6.36,7.33,3.60,8.64,1.46,0,1477,0,1477
7,9.44,7.42,9.47,0.00,9.01,0.0,10.36,3.58,8.04,0.0,...,0.00,5.42,8.38,4.76,4.11,0.00,0,303,0,303
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1074,8.98,9.25,9.93,0.00,8.54,0.0,10.39,5.08,7.03,0.0,...,0.56,4.29,6.97,4.25,6.58,2.84,0,347,0,347
1077,8.86,10.93,8.50,0.48,9.27,0.0,10.12,8.63,7.29,0.0,...,0.48,5.80,7.69,4.99,10.37,2.52,0,467,0,467
1078,9.46,11.30,6.85,0.00,8.80,0.0,9.81,7.98,7.20,0.0,...,4.90,4.59,7.37,6.34,9.90,2.11,0,488,0,488
1079,10.33,10.89,7.64,0.00,9.51,0.0,10.22,2.37,6.68,0.0,...,3.03,6.69,6.32,4.95,9.04,0.68,0,3287,1,181


In [7]:
# Covariate preprocessing
cols_standardize =  xx.columns.tolist()

standardize = [([col], StandardScaler()) for col in cols_standardize]

x_mapper = DataFrameMapper(standardize)

x_train = x_mapper.fit_transform(df_train).astype('float32')
x_val = x_mapper.transform(df_val).astype('float32')
x_test = x_mapper.transform(df_test).astype('float32')

In [8]:
# Discretization of survival times
class LabTransform(LabTransDiscreteTime):
    def transform(self, durations, events):
        durations, is_event = super().transform(durations, events > 0)
        events[is_event == 0] = 0
        return durations, events.astype('int64')
        
num_durations = 10

labtrans1 = LabTransform(num_durations, scheme='equidistant')
get_target1 = lambda df: (df['T1'].values, df['event1'].values)

T1_train = labtrans1.fit_transform(*get_target1(df_train))
T1_val = labtrans1.transform(*get_target1(df_val))
T1_test, event1_test = labtrans1.transform(*get_target1(df_test))

labtrans2 = LabTransform(num_durations, scheme='equidistant')
get_target2 = lambda df: (df['T2'].values, df['event2'].values)

T2_train = labtrans2.fit_transform(*get_target2(df_train))
T2_val = labtrans2.transform(*get_target2(df_val))
# Discretization is not required because the prediction time is already a continuous value after spline interpolation when evaluated on the test set
T2_test, event2_test = get_target2(df_test)

In [9]:
# Package the data into the input format required by the network later
def index_to_1d(i, j, n):
    """
    Converts a 2D index (i, j) from sets T1 and T2, each with elements ranging from 1 to n, to a 1D index.
    """
    return i * n + j

T_train = list(T1_train)
T1_train_list = list(T1_train)
T2_train_list = list(T2_train)
T_train[0] = index_to_1d(T1_train_list[0], T2_train_list[0], num_durations)
T_train[1] = index_to_1d(T1_train_list[1], T2_train_list[1], 2)
T_train = tuple(T_train)

T_val = list(T1_val)
T1_val_list = list(T1_val)
T2_val_list = list(T2_val)
T_val[0] = index_to_1d(T1_val_list[0], T2_val_list[0], num_durations)
T_val[1] = index_to_1d(T1_val_list[1], T2_val_list[1], 2)
T_val = tuple(T_val)
val = (x_val, T_val)

## Neural net

In [10]:
# Neural network architecture
class MaskedLinearLayer(nn.Module):
    def __init__(self, in_features, out_features, bias=True, mask=None):
        super(MaskedLinearLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if bias:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
        self.mask = mask if mask is not None else torch.ones(out_features, in_features)

    def forward(self, input):
        masked_weight = self.weight * self.mask
        return nn.functional.linear(input, masked_weight, self.bias)

class Multivariate_survival(torch.nn.Module):

    def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                 out_features, batch_norm=True, dropout=None, mask=None):
        super().__init__()
        self.first_layer = MaskedLinearLayer(in_features, num_nodes_shared[0], mask=mask)
        self.shared_net = tt.practical.MLPVanilla(
            num_nodes_shared[0], num_nodes_shared[1:-1], num_nodes_shared[-1],
            batch_norm, dropout,
        )
        self.risk_nets = torch.nn.ModuleList()
        for _ in range(num_T1):
            net = tt.practical.MLPVanilla(
                num_nodes_shared[-1], num_nodes_indiv, out_features,
                batch_norm, dropout,
            )
            self.risk_nets.append(net)

    def forward(self, input):
        out = self.first_layer(input)
        out = self.shared_net(out)
        out = [net(out) for net in self.risk_nets]
        out = torch.stack(out, dim=1)
        return out



In [11]:
in_features = x_train.shape[1]
num_nodes_shared = [num_pathways, 32, 32]
num_nodes_indiv = [32, 32]
num_T1 = num_durations 
out_features = len(labtrans2.cuts)
batch_norm = True
dropout = 0.7

net = Multivariate_survival(in_features, num_nodes_shared, num_nodes_indiv, num_T1,
                       out_features, batch_norm, dropout, mask=mask)

## Training



In [12]:
class second_multi_task(tt.Model):

    def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
        self.duration_index = duration_index
        if loss is None:
            loss = Loss1(alpha, sigma)
        super().__init__(net, loss, optimizer, device)

    @property
    def duration_index(self):
        
        return self._duration_index

    @duration_index.setter
    def duration_index(self, val):
        self._duration_index = val

    def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
        dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
                                             make_dataset=models.data.DeepHitDataset)
        return dataloader
    
    def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
        dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
        return dataloader

    def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):

        surv = self.predict_pmf_1_cif_2(input, batch_size, True, eval_, True, num_workers)
        return pd.DataFrame(surv, self.duration_index)

    def predict_surv_2_condpmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        cif = self.predict_pmf_1_cif_2(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = self.predict_pmf_1(input, batch_size, False, eval_, to_cpu, num_workers)
        condsurv = 1. - cif/pmf
        return tt.utils.array_or_tensor(condsurv, numpy, input)
        
    def predict_pmf_1_cif_2(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        cif = pmf.cumsum(1)
        return tt.utils.array_or_tensor(cif, numpy, input) 

    def predict_pmf_1(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        pmf12 = self.predict_pmf_12(input, batch_size, False, eval_, to_cpu, num_workers)
        pmf = pmf12.sum(1)
        return tt.utils.array_or_tensor(pmf, numpy, input)
    
    def predict_pmf_12(self, input, batch_size=8224, numpy=None, eval_=True,
                     to_cpu=False, num_workers=0):
 
        preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
        pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1]
        pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2)
        return tt.utils.array_or_tensor(pmf, numpy, input)

def _reduction(loss: Tensor, reduction: str = 'mean') -> Tensor:
    if reduction == 'none':
        return loss
    elif reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    raise ValueError(f"`reduction` = {reduction} is not valid. Use 'none', 'mean' or 'sum'.")

def _diff_cdf_at_time_i(pmf: Tensor, y: Tensor) -> Tensor:

    n = pmf.shape[0]
    ones = torch.ones((n, 1), device=pmf.device)
    r = pmf.cumsum(1).matmul(y.transpose(0, 1))
    diag_r = r.diag().view(1, -1)
    r = ones.matmul(diag_r) - r
    return r.transpose(0, 1)

def _rank_loss_deephit(pmf: Tensor, y: Tensor, rank_mat: Tensor, sigma: float,
                       reduction: str = 'mean') -> Tensor:

    r = _diff_cdf_at_time_i(pmf, y)
    loss = rank_mat * torch.exp(-r/sigma)
    loss = loss.mean(1, keepdim=True)
    return _reduction(loss, reduction)

def index_from_1d(k, n):

    i = k // n
    j = k % n
    return (i, j)
    
def nll_pmf_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, reduction: str = 'mean',
               epsilon: float = 1e-7) -> Tensor:

    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()
    
    idx_durations1, idx_durations2 = index_from_1d(idx_durations.view(-1), num_durations)
    batch_size = phi.size(0)
    sm = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    index = torch.arange(batch_size)
    part1 = sm[index, idx_durations1, idx_durations2].relu().add(epsilon).log().mul(event_03)
    part2 = (sm[index, idx_durations1, :].sum(1) - sm.cumsum(2)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_02)
    part3 = (sm[index, :, idx_durations2].sum(1) - sm.cumsum(1)[index, idx_durations1, idx_durations2]).relu().add(epsilon).log().mul(event_01)
    part4 = (1 - sm.cumsum(2)[index, :, idx_durations2].sum(1) + sm.cumsum(1).cumsum(2)[index, idx_durations1, idx_durations2] - sm.cumsum(1)[index, idx_durations1, :].sum(1)).relu().add(epsilon).log().mul(event_00)     
     
    loss = - part1.add(part2).add(part3).add(part4)
    return _reduction(loss, reduction)



def rank_loss_deephit_cr(phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor,
                         sigma: float, reduction: str = 'mean') -> Tensor:

    idx_durations = idx_durations.view(-1)
    events = events.view(-1)
    event_00 = (events == 0).float()
    event_01 = (events == 1).float()
    event_02 = (events == 2).float()
    event_03 = (events == 3).float()

    batch_size = phi.size(0)
    pmf = utils.pad_col(phi.view(batch_size, -1)).softmax(1)[:, :-1].view(phi.shape)
    y = torch.zeros_like(pmf)
    y[torch.arange(batch_size), :, idx_durations] = 1.

    loss = []
    for i in range(4):
        rank_loss_i = _rank_loss_deephit(pmf[:, i, :], y[:, i, :], rank_mat, sigma, 'none')
        loss.append(rank_loss_i.view(-1) * (events == i).float())

    if reduction == 'none':
        return sum(loss)
    elif reduction == 'mean':
        return sum([lo.mean() for lo in loss])
    elif reduction == 'sum':
        return sum([lo.sum() for lo in loss])
    return _reduction(loss, reduction)

class _Loss(torch.nn.Module):

    def __init__(self, reduction: str = 'mean') -> None:
        super().__init__()
        self.reduction = reduction

class _Loss1(_Loss):

    def __init__(self, alpha: float, sigma: float, reduction: str = 'mean') -> None:
        super().__init__(reduction)
        self.alpha = alpha
        self.sigma = sigma

    @property
    def alpha(self) -> float:
        return self._alpha

    @alpha.setter
    def alpha(self, alpha: float) -> None:
        if (alpha < 0) or (alpha > 1):
            raise ValueError(f"Need `alpha` to be in [0, 1]. Got {alpha}.")
        self._alpha = alpha

    @property
    def sigma(self) -> float:
        return self._sigma

    @sigma.setter
    def sigma(self, sigma: float) -> None:
        if sigma <= 0:
            raise ValueError(f"Need `sigma` to be positive. Got {sigma}.")
        self._sigma = sigma

class Loss1(_Loss1):

    def forward(self, phi: Tensor, idx_durations: Tensor, events: Tensor, rank_mat: Tensor) -> Tensor:
        nll =  nll_pmf_cr(phi, idx_durations, events, self.reduction)
        return nll

In [13]:
optimizer = tt.optim.AdamWR(lr=0.8, decoupled_weight_decay=0.01,
                                    cycle_eta_multiplier=0.9)
model = second_multi_task(net, optimizer, alpha=1,
                   duration_index=labtrans1.cuts)
batch_size = 128
lrfind = model.lr_finder(x_train, T_train, batch_size, tolerance=50)
model.optimizer.set_lr(lrfind.get_best_lr()) # The learning rates for the AdamWR optimizer were adjusted using the method proposed by Smith
epochs = 512
callbacks = [tt.callbacks.EarlyStoppingCycle()]
verbose = False
log = model.fit(x_train, T_train, batch_size, epochs, callbacks, verbose, val_data=val)


	add(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1630.)
  p.data = p.data.add(-weight_decay * eta, p.data)


## Prediction and Evaluation

In [14]:
# Spline interpolation and evaluation of predictive performance
index = torch.arange(T1_test.size)
x = np.linspace(0, num_durations-1, num_durations)

xnew = np.linspace(0, num_durations-1, 10000)
ynew=[]    
for i in range(T2_test[event1_test==1].size):  
    spline_interp = UnivariateSpline(x, (1. - (model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].T
                                               /model.predict_pmf_12(x_test)[T1_test[event1_test==1],:,index[event1_test==1]].sum(1)).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)
        
ynew_array = np.stack(ynew, axis=1)
    
surv1 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))
    
ev1 = EvalSurv(surv1, np.array(T2_test[event1_test==1]), np.array(event2_test)[event1_test==1], censor_surv='km')

ynew=[]

for i in range(T2_test[event1_test==0].size):  
    spline_interp = UnivariateSpline(x, (1-((model.predict_pmf_12(x_test).sum(0)[:,event1_test==0]-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].T)
                                            /(1-model.predict_pmf_12(x_test).cumsum(0)[T1_test[event1_test==0],:,index[event1_test==0]].sum(1))).cumsum(0))[:,i], s=0)
    y_spline = spline_interp(xnew)
    ynew.append(y_spline)

ynew_array = np.stack(ynew, axis=1)
    
surv2 = pd.DataFrame(ynew_array, np.linspace(0, labtrans2.cuts.max(), 10000))

ev2 = EvalSurv(surv2, np.array(T2_test[event1_test==0]), np.array(event2_test)[event1_test==0], censor_surv='km')

time_grid = np.linspace(T2_test.min(), T2_test.max(), 100) 

C_index=(ev1.concordance_td()*T2_test[event1_test==1].size + ev2.concordance_td()*T2_test[event1_test==0].size)/T2_test.size

IBS=(ev1.integrated_brier_score(time_grid)*T2_test[event1_test==1].size + ev2.integrated_brier_score(time_grid)*T2_test[event1_test==0].size)/T2_test.size

print("C-index:", C_index)
print("IBS:", IBS)

C-index: 0.8702086189011654
IBS: 0.09567708604324636
