In [2]:
import os


def result_file_name(args):
    file_name = f"results_{args.fold_range}_{args.seed}"
    if len(args.modalities) > 0:
        file_name += "_"
        file_name += "_".join(args.modalities)
    return file_name

class Args:
    def __init__(self, modalities=["clinical", "miRNA", "mRNA", "WSI"]):
        self.epochs = 20
        self.modality_data_path = {'clinical': './preprocess/preprocessed_data/clinical_kidney.csv',
                                    'mRNA': './preprocess/preprocessed_data/mrna_kidney.csv',
                                    'miRNA': './preprocess/preprocessed_data/mirna_kidney.csv',
                                    'WSI': './preprocess/preprocessed_data/UNI2_features/TCGA-Kidney.pt'}
        self.device = "cuda"
        self.modalities = modalities
        self.input_modality_dim = {'clinical':4, 'mRNA':2746, 'miRNA':743 , 'WSI': 1536}
        self.fold_range = 5
        self.fold = 1
        self.modality_fv_len = 128
        self.batch_size = 128
        self.learning_rate = 1e-4
        self.loss_trade_off = 0.3
        self.loss_mode = 'total'
        self.seed = 24
        self.result_path = f"./logs/graph_att/{result_file_name(self)}"
        print(f"Result Path: {self.result_path}")
        self.log_file_name = "temp.log"
        self.scheduler_patience = 5
        self.num_workers = 4
        self.num_modalities = len(self.modalities)
        self.split_path = "./splits/kidney_splits"
        self.remove_missing = True

args = Args()

Result Path: ./logs/graph_att/results_5_24_clinical_miRNA_mRNA_WSI


In [3]:
import torch
import torch.nn as nn
from torch_geometric.nn import GAT
from torch_geometric.utils import from_scipy_sparse_matrix, to_scipy_sparse_matrix
import scipy.sparse as sp
from torch_geometric.nn import GATv2Conv
import torch.nn.functional as F


class GraphAttention(nn.Module):
    def __init__(self, num_modalities, modality_dim, hidden_dim, output_dim, device):
        super(GraphAttention, self).__init__()
        self.input_dim = modality_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.gat1 = GATv2Conv(self.input_dim, self.hidden_dim)
        self.num_modalities = num_modalities
        adjacency_matrix = torch.ones([self.num_modalities, self.num_modalities], dtype=torch.float)
        self.adjacency_matrix = sp.coo_matrix(adjacency_matrix.numpy())
        edge_index, edge_weight = from_scipy_sparse_matrix(self.adjacency_matrix)
        self.edge_index = edge_index.to(device)
        
    def forward(self, x):
        result= []
        attns = []
        for data in x:
            x_fv, att_w = self.gat1(data, self.edge_index, return_attention_weights = True)
            adj1 = to_scipy_sparse_matrix(att_w[0].detach().cpu(),att_w[1].detach().cpu(), self.num_modalities)        
            result.append(x_fv)
            attns.append(adj1.todense())
        return torch.stack(result), attns

In [4]:
from torch import nn


class AggModel(nn.Module):
    def __init__(self, n_views, n_in_feats, encoder_model, device):
        super(AggModel, self).__init__()
        self.output_dim = n_in_feats * n_views
        self.modality_count = n_views
        self.modality_fv_len = n_in_feats
        
        self.encoder_model = encoder_model
        self.fusion = GraphAttention(self.modality_count, self.modality_fv_len, self.modality_fv_len, self.modality_fv_len, device)
        self.hazard_layer = nn.Linear(self.output_dim, 1)
        self.label_layer = nn.Linear(self.output_dim, 2)
	
    def to_device(self, device):
        self.to(device)

    def to_train(self):
        self.train()

    def to_eval(self):
        self.eval()
			
    def forward(self, x_modality, mask):
        representation = self.encoder_model(x_modality)
        missing_rep = []
        for i, rep in enumerate(representation):
            index = torch.ones((rep.shape[0]), dtype=int) * i 
            index = index.to(rep.device)
            modality_mask = mask[:, i].reshape(-1, 1)
            rep =  modality_mask * rep
            missing_rep.append(rep)
        
        representation_dict = {} #used for self supervised loss
        for i, data in enumerate(missing_rep):
            representation_dict[i] = data

        final_representation, _ = self.fusion(torch.stack(missing_rep, 1))
        final_representation = final_representation.reshape(-1, self.output_dim)

        hazard = self.hazard_layer(final_representation)
        score = F.log_softmax(self.label_layer(final_representation), dim=1)
        return {'hazard':hazard, 'score':score}, representation_dict

