In [2]:
import os


def result_file_name(args):
    file_name = f"results_{args.fold_range}_{args.seed}"
    if len(args.modalities) > 0:
        file_name += "_"
        file_name += "_".join(args.modalities)
    return file_name

class Args:
    def __init__(self, modalities=["clinical", "miRNA", "mRNA", "WSI"]):
        self.epochs = 20
        self.modality_data_path = {'clinical': './preprocess/preprocessed_data/clinical_kidney.csv',
                                    'mRNA': './preprocess/preprocessed_data/mrna_kidney.csv',
                                    'miRNA': './preprocess/preprocessed_data/mirna_kidney.csv',
                                    'WSI': './preprocess/preprocessed_data/UNI2_features/TCGA-Kidney.pt'}
        self.device = "cuda"
        self.modalities = modalities
        self.input_modality_dim = {'clinical':4, 'mRNA':2746, 'miRNA':743 , 'WSI': 1536}
        self.fold_range = 5
        self.fold = 1
        self.modality_fv_len = 128
        self.batch_size = 128
        self.learning_rate = 1e-4
        self.loss_trade_off = 0.3
        self.loss_mode = 'total'
        self.seed = 24
        self.result_path = f"./logs/attention/{result_file_name(self)}"
        print(f"Result Path: {self.result_path}")
        self.log_file_name = "temp.log"
        self.scheduler_patience = 5
        self.num_workers = 4
        self.num_modalities = len(self.modalities)
        self.split_path = "./splits/kidney_splits"
        self.remove_missing = True

args = Args()

Result Path: ./logs/attention/results_5_24_clinical_miRNA_mRNA_WSI


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
	def __init__(self, m_length, modality_count, device=None):
		"""Attention-based multimodal fusion Part
		Parameters
		----------
		m_length: int
			Weight vector length, corresponding to the modality representation
			length.

		modalities: list
			The list of used modality.

		"""
		super(Attention, self).__init__()
		self.m_length = m_length
		self.device = device
		# contrast a pipeline for different modality weight matrix
		self.pipeline = {}
		self.modality_count = modality_count
		for i in range(self.modality_count):
			self.pipeline[i] = nn.Linear(self.m_length, self.m_length, bias=False).to(self.device)

	def _scale_for_missing_modalities(self, x, out):
		batch_dim = x.shape[1]
		for i in range(batch_dim):
			patient = x[:, i, :]
			zero_dims = 0
			for modality in patient:
				if modality.sum().data == 0:
					zero_dims += 1

			if zero_dims > 0:
				scaler = zero_dims + 1
				out[i, :] = scaler * out[i, :]

		return out

	def to_device(self, device):
		self.to(device)
		for modality in self.pipeline:
			self.pipeline[modality].to(device)

	def to_train(self):
		self.train()
		for modality in self.pipeline:
			self.pipeline[modality].train()

	def to_eval(self):
		self.eval()
		for modality in self.pipeline:
			self.pipeline[modality].eval()

	def forward(self, multimodal_input):
		"""
		multimodal_input: dictionary
			A dictionary of used modality data, like:
			{'clinical':tensor(sample_size, m_length), 'mRNA':tensor(,),}
		"""
		attention_weight = tuple()
		multimodal_features = tuple()
		for i in range(self.modality_count):
			attention_weight += (torch.tanh(self.pipeline[i](multimodal_input[i])),)
			multimodal_features += (multimodal_input[i],)

		# Across feature
		attention_matrix = F.softmax(torch.stack(attention_weight), dim=0)
		fused_vec = torch.sum(torch.stack(multimodal_features) * attention_matrix, dim=0)

		fused_vec = self._scale_for_missing_modalities(torch.stack(multimodal_features), fused_vec)
		return fused_vec

In [4]:
from torch import nn


