In [None]:
import os
import random
import numpy as np

import torch
import time
import os
import random
import numpy as np
import pandas as pd

import torch
import torch.utils.data as data
import csv
from torch.utils.data import DataLoader, SubsetRandomSampler

In [None]:
def set_seed(seed=7):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

In [None]:
class TCGA_Survival(data.Dataset):
    def __init__(self, excel_file):
        print('[dataset] loading dataset from %s' % (excel_file))
        rows = pd.read_csv(excel_file)
        self.rows = self.disc_label(rows)
        label_dist = self.rows['Label'].value_counts().sort_index()
        print('[dataset] discrete label distribution: ')
        print(label_dist)
        print('[dataset] dataset from %s, number of cases=%d' % (excel_file, len(self.rows)))
    def get_split(self, fold=0):
        random.seed(1)
        ratio=0.2
        assert 0 <= fold <= 4, 'fold should be in 0 ~ 4'
        sample_index = random.sample(range(len(self.rows)), len(self.rows)) 
        num_split = round((len(self.rows) - 1) * ratio)
        if fold < 1 / ratio - 1: 
            val_split = sample_index[fold * num_split: (fold + 1) * num_split]
        else:
            val_split = sample_index[fold * num_split:]
        train_split = [i for i in sample_index if i not in val_split]
        print("[dataset] training split: {}, validation split: {}".format(len(train_split), len(val_split)))
        return train_split, val_split
    def __getitem__(self, index):
        case = self.rows.iloc[index, :].values.tolist()
        ID, Event, Status, WSI = case[:4]
        Label = case[-1]
        Censorship = 1 if int(Status) == 0 else 0
        fo=r"F:\lung_dl\data\survival\TCGA-LUSC-R50\pt"
        WSI = torch.load(f"{fo}/{ID}.pt")
        return (ID, WSI, Event, Censorship, Label)
    def __len__(self):
        return len(self.rows)
    def disc_label(self, rows):
        n_bins, eps = 4, 1e-6
        uncensored_df = rows[rows['Status'] == 1] 
        disc_labels, q_bins = pd.qcut(uncensored_df['Event'], q=n_bins, retbins=True, labels=False) 
        q_bins[-1] = rows['Event'].max() + eps
        q_bins[0] = rows['Event'].min() - eps
        disc_labels, q_bins = pd.cut(rows['Event'], bins=q_bins, retbins=True, labels=False, right=False, include_lowest=True) 
        disc_labels = disc_labels.values.astype(int)
        disc_labels[disc_labels < 0] = -1
        rows.insert(len(rows.columns), 'Label', disc_labels)
        return rows

In [None]:
set_seed()

results_dir = "./results/{dataset}/[{model}]-[{time}]".format(
    dataset="TCGA-LUSC-R50",##########################################################################################
    model="topomil",
    time=time.strftime("%Y-%m-%d]-[%H-%M-%S"),
)
print(results_dir)
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

In [None]:
num_classes = 4
n_features = 1024 # plip:512\resnet50:1024############################################################################################

In [None]:
dataset = TCGA_Survival(excel_file=r"F:\lung_dl\data\survival\TCGA-LUSC-PLIP\LUSC.csv")######################################################################

In [None]:
class CV_Meter():
    def __init__(self, fold=5):
        self.fold = fold
        self.header = ["folds", "fold 0", "fold 1", "fold 2", "fold 3", "fold 4", "mean", "std"]
        self.epochs = ["epoch"]
        self.cindex = ["cindex"]

    def updata(self, score, epoch):
        self.epochs.append(epoch)
        self.cindex.append(round(score, 4))

    def save(self, path):
        self.cindex.append(round(np.mean(self.cindex[1:self.fold + 1]), 4))
        self.cindex.append(round(np.std(self.cindex[1:self.fold + 1]), 4))
        print("save evaluation resluts to", path)
        with open(path, "w", encoding="utf-8", newline="") as fp:
            writer = csv.writer(fp)
            writer.writerow(self.header)
            writer.writerow(self.epochs)
            writer.writerow(self.cindex)
meter = CV_Meter(fold=5)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
def initialize_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            # ref from huggingface
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m,nn.Linear):
            # ref from clam
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m,nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