In [5]:
from lifelines.utils import concordance_index


def print_log(msg, logger):
    print(msg)
    logger.info(msg)

def evaluation(model, dataloader, loss_fucnt, device):
    running_loss = []
    running_sample_time = torch.FloatTensor().to(device)
    running_sample_event = torch.LongTensor().to(device)
    running_hazard = torch.FloatTensor().to(device)
    
    model.to_eval()
    with torch.no_grad():
        for x_modal, data_label, mask in dataloader:
            for modality in x_modal:
                x_modal[modality] = x_modal[modality].to(device)
            event = data_label[:, 0].to(device)
            time = data_label[:, 1].to(device)
            mask = mask.to(device)
            hazard, representation = model(x_modal, mask)
            loss = loss_fucnt(representation, list(representation.keys()), hazard, event, time)
            hazard = hazard['hazard']
            
            running_loss.append(loss.item())
            running_sample_time = torch.cat((running_sample_time, time.data.float()))
            running_sample_event = torch.cat((running_sample_event, event.long().data))
            running_hazard = torch.cat((running_hazard, hazard.detach()))
            
    eval_loss = torch.mean(torch.tensor(running_loss))
    eval_c_index = concordance_index(running_sample_time.cpu().numpy(), -running_hazard.cpu().numpy(), running_sample_event.cpu().numpy())
    return eval_loss, eval_c_index

In [6]:
import copy
import random
import logging 
import numpy as np
from utils.loss import Loss
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils.dataset import MultiModalDataset
from utils.encoder import EncoderModel
    

def save_model(model, path):
    torch.save(model.state_dict(), path)

def set_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
        