class AggModel(nn.Module):
    def __init__(self, n_views, n_in_feats, encoder_model):
        super(AggModel, self).__init__()
        self.modality_count = n_views
        self.modality_fv_len = n_in_feats
        
        self.encoder_model = encoder_model
        self.fusion = Attention(m_length=self.modality_fv_len, modality_count=self.modality_count)
        self.hazard_layer = nn.Linear(self.modality_fv_len, 1)
        self.label_layer = nn.Linear(self.modality_fv_len, 2)
	
    def to_device(self, device):
        self.to(device)
        self.fusion.to_device(device)

    def to_train(self):
        self.train()
        self.fusion.to_train()

    def to_eval(self):
        self.eval()
        self.fusion.to_eval()
			
    def forward(self, x_modality, mask):
        representation = self.encoder_model(x_modality)
        missing_rep = []
        for i, rep in enumerate(representation):
            index = torch.ones((rep.shape[0]), dtype=int) * i 
            index = index.to(rep.device)
            modality_mask = mask[:, i].reshape(-1, 1)
            rep =  modality_mask * rep
            missing_rep.append(rep)
        
        representation_dict = {} #used for self supervised loss
        for i, data in enumerate(missing_rep):
            representation_dict[i] = data

        final_representation = self.fusion(representation_dict)
        hazard = self.hazard_layer(final_representation)
        score = F.log_softmax(self.label_layer(final_representation), dim=1)
        return {'hazard':hazard, 'score':score}, representation_dict

In [5]:
from lifelines.utils import concordance_index


def print_log(msg, logger):
    print(msg)
    logger.info(msg)

def evaluation(model, dataloader, loss_fucnt, device):
    running_loss = []
    running_sample_time = torch.FloatTensor().to(device)
    running_sample_event = torch.LongTensor().to(device)
    running_hazard = torch.FloatTensor().to(device)
    
    model.to_eval()
    with torch.no_grad():
        for x_modal, data_label, mask in dataloader:
            for modality in x_modal:
                x_modal[modality] = x_modal[modality].to(device)
            event = data_label[:, 0].to(device)
            time = data_label[:, 1].to(device)
            mask = mask.to(device)
            hazard, representation = model(x_modal, mask)
            loss = loss_fucnt(representation, list(representation.keys()), hazard, event, time)
            hazard = hazard['hazard']
            
            running_loss.append(loss.item())
            running_sample_time = torch.cat((running_sample_time, time.data.float()))
            running_sample_event = torch.cat((running_sample_event, event.long().data))
            running_hazard = torch.cat((running_hazard, hazard.detach()))
            
    eval_loss = torch.mean(torch.tensor(running_loss))
    eval_c_index = concordance_index(running_sample_time.cpu().numpy(), -running_hazard.cpu().numpy(), running_sample_event.cpu().numpy())
    return eval_loss, eval_c_index

In [None]:
import copy
import random
import logging 
import numpy as np
from utils.loss import Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils.dataset import MultiModalDataset
from utils.encoder import EncoderModel
    

def save_model(model, path):
    torch.save(model.state_dict(), path)

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
        