import torch
torch.autograd.set_detect_anomaly(True)
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_max_pool, GlobalAttention

class TopoAggregator(nn.Module):
    def __init__(self, dim_in=512, dim_hidden=512, topk=6):
        super().__init__()
        self.proj_q = nn.Linear(dim_in, dim_hidden)
        self.proj_k = nn.Linear(dim_in, dim_hidden)
        self.topk = topk

    def forward(self, x):
        q = self.proj_q(x)  # Query
        k = self.proj_k(x) 

        S = torch.matmul(q, k.transpose(-2, -1))  
        S_topk, idx_topk = torch.topk(S, k=self.topk, dim=-1)
        idx_topk = idx_topk.to(torch.long)

        idx_topk_exp = idx_topk.expand(k.size(0), -1, -1)
        batch_indices = torch.arange(k.size(0)).view(-1, 1, 1).to(idx_topk.device)
        K_neighbors = k[batch_indices, idx_topk_exp, :]

        P_topk = F.softmax(S_topk, dim=2)
        X_agg = torch.mul(P_topk.unsqueeze(-1), K_neighbors) + torch.matmul((1 - P_topk).unsqueeze(-1), q.unsqueeze(2))

        G = torch.tanh(X_agg)
        W_KA = torch.einsum('ijkl,ijkm->ijk', K_neighbors, G)
        P_KA = F.softmax(W_KA, dim=2).unsqueeze(2)
        X_topo = torch.matmul(P_KA, K_neighbors).squeeze(2)

        return X_topo + q


class DAttention(nn.Module):
    def __init__(self,input_dim,n_classes,TopoAggregator=None):
        super(DAttention, self).__init__()
        self.L = 512 #512
        self.D = 128 #128
        self.K = 1
        self.feature = [nn.Linear(input_dim, 512)]# nn.LayerNorm(input_dim),
        self.feature += [nn.ReLU()]# 

        #if dropout:
        self.feature += [nn.Dropout(0.25)]
        if TopoAggregator is not None:
            self.feature += [TopoAggregator] 
        self.feature = nn.Sequential(*self.feature)

        self.attention = nn.Sequential(
            # nn.LayerNorm(self.L),
            nn.Linear(self.L, self.D),
            # nn.LayerNorm(self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.L*self.K, n_classes),
        )

        self.apply(initialize_weights)
    def forward(self, x):

        feature = self.feature(x) #1 * N * 512
        feature = feature.squeeze(0) # N * 512 
        A = self.attention(feature) # N * 1

        A = torch.transpose(A, -1, -2)  # 1*N
        A = F.softmax(A, dim=-1)  # 1 * N 

        M = torch.mm(A, feature)  # KxL
        Y_prob = self.classifier(M)
        hazards = torch.sigmoid(Y_prob)
        S = torch.cumprod(1 - hazards, dim=1)
        return hazards, S

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


def define_loss(args):
    if args == "ce_surv":
        loss = CrossEntropySurvLoss(alpha=0.0)
    elif args == "nll_surv":
        loss = NLLSurvLoss(alpha=0.0)
    elif args == "nll_surv_l1":
        loss = [NLLSurvLoss(alpha=0.0), nn.L1Loss()]
    elif args == "nll_surv_mse":
        loss = [NLLSurvLoss(alpha=0.0), nn.MSELoss()]
    elif args == "nll_surv_kl":
        loss = [NLLSurvLoss(alpha=0.0), KLLoss()]
    elif args == "nll_surv_cos":
        loss = [NLLSurvLoss(alpha=0.0), CosineLoss()]
    else:
        raise NotImplementedError
    return loss


def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1)  # ground truth bin, 1,2,...,k
    c = c.view(batch_size, 1).float()  # censorship status, 0 or 1
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1)  # surival is cumulative product of 1 - hazards
    # without padding, S(0) = S[0], h(0) = h[0]
    S_padded = torch.cat([torch.ones_like(c), S], 1)  # S(-1) = 0, all patients are alive from (-inf, 0) by definition
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
    # h[y] = h(1)
    # S[1] = S(1)
    uncensored_loss = -(1 - c) * (
        torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))
    )
    censored_loss = -c * torch.log(torch.gather(S_padded, 1, Y + 1).clamp(min=eps))
    neg_l = censored_loss + uncensored_loss
    loss = (1 - alpha) * neg_l + alpha * uncensored_loss
    loss = loss.mean()
    return loss