def train_eval_model(args):
    set_seed(args.seed)

    os.makedirs(args.result_path, exist_ok=True)

    logging.basicConfig(format='%(message)s')
    for k in range(args.fold_range):
        args.fold = k
        args.log_file_name = f"result_fold{args.fold}.log"
        args.log_file_name = os.path.join(args.result_path, args.log_file_name)
        if os.path.exists(args.log_file_name):
            os.remove(args.log_file_name)
            
        handler = logging.FileHandler(args.log_file_name, mode='w')
        logger=logging.getLogger() 
        logger.setLevel(logging.DEBUG) 
        logger.addHandler(handler)
        
        train_dataset = MultiModalDataset(args, 'train', args.modalities, args.modality_data_path, remove_missing=args.remove_missing)
        train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
        test_dataset = MultiModalDataset(args, 'test', args.modalities, args.modality_data_path)
        test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

        device = torch.device(args.device)
        encoder_model = EncoderModel(args.modalities, args.modality_fv_len, args.input_modality_dim)
        model = AggModel(args.num_modalities, args.modality_fv_len, encoder_model, device)
        optimizer = Adam(model.parameters(), lr=args.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 
            optimizer=optimizer, mode='max', factor=0.1,
            patience = args.scheduler_patience, verbose=True, threshold=1e-3,
            threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8)
        model.to_device(device)
        
        loss_fucnt = Loss(trade_off=args.loss_trade_off, mode=args.loss_mode)

        for epoch in range(args.epochs):
            running_loss = []
            running_sample_time = torch.FloatTensor().to(device)
            running_sample_event = torch.LongTensor().to(device)
            running_hazard = torch.FloatTensor().to(device)
            
            model.to_train()
            for x_modal, data_label, mask in train_dataloader:
                for modality in x_modal:
                    x_modal[modality] = x_modal[modality].to(device)
                event = data_label[:, 0].to(device)
                time = data_label[:, 1].to(device)
                mask = mask.to(device)
                hazard, representation = model(x_modal, mask)
                loss = loss_fucnt(representation, list(representation.keys()), hazard, event, time)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                hazard = hazard['hazard']
                
                running_loss.append(loss.item())
                running_sample_time = torch.cat((running_sample_time, time.data.float()))
                running_sample_event = torch.cat((running_sample_event, event.long().data))
                running_hazard = torch.cat((running_hazard, hazard.detach()))
                
            train_loss = torch.mean(torch.tensor(running_loss))
            train_c_index = concordance_index(running_sample_time.cpu().numpy(), -running_hazard.cpu().numpy(), running_sample_event.cpu().numpy())
            print_log(f"Epoch {epoch}/{args.epochs}: Training Loss: {train_loss}, C-Index: {train_c_index}", logger)
            scheduler.step(train_c_index)
            if scheduler.patience == scheduler.num_bad_epochs: #need to fix to make it more accurate
                print_log("Reducing the learning rate", logger)
            best_model = copy.deepcopy(model.state_dict())
        print_log("*" * 100, logger)
        
        model.load_state_dict(best_model)
        model_save_path = os.path.join(args.result_path, f"model_fold{args.fold}.pt")
        save_model(model, model_save_path)
        train_loss, train_c_index = evaluation(model, train_dataloader, loss_fucnt, device)
        print_log(f"Final Train Loss: {train_loss}, C-Index: {train_c_index}", logger)
        test_loss, test_c_index = evaluation(model, test_dataloader, loss_fucnt, device)
        print_log(f"Test Loss: {test_loss}, C-Index: {test_c_index}", logger)
        
        logger.removeHandler(handler)

In [7]:
import numpy as np


all_modalities = ["clinical", "miRNA", "mRNA", "WSI"]
all_modalities = np.array(all_modalities)
i = 2 ** len(all_modalities) - 1
selection = F"{i:0{len(all_modalities)}b}"
modality_select = all_modalities[np.array(list(selection)) == '1']
args = Args(modality_select)
train_eval_model(args)

Result Path: ./logs/graph_att/results_5_24_clinical_miRNA_mRNA_WSI


Epoch 0/20: Training Loss: 4.316524505615234, C-Index: 0.5681106941163775


Epoch 0/20: Training Loss: 4.316524505615234, C-Index: 0.5681106941163775


Epoch 1/20: Training Loss: 4.1388020515441895, C-Index: 0.6819702401649861


Epoch 1/20: Training Loss: 4.1388020515441895, C-Index: 0.6819702401649861


Epoch 2/20: Training Loss: 4.010704517364502, C-Index: 0.7266699440054676


Epoch 2/20: Training Loss: 4.010704517364502, C-Index: 0.7266699440054676


Epoch 3/20: Training Loss: 3.797489881515503, C-Index: 0.7818731190273498


Epoch 3/20: Training Loss: 3.797489881515503, C-Index: 0.7818731190273498


Epoch 4/20: Training Loss: 3.7805306911468506, C-Index: 0.7904941187755542


Epoch 4/20: Training Loss: 3.7805306911468506, C-Index: 0.7904941187755542


Epoch 5/20: Training Loss: 3.5916805267333984, C-Index: 0.8147144518650855


Epoch 5/20: Training Loss: 3.5916805267333984, C-Index: 0.8147144518650855


Epoch 6/20: Training Loss: 3.5453546047210693, C-Index: 0.8322322274313257


Epoch 6/20: Training Loss: 3.5453546047210693, C-Index: 0.8322322274313257


Epoch 7/20: Training Loss: 3.4979381561279297, C-Index: 0.8485509766069951


Epoch 7/20: Training Loss: 3.4979381561279297, C-Index: 0.8485509766069951


Epoch 8/20: Training Loss: 3.362492322921753, C-Index: 0.8624237119459