def train_eval_model(args):
    set_seed(args.seed)

    os.makedirs(args.result_path, exist_ok=True)

    logging.basicConfig(format='%(message)s')
    for k in range(args.fold_range):
        args.fold = k
        args.log_file_name = f"result_fold{args.fold}.log"
        args.log_file_name = os.path.join(args.result_path, args.log_file_name)
        if os.path.exists(args.log_file_name):
            os.remove(args.log_file_name)
            
        handler = logging.FileHandler(args.log_file_name, mode='w')
        logger=logging.getLogger() 
        logger.setLevel(logging.DEBUG) 
        logger.addHandler(handler)
        
        train_dataset = MultiModalDataset(args, 'train', args.modalities, args.modality_data_path, remove_missing=args.remove_missing)
        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
        test_dataset = MultiModalDataset(args, 'test', args.modalities, args.modality_data_path)
        test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

        device = torch.device(args.device)
        encoder_model = EncoderModel(args.modalities, args.modality_fv_len, args.input_modality_dim)
        model = AggModel(args.num_modalities, args.modality_fv_len, encoder_model)
        optimizer = Adam(model.parameters(), lr=args.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer=optimizer, mode='max', factor=0.1,
            patience = args.scheduler_patience, verbose=True, threshold=1e-3,
            threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8)
        model.to_device(device)
        
        loss_fucnt = Loss(trade_off=args.loss_trade_off, mode=args.loss_mode)

        for epoch in range(args.epochs):
            running_loss = []
            running_sample_time = torch.FloatTensor().to(device)
            running_sample_event = torch.LongTensor().to(device)
            running_hazard = torch.FloatTensor().to(device)
            
            model.to_train()
            for x_modal, data_label, mask in train_dataloader:
                for modality in x_modal:
                    x_modal[modality] = x_modal[modality].to(device)
                event = data_label[:, 0].to(device)
                time = data_label[:, 1].to(device)
                mask = mask.to(device)
                hazard, representation = model(x_modal, mask)
                loss = loss_fucnt(representation, list(representation.keys()), hazard, event, time)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                hazard = hazard['hazard']
                
                running_loss.append(loss.item())
                running_sample_time = torch.cat((running_sample_time, time.data.float()))
                running_sample_event = torch.cat((running_sample_event, event.long().data))
                running_hazard = torch.cat((running_hazard, hazard.detach()))
                
            train_loss = torch.mean(torch.tensor(running_loss))
            train_c_index = concordance_index(running_sample_time.cpu().numpy(), -running_hazard.cpu().numpy(), running_sample_event.cpu().numpy())
            print_log(f"Epoch {epoch}/{args.epochs}: Training Loss: {train_loss}, C-Index: {train_c_index}", logger)
            scheduler.step(train_c_index)
            if scheduler.patience == scheduler.num_bad_epochs:
                print_log("Reducing the learning rate", logger)
            best_model = copy.deepcopy(model.state_dict())
        print_log("*" * 100, logger)
        
        model.load_state_dict(best_model)
        model_save_path = os.path.join(args.result_path, f"model_fold{args.fold}.pt")
        save_model(model, model_save_path)
        train_loss, train_c_index = evaluation(model, train_dataloader, loss_fucnt, device)
        print_log(f"Final Train Loss: {train_loss}, C-Index: {train_c_index}", logger)
        test_loss, test_c_index = evaluation(model, test_dataloader, loss_fucnt, device)
        print_log(f"Test Loss: {test_loss}, C-Index: {test_c_index}", logger)
        
        logger.removeHandler(handler)

In [7]:
import numpy as np


all_modalities = ["clinical", "miRNA", "mRNA", "WSI"]
all_modalities = np.array(all_modalities)
i = 2 ** len(all_modalities) - 1
selection = F"{i:0{len(all_modalities)}b}"
modality_select = all_modalities[np.array(list(selection)) == '1']
args = Args(modality_select)
train_eval_model(args)

Result Path: ./logs/attention/results_5_24_clinical_miRNA_mRNA_WSI


Epoch 0/20: Training Loss: 4.3014445304870605, C-Index: 0.6043093008477116


Epoch 0/20: Training Loss: 4.3014445304870605, C-Index: 0.6043093008477116


Epoch 1/20: Training Loss: 4.029287815093994, C-Index: 0.7526048848335152


Epoch 1/20: Training Loss: 4.029287815093994, C-Index: 0.7526048848335152


Epoch 2/20: Training Loss: 3.875070571899414, C-Index: 0.7949784774762892


Epoch 2/20: Training Loss: 3.875070571899414, C-Index: 0.7949784774762892


Epoch 3/20: Training Loss: 3.7783586978912354, C-Index: 0.8152540137408424


Epoch 3/20: Training Loss: 3.7783586978912354, C-Index: 0.8152540137408424


Epoch 4/20: Training Loss: 3.631094217300415, C-Index: 0.8322442176952315


Epoch 4/20: Training Loss: 3.631094217300415, C-Index: 0.8322442176952315


Epoch 5/20: Training Loss: 3.6076104640960693, C-Index: 0.8406254121653217


Epoch 5/20: Training Loss: 3.6076104640960693, C-Index: 0.8406254121653217


Epoch 6/20: Training Loss: 3.512700080871582, C-Index: 0.8547499430462464