def ce_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
    batch_size = len(Y)
    Y = Y.view(batch_size, 1)  # ground truth bin, 1,2,...,k
    c = c.view(batch_size, 1).float()  # censorship status, 0 or 1
    if S is None:
        S = torch.cumprod(1 - hazards, dim=1)  # surival is cumulative product of 1 - hazards
    # without padding, S(0) = S[0], h(0) = h[0]
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
    # h[y] = h(1)
    # S[1] = S(1)
    S_padded = torch.cat([torch.ones_like(c), S], 1)
    reg = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y) + eps) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
    ce_l = -c * torch.log(torch.gather(S, 1, Y).clamp(min=eps)) - (1 - c) * torch.log(1 - torch.gather(S, 1, Y).clamp(min=eps))
    loss = (1 - alpha) * ce_l + alpha * reg
    loss = loss.mean()
    return loss


class CrossEntropySurvLoss(object):
    def __init__(self, alpha=0.15):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None):
        if alpha is None:
            return ce_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return ce_loss(hazards, S, Y, c, alpha=alpha)


# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0)
class NLLSurvLoss(object):
    def __init__(self, alpha=0.15):
        self.alpha = alpha

    def __call__(self, hazards, S, Y, c, alpha=None):
        if alpha is None:
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
        else:
            return nll_loss(hazards, S, Y, c, alpha=alpha)


class KLLoss(object):
    def __call__(self, y, y_hat):
        return F.kl_div(y_hat.softmax(dim=-1).log(), y.softmax(dim=-1), reduction="sum")


class CosineLoss(object):
    def __call__(self, y, y_hat):
        return 1 - F.cosine_similarity(y, y_hat, dim=1)


In [None]:
import torch
from collections import defaultdict
import math
from torch.optim.optimizer import Optimizer


def define_optimizer(args, model,lr=0.001,weight_decay=0.0001):##############################################################################
    if args == 'SGD':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=0.9, weight_decay=weight_decay)
    elif args == 'AdamW':
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    elif args == 'Adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    elif args == 'RAdam':
        optimizer = RAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    elif args == 'PlainRAdam':
        optimizer = PlainRAdam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
    elif args == 'Lookahead':
        base_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
        optimizer = Lookahead(base_optimizer)
    else:
        raise NotImplementedError('Optimizer [{}] is not implemented'.format(args))
    return optimizer