Epoch 8/20: Training Loss: 3.362492322921753, C-Index: 0.8624237119459


Epoch 9/20: Training Loss: 3.3396222591400146, C-Index: 0.8631311375163367


Epoch 9/20: Training Loss: 3.3396222591400146, C-Index: 0.8631311375163367


Epoch 10/20: Training Loss: 3.246286153793335, C-Index: 0.8768120286327502


Epoch 10/20: Training Loss: 3.246286153793335, C-Index: 0.8768120286327502


Epoch 11/20: Training Loss: 3.278798818588257, C-Index: 0.8769679020635244


Epoch 11/20: Training Loss: 3.278798818588257, C-Index: 0.8769679020635244


Epoch 12/20: Training Loss: 3.163625717163086, C-Index: 0.8964640711742066


Epoch 12/20: Training Loss: 3.163625717163086, C-Index: 0.8964640711742066


Epoch 13/20: Training Loss: 3.0641863346099854, C-Index: 0.8975551851896261


Epoch 13/20: Training Loss: 3.0641863346099854, C-Index: 0.8975551851896261


Epoch 14/20: Training Loss: 2.9544780254364014, C-Index: 0.9138499538374839


Epoch 14/20: Training Loss: 2.9544780254364014, C-Index: 0.9138499538374839


Epoch 15/20: Training Loss: 2.8956356048583984, C-Index: 0.9158403376458316


Epoch 15/20: Training Loss: 2.8956356048583984, C-Index: 0.9158403376458316


Epoch 16/20: Training Loss: 2.9169158935546875, C-Index: 0.9167156269109483


Epoch 16/20: Training Loss: 2.9169158935546875, C-Index: 0.9167156269109483


Epoch 17/20: Training Loss: 2.8380634784698486, C-Index: 0.9272550688840662


Epoch 17/20: Training Loss: 2.8380634784698486, C-Index: 0.9272550688840662


Epoch 18/20: Training Loss: 2.7406044006347656, C-Index: 0.9308521480557788


Epoch 18/20: Training Loss: 2.7406044006347656, C-Index: 0.9308521480557788


Epoch 19/20: Training Loss: 2.680767059326172, C-Index: 0.9385618877471493
****************************************************************************************************


Epoch 19/20: Training Loss: 2.680767059326172, C-Index: 0.9385618877471493
****************************************************************************************************


Final Train Loss: 2.6120760440826416, C-Index: 0.9497128331794583


Final Train Loss: 2.6120760440826416, C-Index: 0.9497128331794583


Test Loss: 4.006721496582031, C-Index: 0.7250584567420109


Test Loss: 4.006721496582031, C-Index: 0.7250584567420109


Epoch 0/20: Training Loss: 4.257290363311768, C-Index: 0.5745933684058111


Epoch 0/20: Training Loss: 4.257290363311768, C-Index: 0.5745933684058111


Epoch 1/20: Training Loss: 4.01749849319458, C-Index: 0.7326729089241759


Epoch 1/20: Training Loss: 4.01749849319458, C-Index: 0.7326729089241759


Epoch 2/20: Training Loss: 3.8389275074005127, C-Index: 0.7761474974660939


Epoch 2/20: Training Loss: 3.8389275074005127, C-Index: 0.7761474974660939


Epoch 3/20: Training Loss: 3.811734199523926, C-Index: 0.7819875476615666


Epoch 3/20: Training Loss: 3.811734199523926, C-Index: 0.7819875476615666


Epoch 4/20: Training Loss: 3.6715657711029053, C-Index: 0.7950793957237319


Epoch 4/20: Training Loss: 3.6715657711029053, C-Index: 0.7950793957237319


Epoch 5/20: Training Loss: 3.5873501300811768, C-Index: 0.8262464404652734


Epoch 5/20: Training Loss: 3.5873501300811768, C-Index: 0.8262464404652734


Epoch 6/20: Training Loss: 3.586930513381958, C-Index: 0.8307592065254115