Epoch 6/20: Training Loss: 3.512700080871582, C-Index: 0.8547499430462464


Epoch 7/20: Training Loss: 3.4668772220611572, C-Index: 0.8627714295991655


Epoch 7/20: Training Loss: 3.4668772220611572, C-Index: 0.8627714295991655


Epoch 8/20: Training Loss: 3.403042793273926, C-Index: 0.8670039927578806


Epoch 8/20: Training Loss: 3.403042793273926, C-Index: 0.8670039927578806


Epoch 9/20: Training Loss: 3.4338958263397217, C-Index: 0.8713084975000299


Epoch 9/20: Training Loss: 3.4338958263397217, C-Index: 0.8713084975000299


Epoch 10/20: Training Loss: 3.3068933486938477, C-Index: 0.8922554885433028


Epoch 10/20: Training Loss: 3.3068933486938477, C-Index: 0.8922554885433028


Epoch 11/20: Training Loss: 3.267671585083008, C-Index: 0.8944017457824247


Epoch 11/20: Training Loss: 3.267671585083008, C-Index: 0.8944017457824247


Epoch 12/20: Training Loss: 3.177345037460327, C-Index: 0.9025910960300236


Epoch 12/20: Training Loss: 3.177345037460327, C-Index: 0.9025910960300236


Epoch 13/20: Training Loss: 3.1471846103668213, C-Index: 0.9062121557295476


Epoch 13/20: Training Loss: 3.1471846103668213, C-Index: 0.9062121557295476


Epoch 14/20: Training Loss: 3.0933921337127686, C-Index: 0.9091737509142577


Epoch 14/20: Training Loss: 3.0933921337127686, C-Index: 0.9091737509142577


Epoch 15/20: Training Loss: 3.082700729370117, C-Index: 0.9130346158918958


Epoch 15/20: Training Loss: 3.082700729370117, C-Index: 0.9130346158918958


Epoch 16/20: Training Loss: 3.003016710281372, C-Index: 0.9202287742353209


Epoch 16/20: Training Loss: 3.003016710281372, C-Index: 0.9202287742353209


Epoch 17/20: Training Loss: 2.9585087299346924, C-Index: 0.9216316351122888


Epoch 17/20: Training Loss: 2.9585087299346924, C-Index: 0.9216316351122888


Epoch 18/20: Training Loss: 2.9251842498779297, C-Index: 0.9275668157456146


Epoch 18/20: Training Loss: 2.9251842498779297, C-Index: 0.9275668157456146


Epoch 19/20: Training Loss: 2.843846082687378, C-Index: 0.9273749715231232
****************************************************************************************************


Epoch 19/20: Training Loss: 2.843846082687378, C-Index: 0.9273749715231232
****************************************************************************************************


Final Train Loss: 2.868039846420288, C-Index: 0.9382501408856009


Final Train Loss: 2.868039846420288, C-Index: 0.9382501408856009


Test Loss: 3.8208250999450684, C-Index: 0.7511691348402182


Test Loss: 3.8208250999450684, C-Index: 0.7511691348402182


Epoch 0/20: Training Loss: 4.13225793838501, C-Index: 0.6526376755634924


Epoch 0/20: Training Loss: 4.13225793838501, C-Index: 0.6526376755634924


Epoch 1/20: Training Loss: 3.951441764831543, C-Index: 0.7614749746609393


Epoch 1/20: Training Loss: 3.951441764831543, C-Index: 0.7614749746609393


Epoch 2/20: Training Loss: 3.865009069442749, C-Index: 0.7805154688932864


Epoch 2/20: Training Loss: 3.865009069442749, C-Index: 0.7805154688932864


Epoch 3/20: Training Loss: 3.726421356201172, C-Index: 0.8073869395241083


Epoch 3/20: Training Loss: 3.726421356201172, C-Index: 0.8073869395241083


Epoch 4/20: Training Loss: 3.662212371826172, C-Index: 0.8213958202615956


Epoch 4/20: Training Loss: 3.662212371826172, C-Index: 0.8213958202615956


Epoch 5/20: Training Loss: 3.6083076000213623, C-Index: 0.8327139340701771