class RAdam(Optimizer):
    '''
    RAdam Optimizer  
    Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
    Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
    '''

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = group['lr'] * math.sqrt(
                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                                N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = group['lr'] / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0 and group['weight_decay'] is not None:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss


class PlainRAdam(Optimizer):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)

        super(PlainRAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(PlainRAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                beta2_t = beta2 ** state['step']
                N_sma_max = 2 / (1 - beta2) - 1
                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:
                    step_size = group['lr'] * math.sqrt(
                        (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
                            N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                else:
                    step_size = group['lr'] / (1 - beta1 ** state['step'])
                    p_data_fp32.add_(-step_size, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss


class Lookahead(Optimizer):
    '''
    Lookahead Optimizer Wrapper
    Implementation modified from: https://github.com/alphadl/lookahead.pytorch
    Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
    Hacked together by / Copyright 2020 Ross Wightman
    '''

    def __init__(self, base_optimizer, alpha=0.5, k=6):
        if not 0.0 <= alpha <= 1.0:
            raise ValueError(f'Invalid slow update rate: {alpha}')
        if not 1 <= k:
            raise ValueError(f'Invalid lookahead steps: {k}')
        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
        self.base_optimizer = base_optimizer
        self.param_groups = self.base_optimizer.param_groups
        self.defaults = base_optimizer.defaults
        self.defaults.update(defaults)
        self.state = defaultdict(dict)
        # manually add our defaults to the param groups
        for name, default in defaults.items():
            for group in self.param_groups:
                group.setdefault(name, default)

    def update_slow(self, group):
        for fast_p in group["params"]:
            if fast_p.grad is None:
                continue
            param_state = self.state[fast_p]
            if 'slow_buffer' not in param_state:
                param_state['slow_buffer'] = torch.empty_like(fast_p.data)
                param_state['slow_buffer'].copy_(fast_p.data)
            slow = param_state['slow_buffer']
            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
            fast_p.data.copy_(slow)

    def sync_lookahead(self):
        for group in self.param_groups:
            self.update_slow(group)

    def step(self, closure=None):
        # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
        loss = self.base_optimizer.step(closure)
        for group in self.param_groups:
            group['lookahead_step'] += 1
            if group['lookahead_step'] % group['lookahead_k'] == 0:
                self.update_slow(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.base_optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['state'],
            'param_groups': state_dict['param_groups'],
        }
        self.base_optimizer.load_state_dict(fast_state_dict)

        # We want to restore the slow state, but share param_groups reference
        # with base_optimizer. This is a bit redundant but least code
        slow_state_new = False
        if 'slow_state' not in state_dict:
            print('Loading state_dict from optimizer without Lookahead applied.')
            state_dict['slow_state'] = defaultdict(dict)
            slow_state_new = True
        slow_state_dict = {
            'state': state_dict['slow_state'],
            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
        }
        super(Lookahead, self).load_state_dict(slow_state_dict)
        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
        if slow_state_new:
            # reapply defaults to catch missing lookahead specific ones
            for name, default in self.defaults.items():
                for group in self.param_groups:
                    group.setdefault(name, default)

In [None]:
import torch.optim.lr_scheduler as lr_scheduler


def define_scheduler(args, optimizer,num_epoch=200):#################################################
    if args == 'exp':
        scheduler = lr_scheduler.ExponentialLR(optimizer, 0.1, last_epoch=-1)
    elif args == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=num_epoch / 2, gamma=0.1)
    elif args == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    elif args == 'cosine':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoch, eta_min=0)
    elif args == 'None':
        scheduler = None
    else:
        return NotImplementedError('Scheduler [{}] is not implemented'.format(args))
    return scheduler

In [None]:
import os
import numpy as np
from tqdm import tqdm
from sksurv.metrics import concordance_index_censored
import torch.optim
class Engine(object):
    def __init__(self, results_dir, fold):
        self.results_dir = results_dir
        self.fold = fold
        self.best_scores = 0
        self.best_epoch = 0
        self.filename_best = None

    def learning(self, model, train_loader, val_loader, criterion, optimizer, scheduler,resume,evaluate,num_epoch):
        if torch.cuda.is_available():
            model = model.cuda()
        if resume is not None:
            if os.path.isfile(resume):
                print("=> loading checkpoint '{}'".format(resume))
                checkpoint = torch.load(resume)
                self.best_scores = checkpoint['best_score']
                model.load_state_dict(checkpoint['state_dict'])
                print("=> loaded checkpoint (score: {})".format(checkpoint['best_score']))
            else:
                print("=> no checkpoint found at '{}'".format(resume))

        if evaluate:
            self.validate(val_loader, model, criterion)
            return

        for epoch in range(num_epoch):
            self.epoch = epoch
            # train for one epoch
            self.train(train_loader, model, criterion, optimizer)
            # evaluate on validation set
            scores = self.validate(val_loader, model, criterion)
            # remember best c-index and save checkpoint
            is_best = scores > self.best_scores
            if is_best:
                self.best_scores = scores
                self.best_epoch = self.epoch
                self.save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'best_score': self.best_scores})
            print(' *** best score={:.4f} at epoch {}'.format(self.best_scores, self.best_epoch))
            scheduler.step()
            print('>>>')
            print('>>>')
        return self.best_scores, self.best_epoch

    def train(self, data_loader, model, criterion, optimizer):
        model.train()

        total_loss = 0.0
        all_risk_scores = np.zeros((len(data_loader)))
        all_censorships = np.zeros((len(data_loader)))
        all_event_times = np.zeros((len(data_loader)))
        dataloader = tqdm(data_loader, desc='Train Epoch {}'.format(self.epoch))
        for batch_idx, (data_ID, data_WSI, data_Event, data_Censorship, data_Label) in enumerate(dataloader):
            if torch.cuda.is_available():
                data_WSI = data_WSI.cuda()
                data_Label = data_Label.type(torch.LongTensor).cuda()
                data_Censorship = data_Censorship.type(torch.FloatTensor).cuda()
            # prediction
            
            hazards, S = model(data_WSI)
            loss = criterion(hazards=hazards, S=S, Y=data_Label, c=data_Censorship)
            # results
            risk = -torch.sum(S, dim=1).detach().cpu().numpy()
            all_risk_scores[batch_idx] = risk
            all_censorships[batch_idx] = data_Censorship.item()
            all_event_times[batch_idx] = data_Event
            total_loss += loss.item()
            # backward to update parameters
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # calculate loss and error for each epoch
        loss = total_loss / len(dataloader)
        c_index = concordance_index_censored((1 - all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
        print('loss: {:.4f}, c_index: {:.4f}'.format(loss, c_index))


    def validate(self, data_loader, model, criterion):
        model.eval()
        total_loss = 0.0
        all_risk_scores = np.zeros((len(data_loader)))
        all_censorships = np.zeros((len(data_loader)))
        all_event_times = np.zeros((len(data_loader)))
        dataloader = tqdm(data_loader, desc='Test Epoch {}'.format(self.epoch))

        for batch_idx, (data_ID, data_WSI, data_Event, data_Censorship, data_Label) in enumerate(dataloader):
            if torch.cuda.is_available():
                data_WSI = data_WSI.cuda()
                data_Label = data_Label.type(torch.LongTensor).cuda()
                data_Censorship = data_Censorship.type(torch.FloatTensor).cuda()
            # prediction
            with torch.no_grad():
                hazards, S = model(data_WSI)
            loss = criterion(hazards=hazards, S=S, Y=data_Label, c=data_Censorship)
            total_loss += loss.item()
            # results
            risk = -torch.sum(S, dim=1).detach().cpu().numpy()
            all_risk_scores[batch_idx] = risk
            all_censorships[batch_idx] = data_Censorship.item()
            all_event_times[batch_idx] = data_Event
        # calculate loss and error for each epoch
        loss = total_loss / len(dataloader)
        c_index = concordance_index_censored((1 - all_censorships).astype(bool), all_event_times, all_risk_scores, tied_tol=1e-08)[0]
        print('loss: {:.4f}, c_index: {:.4f}'.format(loss, c_index))
        return c_index

    def save_checkpoint(self, state):
        if self.filename_best is not None:
            os.remove(self.filename_best)
        fold_dir = os.path.join(self.results_dir, f'fold_{self.fold}')
        os.makedirs(fold_dir, exist_ok=True)
        self.filename_best = os.path.join(self.results_dir,
                                          'fold_' + str(self.fold),
                                          'model_best_{score:.4f}_{epoch}.pth.tar'.format(score=state['best_score'], epoch=state['epoch']))
        print('save best model {filename}'.format(filename=self.filename_best))
        torch.save(state, self.filename_best)

In [None]:
for fold in range(5):
    # get split
    train_split, val_split = dataset.get_split(fold)
    train_loader = DataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True, sampler=SubsetRandomSampler(train_split))
    val_loader = DataLoader(dataset, batch_size=1, num_workers=0, pin_memory=True, sampler=SubsetRandomSampler(val_split))
    # build model, criterion, optimizer, schedular
    #################################################
    # Unimodal: WSI
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    age = TopoAggregator(dim_in=512, dim_hidden=512, topk=6).to(device) #######################################
    model = DAttention(input_dim=1024, n_classes=4, TopoAggregator=age).to(device)################################################################################input_dim R50：1024
    engine = Engine(results_dir, fold)
    criterion = define_loss("nll_surv")
    optimizer = define_optimizer("Adam", model)
    scheduler = define_scheduler('cosine', optimizer)
    # start training
    score, epoch = engine.learning(model, train_loader, val_loader, criterion, optimizer, scheduler,resume=None,evaluate=False,num_epoch=200)
    meter.updata(score, epoch)

In [None]:
csv_path = os.path.join(results_dir, "results_{}.csv".format("topomil"))
meter.save(csv_path)