Epoch 6/20: Training Loss: 3.586930513381958, C-Index: 0.8307592065254115


Epoch 7/20: Training Loss: 3.501821756362915, C-Index: 0.845335199575269


Epoch 7/20: Training Loss: 3.501821756362915, C-Index: 0.845335199575269


Epoch 8/20: Training Loss: 3.376060724258423, C-Index: 0.8564481876538443


Epoch 8/20: Training Loss: 3.376060724258423, C-Index: 0.8564481876538443


Epoch 9/20: Training Loss: 3.333177328109741, C-Index: 0.858849365316859


Epoch 9/20: Training Loss: 3.333177328109741, C-Index: 0.858849365316859


Epoch 10/20: Training Loss: 3.2423837184906006, C-Index: 0.8778295284521453


Epoch 10/20: Training Loss: 3.2423837184906006, C-Index: 0.8778295284521453


Epoch 11/20: Training Loss: 3.2370645999908447, C-Index: 0.8821371687822771


Epoch 11/20: Training Loss: 3.2370645999908447, C-Index: 0.8821371687822771


Epoch 12/20: Training Loss: 3.1183292865753174, C-Index: 0.9037115690911723


Epoch 12/20: Training Loss: 3.1183292865753174, C-Index: 0.9037115690911723


Epoch 13/20: Training Loss: 3.0957868099212646, C-Index: 0.9047733963994401


Epoch 13/20: Training Loss: 3.0957868099212646, C-Index: 0.9047733963994401


Epoch 14/20: Training Loss: 2.993769884109497, C-Index: 0.9119407307302476


Epoch 14/20: Training Loss: 2.993769884109497, C-Index: 0.9119407307302476


Epoch 15/20: Training Loss: 2.9648077487945557, C-Index: 0.9174670592210049


Epoch 15/20: Training Loss: 2.9648077487945557, C-Index: 0.9174670592210049


Epoch 16/20: Training Loss: 2.858250617980957, C-Index: 0.929472947536078


Epoch 16/20: Training Loss: 2.858250617980957, C-Index: 0.929472947536078


Epoch 17/20: Training Loss: 2.7491683959960938, C-Index: 0.9317655292243834


Epoch 17/20: Training Loss: 2.7491683959960938, C-Index: 0.9317655292243834


Epoch 18/20: Training Loss: 2.754342794418335, C-Index: 0.9274337564554274


Epoch 18/20: Training Loss: 2.754342794418335, C-Index: 0.9274337564554274


Epoch 19/20: Training Loss: 2.6551244258880615, C-Index: 0.9423958685264733
****************************************************************************************************


Epoch 19/20: Training Loss: 2.6551244258880615, C-Index: 0.9423958685264733
****************************************************************************************************


Final Train Loss: 2.578139543533325, C-Index: 0.950214778705536


Final Train Loss: 2.578139543533325, C-Index: 0.950214778705536


Test Loss: 3.388868570327759, C-Index: 0.7385052034058657


Test Loss: 3.388868570327759, C-Index: 0.7385052034058657


Epoch 0/20: Training Loss: 4.317249298095703, C-Index: 0.551804284276609


Epoch 0/20: Training Loss: 4.317249298095703, C-Index: 0.551804284276609


Epoch 1/20: Training Loss: 4.117033004760742, C-Index: 0.6885872430152873


Epoch 1/20: Training Loss: 4.117033004760742, C-Index: 0.6885872430152873


Epoch 2/20: Training Loss: 4.017776966094971, C-Index: 0.7256193990511334


Epoch 2/20: Training Loss: 4.017776966094971, C-Index: 0.7256193990511334


Epoch 3/20: Training Loss: 3.8795764446258545, C-Index: 0.7765371160205109


Epoch 3/20: Training Loss: 3.8795764446258545, C-Index: 0.7765371160205109


Epoch 4/20: Training Loss: 3.767591714859009, C-Index: 0.8045837925911726


Epoch 4/20: Training Loss: 3.767591714859009, C-Index: 0.8045837925911726