Epoch 5/20: Training Loss: 3.6083076000213623, C-Index: 0.8327139340701771


Epoch 6/20: Training Loss: 3.5348050594329834, C-Index: 0.8416912013128047


Epoch 6/20: Training Loss: 3.5348050594329834, C-Index: 0.8416912013128047


Epoch 7/20: Training Loss: 3.4705350399017334, C-Index: 0.8555070225396979


Epoch 7/20: Training Loss: 3.4705350399017334, C-Index: 0.8555070225396979


Epoch 8/20: Training Loss: 3.428760528564453, C-Index: 0.8637844490564216


Epoch 8/20: Training Loss: 3.428760528564453, C-Index: 0.8637844490564216


Epoch 9/20: Training Loss: 3.4069528579711914, C-Index: 0.8693831748636517


Epoch 9/20: Training Loss: 3.4069528579711914, C-Index: 0.8693831748636517


Epoch 10/20: Training Loss: 3.330921173095703, C-Index: 0.8800014479463295


Epoch 10/20: Training Loss: 3.330921173095703, C-Index: 0.8800014479463295


Epoch 11/20: Training Loss: 3.283829689025879, C-Index: 0.8917056807760992


Epoch 11/20: Training Loss: 3.283829689025879, C-Index: 0.8917056807760992


Epoch 12/20: Training Loss: 3.2486572265625, C-Index: 0.8916453496790385


Epoch 12/20: Training Loss: 3.2486572265625, C-Index: 0.8916453496790385


Epoch 13/20: Training Loss: 3.1632649898529053, C-Index: 0.9016603117911096


Epoch 13/20: Training Loss: 3.1632649898529053, C-Index: 0.9016603117911096


Epoch 14/20: Training Loss: 3.092172622680664, C-Index: 0.9104807181813794


Epoch 14/20: Training Loss: 3.092172622680664, C-Index: 0.9104807181813794


Epoch 15/20: Training Loss: 3.026886224746704, C-Index: 0.9157053911868334


Epoch 15/20: Training Loss: 3.026886224746704, C-Index: 0.9157053911868334


Epoch 16/20: Training Loss: 3.0500519275665283, C-Index: 0.9164293643515614


Epoch 16/20: Training Loss: 3.0500519275665283, C-Index: 0.9164293643515614


Epoch 17/20: Training Loss: 2.9822397232055664, C-Index: 0.9233191756358897


Epoch 17/20: Training Loss: 2.9822397232055664, C-Index: 0.9233191756358897


Epoch 18/20: Training Loss: 2.940094232559204, C-Index: 0.9236690959988416


Epoch 18/20: Training Loss: 2.940094232559204, C-Index: 0.9236690959988416


Epoch 19/20: Training Loss: 2.846381902694702, C-Index: 0.930546841063758
****************************************************************************************************


Epoch 19/20: Training Loss: 2.846381902694702, C-Index: 0.930546841063758
****************************************************************************************************


Final Train Loss: 2.855374336242676, C-Index: 0.9360249046768666


Final Train Loss: 2.855374336242676, C-Index: 0.9360249046768666


Test Loss: 3.278015613555908, C-Index: 0.7666982024597918


Test Loss: 3.278015613555908, C-Index: 0.7666982024597918


Epoch 0/20: Training Loss: 4.260795593261719, C-Index: 0.5996190156706762


Epoch 0/20: Training Loss: 4.260795593261719, C-Index: 0.5996190156706762


Epoch 1/20: Training Loss: 3.989954948425293, C-Index: 0.7665572434945128


Epoch 1/20: Training Loss: 3.989954948425293, C-Index: 0.7665572434945128


Epoch 2/20: Training Loss: 3.869812250137329, C-Index: 0.7903028705611731


Epoch 2/20: Training Loss: 3.869812250137329, C-Index: 0.7903028705611731


Epoch 3/20: Training Loss: 3.7640035152435303, C-Index: 0.8222552355393684


Epoch 3/20: Training Loss: 3.7640035152435303, C-Index: 0.8222552355393684


Epoch 4/20: Training Loss: 3.6837337017059326, C-Index: 0.8376743183016245


Epoch 4/20: Training Loss: 3.6837337017059326, C-Index: 0.8376743183016245


Epoch 5/20: Training Loss: 3.589669942855835, C-Index: 0.8433291800450472


Epoch 5/20: Training Loss: 3.589669942855835, C-Index: 0.8433291800450472


Epoch 6/20: Training Loss: 3.537142515182495, C-Index: 0.8583049791536876


Epoch 6/20: Training Loss: 3.537142515182495, C-Index: 0.8583049791536876


Epoch 7/20: Training Loss: 3.5090131759643555, C-Index: 0.866463794508075


Epoch 7/20: Training Loss: 3.5090131759643555, C-Index: 0.866463794508075


Epoch 8/20: Training Loss: 3.44023060798645, C-Index: 0.8717232951550294


Epoch 8/20: Training Loss: 3.44023060798645, C-Index: 0.8717232951550294


Epoch 9/20: Training Loss: 3.3251190185546875, C-Index: 0.884638424306321


Epoch 9/20: Training Loss: 3.3251190185546875, C-Index: 0.884638424306321


Epoch 10/20: Training Loss: 3.3254482746124268, C-Index: 0.8885560933531413


Epoch 10/20: Training Loss: 3.3254482746124268, C-Index: 0.8885560933531413


Epoch 11/20: Training Loss: 3.2226123809814453, C-Index: 0.8926654526285522


Epoch 11/20: Training Loss: 3.2226123809814453, C-Index: 0.8926654526285522


Epoch 12/20: Training Loss: 3.1889097690582275, C-Index: 0.9113672305554225


Epoch 12/20: Training Loss: 3.1889097690582275, C-Index: 0.9113672305554225


Epoch 13/20: Training Loss: 3.1355011463165283, C-Index: 0.9111875209661188


Epoch 13/20: Training Loss: 3.1355011463165283, C-Index: 0.9111875209661188


Epoch 14/20: Training Loss: 3.0598933696746826, C-Index: 0.9194062395169407


Epoch 14/20: Training Loss: 3.0598933696746826, C-Index: 0.9194062395169407


Epoch 15/20: Training Loss: 3.0465774536132812, C-Index: 0.9185556141275698


Epoch 15/20: Training Loss: 3.0465774536132812, C-Index: 0.9185556141275698


Epoch 16/20: Training Loss: 2.9773664474487305, C-Index: 0.9201610197920161


Epoch 16/20: Training Loss: 2.9773664474487305, C-Index: 0.9201610197920161


Epoch 17/20: Training Loss: 2.9429800510406494, C-Index: 0.9258997460104471


Epoch 17/20: Training Loss: 2.9429800510406494, C-Index: 0.9258997460104471


Epoch 18/20: Training Loss: 2.9185569286346436, C-Index: 0.9284995447357071


Epoch 18/20: Training Loss: 2.9185569286346436, C-Index: 0.9284995447357071


Epoch 19/20: Training Loss: 2.8555781841278076, C-Index: 0.9331719940576029
****************************************************************************************************


Epoch 19/20: Training Loss: 2.8555781841278076, C-Index: 0.9331719940576029
****************************************************************************************************


Final Train Loss: 2.844204902648926, C-Index: 0.9416902285905976


Final Train Loss: 2.844204902648926, C-Index: 0.9416902285905976


Test Loss: 3.456570863723755, C-Index: 0.7471106758080314


Test Loss: 3.456570863723755, C-Index: 0.7471106758080314


Epoch 0/20: Training Loss: 4.275538921356201, C-Index: 0.5945907422584898


Epoch 0/20: Training Loss: 4.275538921356201, C-Index: 0.5945907422584898


Epoch 1/20: Training Loss: 4.058098793029785, C-Index: 0.7344308638895818


Epoch 1/20: Training Loss: 4.058098793029785, C-Index: 0.7344308638895818


Epoch 2/20: Training Loss: 3.9341344833374023, C-Index: 0.7635082968083716