Epoch 5/20: Training Loss: 3.689465284347534, C-Index: 0.8174390185460296


Epoch 5/20: Training Loss: 3.689465284347534, C-Index: 0.8174390185460296


Epoch 6/20: Training Loss: 3.6011579036712646, C-Index: 0.8333493075190492


Epoch 6/20: Training Loss: 3.6011579036712646, C-Index: 0.8333493075190492


Epoch 7/20: Training Loss: 3.4957094192504883, C-Index: 0.8474625005990319


Epoch 7/20: Training Loss: 3.4957094192504883, C-Index: 0.8474625005990319


Epoch 8/20: Training Loss: 3.383037805557251, C-Index: 0.8662840849187713


Epoch 8/20: Training Loss: 3.383037805557251, C-Index: 0.8662840849187713


Epoch 9/20: Training Loss: 3.35684871673584, C-Index: 0.8732448363444674


Epoch 9/20: Training Loss: 3.35684871673584, C-Index: 0.8732448363444674


Epoch 10/20: Training Loss: 3.2213470935821533, C-Index: 0.8941630325394163


Epoch 10/20: Training Loss: 3.2213470935821533, C-Index: 0.8941630325394163


Epoch 11/20: Training Loss: 3.125552177429199, C-Index: 0.9027052283509848


Epoch 11/20: Training Loss: 3.125552177429199, C-Index: 0.9027052283509848


Epoch 12/20: Training Loss: 3.0706565380096436, C-Index: 0.9085997028801457


Epoch 12/20: Training Loss: 3.0706565380096436, C-Index: 0.9085997028801457


Epoch 13/20: Training Loss: 2.983626127243042, C-Index: 0.9105525470839124


Epoch 13/20: Training Loss: 2.983626127243042, C-Index: 0.9105525470839124


Epoch 14/20: Training Loss: 2.931974172592163, C-Index: 0.9195739684669574


Epoch 14/20: Training Loss: 2.931974172592163, C-Index: 0.9195739684669574


Epoch 15/20: Training Loss: 2.935678243637085, C-Index: 0.9184837302918484


Epoch 15/20: Training Loss: 2.935678243637085, C-Index: 0.9184837302918484


Epoch 16/20: Training Loss: 2.7195374965667725, C-Index: 0.9340945032826952


Epoch 16/20: Training Loss: 2.7195374965667725, C-Index: 0.9340945032826952


Epoch 17/20: Training Loss: 2.713047981262207, C-Index: 0.9348492835577706


Epoch 17/20: Training Loss: 2.713047981262207, C-Index: 0.9348492835577706


Epoch 18/20: Training Loss: 2.6555168628692627, C-Index: 0.9410432740691044


Epoch 18/20: Training Loss: 2.6555168628692627, C-Index: 0.9410432740691044


Epoch 19/20: Training Loss: 2.5790631771087646, C-Index: 0.9446734077730388
****************************************************************************************************


Epoch 19/20: Training Loss: 2.5790631771087646, C-Index: 0.9446734077730388
****************************************************************************************************


Final Train Loss: 2.5234458446502686, C-Index: 0.9560430344563186


Final Train Loss: 2.5234458446502686, C-Index: 0.9560430344563186


Test Loss: 4.009307861328125, C-Index: 0.7561214495592556


Test Loss: 4.009307861328125, C-Index: 0.7561214495592556


Epoch 0/20: Training Loss: 4.2684245109558105, C-Index: 0.5804677570704011


Epoch 0/20: Training Loss: 4.2684245109558105, C-Index: 0.5804677570704011


Epoch 1/20: Training Loss: 4.121853351593018, C-Index: 0.6903633491311216


Epoch 1/20: Training Loss: 4.121853351593018, C-Index: 0.6903633491311216


Epoch 2/20: Training Loss: 3.9126789569854736, C-Index: 0.766335269453254


Epoch 2/20: Training Loss: 3.9126789569854736, C-Index: 0.766335269453254


Epoch 3/20: Training Loss: 3.8638553619384766, C-Index: 0.7664896839254535