Epoch 2/20: Training Loss: 3.9341344833374023, C-Index: 0.7635082968083716


Epoch 3/20: Training Loss: 3.8427417278289795, C-Index: 0.7985603819976481


Epoch 3/20: Training Loss: 3.8427417278289795, C-Index: 0.7985603819976481


Epoch 4/20: Training Loss: 3.7272441387176514, C-Index: 0.8214018458468446


Epoch 4/20: Training Loss: 3.7272441387176514, C-Index: 0.8214018458468446


Epoch 5/20: Training Loss: 3.6577365398406982, C-Index: 0.8309399090142418


Epoch 5/20: Training Loss: 3.6577365398406982, C-Index: 0.8309399090142418


Epoch 6/20: Training Loss: 3.5728695392608643, C-Index: 0.8482462079369039


Epoch 6/20: Training Loss: 3.5728695392608643, C-Index: 0.8482462079369039


Epoch 7/20: Training Loss: 3.4850895404815674, C-Index: 0.860504341422276


Epoch 7/20: Training Loss: 3.4850895404815674, C-Index: 0.860504341422276


Epoch 8/20: Training Loss: 3.470163106918335, C-Index: 0.8685220159403247


Epoch 8/20: Training Loss: 3.470163106918335, C-Index: 0.8685220159403247


Epoch 9/20: Training Loss: 3.406080961227417, C-Index: 0.8764327881314661


Epoch 9/20: Training Loss: 3.406080961227417, C-Index: 0.8764327881314661


Epoch 10/20: Training Loss: 3.317143201828003, C-Index: 0.8888334580527147


Epoch 10/20: Training Loss: 3.317143201828003, C-Index: 0.8888334580527147


Epoch 11/20: Training Loss: 3.2726314067840576, C-Index: 0.898383399256435


Epoch 11/20: Training Loss: 3.2726314067840576, C-Index: 0.898383399256435


Epoch 12/20: Training Loss: 3.1999595165252686, C-Index: 0.9092161683830429


Epoch 12/20: Training Loss: 3.1999595165252686, C-Index: 0.9092161683830429


Epoch 13/20: Training Loss: 3.191422700881958, C-Index: 0.909584387509057


Epoch 13/20: Training Loss: 3.191422700881958, C-Index: 0.909584387509057


Epoch 14/20: Training Loss: 3.126690149307251, C-Index: 0.9127795792799535


Epoch 14/20: Training Loss: 3.126690149307251, C-Index: 0.9127795792799535


Epoch 15/20: Training Loss: 3.1163415908813477, C-Index: 0.9160816733777571


Epoch 15/20: Training Loss: 3.1163415908813477, C-Index: 0.9160816733777571


Epoch 16/20: Training Loss: 3.030937910079956, C-Index: 0.9249189324020953


Epoch 16/20: Training Loss: 3.030937910079956, C-Index: 0.9249189324020953


Epoch 17/20: Training Loss: 3.031996726989746, C-Index: 0.9219494233213366


Epoch 17/20: Training Loss: 3.031996726989746, C-Index: 0.9219494233213366


Epoch 18/20: Training Loss: 2.98557448387146, C-Index: 0.9291118792241266


Epoch 18/20: Training Loss: 2.98557448387146, C-Index: 0.9291118792241266


Epoch 19/20: Training Loss: 2.9041826725006104, C-Index: 0.9368088467614534
****************************************************************************************************


Epoch 19/20: Training Loss: 2.9041826725006104, C-Index: 0.9368088467614534
****************************************************************************************************


Final Train Loss: 2.9003746509552, C-Index: 0.9465250804736961


Final Train Loss: 2.9003746509552, C-Index: 0.9465250804736961


Test Loss: 3.65384578704834, C-Index: 0.7943736085812588


Test Loss: 3.65384578704834, C-Index: 0.7943736085812588


Epoch 0/20: Training Loss: 4.23887300491333, C-Index: 0.5789870584801012


Epoch 0/20: Training Loss: 4.23887300491333, C-Index: 0.5789870584801012


Epoch 1/20: Training Loss: 4.000083923339844, C-Index: 0.7275712756641043