Epoch 3/20: Training Loss: 3.8638553619384766, C-Index: 0.7664896839254535


Epoch 4/20: Training Loss: 3.764188766479492, C-Index: 0.7986791623608785


Epoch 4/20: Training Loss: 3.764188766479492, C-Index: 0.7986791623608785


Epoch 5/20: Training Loss: 3.651736259460449, C-Index: 0.8190737507275297


Epoch 5/20: Training Loss: 3.651736259460449, C-Index: 0.8190737507275297


Epoch 6/20: Training Loss: 3.564105272293091, C-Index: 0.8354891969259642


Epoch 6/20: Training Loss: 3.564105272293091, C-Index: 0.8354891969259642


Epoch 7/20: Training Loss: 3.531994104385376, C-Index: 0.8441364073691338


Epoch 7/20: Training Loss: 3.531994104385376, C-Index: 0.8441364073691338


Epoch 8/20: Training Loss: 3.444305419921875, C-Index: 0.8597916592428939


Epoch 8/20: Training Loss: 3.444305419921875, C-Index: 0.8597916592428939


Epoch 9/20: Training Loss: 3.335590362548828, C-Index: 0.8729881575977859


Epoch 9/20: Training Loss: 3.335590362548828, C-Index: 0.8729881575977859


Epoch 10/20: Training Loss: 3.2473490238189697, C-Index: 0.88079202746202


Epoch 10/20: Training Loss: 3.2473490238189697, C-Index: 0.88079202746202


Epoch 11/20: Training Loss: 3.2076311111450195, C-Index: 0.897017425079286


Epoch 11/20: Training Loss: 3.2076311111450195, C-Index: 0.897017425079286


Epoch 12/20: Training Loss: 3.1295278072357178, C-Index: 0.9038591740013541


Epoch 12/20: Training Loss: 3.1295278072357178, C-Index: 0.9038591740013541


Epoch 13/20: Training Loss: 3.0444834232330322, C-Index: 0.9128864816068607


Epoch 13/20: Training Loss: 3.0444834232330322, C-Index: 0.9128864816068607


Epoch 14/20: Training Loss: 2.942018508911133, C-Index: 0.9173169891553529


Epoch 14/20: Training Loss: 2.942018508911133, C-Index: 0.9173169891553529


Epoch 15/20: Training Loss: 2.8744771480560303, C-Index: 0.9253584197460476


Epoch 15/20: Training Loss: 2.8744771480560303, C-Index: 0.9253584197460476


Epoch 16/20: Training Loss: 2.8544235229492188, C-Index: 0.9284348311537136


Epoch 16/20: Training Loss: 2.8544235229492188, C-Index: 0.9284348311537136


Epoch 17/20: Training Loss: 2.725858688354492, C-Index: 0.9324852415398687


Epoch 17/20: Training Loss: 2.725858688354492, C-Index: 0.9324852415398687


Epoch 18/20: Training Loss: 2.696054220199585, C-Index: 0.9352409459668127


Epoch 18/20: Training Loss: 2.696054220199585, C-Index: 0.9352409459668127


Epoch 19/20: Training Loss: 2.656233549118042, C-Index: 0.9390775516991531
****************************************************************************************************


Epoch 19/20: Training Loss: 2.656233549118042, C-Index: 0.9390775516991531
****************************************************************************************************


Final Train Loss: 2.714604377746582, C-Index: 0.9437099858651368


Final Train Loss: 2.714604377746582, C-Index: 0.9437099858651368


Test Loss: 4.0100860595703125, C-Index: 0.7494434325035418


Test Loss: 4.0100860595703125, C-Index: 0.7494434325035418


Epoch 0/20: Training Loss: 4.257678031921387, C-Index: 0.5493334630728812


Epoch 0/20: Training Loss: 4.257678031921387, C-Index: 0.5493334630728812


Epoch 1/20: Training Loss: 4.048366069793701, C-Index: 0.715590639291622


Epoch 1/20: Training Loss: 4.048366069793701, C-Index: 0.715590639291622


Epoch 2/20: Training Loss: 3.883894920349121, C-Index: 0.7580762868541403


Epoch 2/20: Training Loss: 3.883894920349121, C-Index: 0.7580762868541403


Epoch 3/20: Training Loss: 3.7983787059783936, C-Index: 0.7841296098083098


Epoch 3/20: Training Loss: 3.7983787059783936, C-Index: 0.7841296098083098


Epoch 4/20: Training Loss: 3.682027816772461, C-Index: 0.7997834971295125


Epoch 4/20: Training Loss: 3.682027816772461, C-Index: 0.7997834971295125


Epoch 5/20: Training Loss: 3.5657472610473633, C-Index: 0.823319061983069


Epoch 5/20: Training Loss: 3.5657472610473633, C-Index: 0.823319061983069


Epoch 6/20: Training Loss: 3.4863250255584717, C-Index: 0.8483020336674126


Epoch 6/20: Training Loss: 3.4863250255584717, C-Index: 0.8483020336674126


Epoch 7/20: Training Loss: 3.4521429538726807, C-Index: 0.8361997664688139


Epoch 7/20: Training Loss: 3.4521429538726807, C-Index: 0.8361997664688139


Epoch 8/20: Training Loss: 3.4312448501586914, C-Index: 0.8533862021990853


Epoch 8/20: Training Loss: 3.4312448501586914, C-Index: 0.8533862021990853


Epoch 9/20: Training Loss: 3.2803618907928467, C-Index: 0.8685657292984333


Epoch 9/20: Training Loss: 3.2803618907928467, C-Index: 0.8685657292984333


Epoch 10/20: Training Loss: 3.2091023921966553, C-Index: 0.8792814050793033


Epoch 10/20: Training Loss: 3.2091023921966553, C-Index: 0.8792814050793033


Epoch 11/20: Training Loss: 3.122429609298706, C-Index: 0.8888050987642309


Epoch 11/20: Training Loss: 3.122429609298706, C-Index: 0.8888050987642309


Epoch 12/20: Training Loss: 3.091298818588257, C-Index: 0.8947528461613311


Epoch 12/20: Training Loss: 3.091298818588257, C-Index: 0.8947528461613311


Epoch 13/20: Training Loss: 3.059008836746216, C-Index: 0.8949231293178943


Epoch 13/20: Training Loss: 3.059008836746216, C-Index: 0.8949231293178943


Epoch 14/20: Training Loss: 2.925661325454712, C-Index: 0.9116352048263112


Epoch 14/20: Training Loss: 2.925661325454712, C-Index: 0.9116352048263112


Epoch 15/20: Training Loss: 2.9102985858917236, C-Index: 0.9110635399435633


Epoch 15/20: Training Loss: 2.9102985858917236, C-Index: 0.9110635399435633


Epoch 16/20: Training Loss: 2.8988659381866455, C-Index: 0.9192736207064318


Epoch 16/20: Training Loss: 2.8988659381866455, C-Index: 0.9192736207064318


Epoch 17/20: Training Loss: 2.7572946548461914, C-Index: 0.9247713340469008


Epoch 17/20: Training Loss: 2.7572946548461914, C-Index: 0.9247713340469008


Epoch 18/20: Training Loss: 2.7236146926879883, C-Index: 0.9344288216405566


Epoch 18/20: Training Loss: 2.7236146926879883, C-Index: 0.9344288216405566


Epoch 19/20: Training Loss: 2.705108880996704, C-Index: 0.9360708377931303
****************************************************************************************************


Epoch 19/20: Training Loss: 2.705108880996704, C-Index: 0.9360708377931303
****************************************************************************************************


Final Train Loss: 2.598567247390747, C-Index: 0.9457039992215627


Final Train Loss: 2.598567247390747, C-Index: 0.9457039992215627


Test Loss: 3.5136170387268066, C-Index: 0.7753409509767785


Test Loss: 3.5136170387268066, C-Index: 0.7753409509767785