Epoch 1/20: Training Loss: 4.000083923339844, C-Index: 0.7275712756641043


Epoch 2/20: Training Loss: 3.887723922729492, C-Index: 0.767417534299893


Epoch 2/20: Training Loss: 3.887723922729492, C-Index: 0.767417534299893


Epoch 3/20: Training Loss: 3.7435786724090576, C-Index: 0.7924369952320716


Epoch 3/20: Training Loss: 3.7435786724090576, C-Index: 0.7924369952320716


Epoch 4/20: Training Loss: 3.682823896408081, C-Index: 0.8084192857837891


Epoch 4/20: Training Loss: 3.682823896408081, C-Index: 0.8084192857837891


Epoch 5/20: Training Loss: 3.6254701614379883, C-Index: 0.822856864843826


Epoch 5/20: Training Loss: 3.6254701614379883, C-Index: 0.822856864843826


Epoch 6/20: Training Loss: 3.5590362548828125, C-Index: 0.8424394278485939


Epoch 6/20: Training Loss: 3.5590362548828125, C-Index: 0.8424394278485939


Epoch 7/20: Training Loss: 3.4767444133758545, C-Index: 0.8496399727546949


Epoch 7/20: Training Loss: 3.4767444133758545, C-Index: 0.8496399727546949


Epoch 8/20: Training Loss: 3.4670979976654053, C-Index: 0.8593826019266323


Epoch 8/20: Training Loss: 3.4670979976654053, C-Index: 0.8593826019266323


Epoch 9/20: Training Loss: 3.3895041942596436, C-Index: 0.8700252992118322


Epoch 9/20: Training Loss: 3.3895041942596436, C-Index: 0.8700252992118322


Epoch 10/20: Training Loss: 3.3113410472869873, C-Index: 0.8766298530699621


Epoch 10/20: Training Loss: 3.3113410472869873, C-Index: 0.8766298530699621


Epoch 11/20: Training Loss: 3.2960147857666016, C-Index: 0.880388245596964


Epoch 11/20: Training Loss: 3.2960147857666016, C-Index: 0.880388245596964


Epoch 12/20: Training Loss: 3.2484045028686523, C-Index: 0.8921864357302715


Epoch 12/20: Training Loss: 3.2484045028686523, C-Index: 0.8921864357302715


Epoch 13/20: Training Loss: 3.1844842433929443, C-Index: 0.895397489539749


Epoch 13/20: Training Loss: 3.1844842433929443, C-Index: 0.895397489539749


Epoch 14/20: Training Loss: 3.1444082260131836, C-Index: 0.8973800720054491


Epoch 14/20: Training Loss: 3.1444082260131836, C-Index: 0.8973800720054491


Epoch 15/20: Training Loss: 3.090390920639038, C-Index: 0.9073537997470079


Epoch 15/20: Training Loss: 3.090390920639038, C-Index: 0.9073537997470079


Epoch 16/20: Training Loss: 3.0528109073638916, C-Index: 0.916147708475236


Epoch 16/20: Training Loss: 3.0528109073638916, C-Index: 0.916147708475236


Epoch 17/20: Training Loss: 2.9991447925567627, C-Index: 0.9181302909409361


Epoch 17/20: Training Loss: 2.9991447925567627, C-Index: 0.9181302909409361


Epoch 18/20: Training Loss: 2.9648277759552, C-Index: 0.9216819110635399


Epoch 18/20: Training Loss: 2.9648277759552, C-Index: 0.9216819110635399


Epoch 19/20: Training Loss: 2.945108652114868, C-Index: 0.9289919237131459
****************************************************************************************************


Epoch 19/20: Training Loss: 2.945108652114868, C-Index: 0.9289919237131459
****************************************************************************************************


Final Train Loss: 2.892979383468628, C-Index: 0.9331516979663326


Final Train Loss: 2.892979383468628, C-Index: 0.9331516979663326


Test Loss: 3.5360405445098877, C-Index: 0.7963509030593439


Test Loss: 3.5360405445098877, C-Index: 0.7963509030593439
