In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import models as torch_models1
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import openpyxl
import xlsxwriter
from collections import namedtuple
import os
import random
import math
from PIL import Image
from scipy import interp
from sklearn.model_selection import KFold, StratifiedKFold
import pandas as pd
from torch.utils.data import Sampler, BatchSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, recall_score, precision_score, confusion_matrix, precision_recall_curve
from torch.autograd import Variable
import logging
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
from sklearn.metrics import auc as calc_auc
from topk.svm import SmoothTop1SVM
from torch_kmeans import KMeans


In [2]:
SEED = 1234
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
#train_dataset = pd.read_excel('/home/ldap_howard/script/MSI_CRC_DX_0307.xlsx',sheet_name = 'MSIHMSS')
train_dataset = pd.read_excel('/home/ldap_howard/script/Gene_CRC_DX_0307.xlsx',sheet_name='TP53')

In [4]:
train_x = train_dataset['Patients'].values
train_y = train_dataset['isMSIH']
train_y = train_y.values

In [5]:
kf = StratifiedKFold(n_splits=4, shuffle=True, random_state=22)

In [6]:
def npy_loader(path):
    x = np.load(path,allow_pickle=True).item()
    x_im = torch.from_numpy(x['features'])
    return x_im

def npy_loader_count(path):
    x = np.load(path,allow_pickle=True).item()
    x_im = torch.from_numpy(x['counts'])
    #x_im = x_im[:,1:3]
    return x_im

In [7]:
class SimpleDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.path = '/ORCA_lake/TCGA-COAD/feature/CRC_resnet0307/'
        self.path_c = '/ORCA_lake/TCGA-COAD/hovernet_kmeans/CRC_0307_MI2N/'
        
    def __len__(self):
        return len(self.x)   
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.x[idx]+'.npy')
        count_path = os.path.join(self.path_c, self.x[idx]+'.npy')

        return image_path, count_path, self.y[idx]

In [8]:
class FocalLoss(nn.Module):
    
    def __init__(self, weight=None, 
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, logits, label):
        log_prob = F.log_softmax(logits, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            label, 
            weight=self.weight,
            reduction = self.reduction
        )

In [9]:
class Attention(nn.Module):
    def __init__(self, L=2048, D=1024, dropout=True, n_classes=2, top_k=1, instance_loss_fn=FocalLoss()):
        super(Attention, self).__init__()
        self.L = L
        self.D = D
        self.K = 1

        self.layer1 = nn.Linear(self.L, self.D)
        if dropout:
            self.attention_V = nn.Sequential(nn.Linear(self.D, 512), nn.Tanh(), nn.Dropout(0.25))
            self.attention_U = nn.Sequential(nn.Linear(self.D, 512), nn.Sigmoid(), nn.Dropout(0.25))
        else:
            self.attention_V = nn.Sequential(nn.Linear(self.D, 512), nn.Tanh())
            self.attention_U = nn.Sequential(nn.Linear(self.D, 512), nn.Sigmoid())

        self.attention_weights = nn.Linear(512, self.K)

        self.classifier = nn.Sequential(nn.Linear(self.D,1024),
                                        nn.ReLU(),
                                        nn.Dropout(0.25),
                                        nn.Linear(1024, 512),
                                        nn.ReLU(), 
                                        nn.Linear(512,2),
                                        nn.Sigmoid())
        self.top_k = top_k
        self.instance_loss = instance_loss_fn
        self.fc_X = nn.Sequential(nn.Linear(1, 2), nn.Sigmoid())
        self.fc_c = nn.Sequential(nn.Linear(4, 2), nn.Sigmoid())

    def relocate(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.classifier.to(device)
        self.attention_V.to(device)
        self.attention_U.to(device)
        self.attention_weights.to(device)
        self.instance_loss.to(device)
        self.fc_c.to(device)
        self.fc_X.to(device)
    
    def inst_eval(self, A_T, count):
        logits_c = self.fc_c(count)
        hover_logits = torch.mm(A_T, logits_c)
        y_probs_c = F.softmax(logits_c, dim=1)
        k = math.ceil(logits_c.size()[0] / 20)
        _, predicted_class = torch.max(y_probs_c, dim=1)
        predicted_prob = y_probs_c[torch.arange(y_probs_c.size(0)), predicted_class]
        top_instance_idx = torch.topk(predicted_prob, 5, largest=True)[1]
        top_instance = torch.index_select(y_probs_c, dim=0, index=top_instance_idx)
        _, pseudo_targets = torch.max(top_instance, dim=1)

        A_T = torch.transpose(A_T, 1, 0)  # KxN
        logits_x = self.fc_X(A_T)
        y_probs_x = F.softmax(logits_x, dim=1)
        top_instance_x = torch.index_select(logits_x, dim=0, index=top_instance_idx)
        pseudo_logits = top_instance_x

        return pseudo_logits, pseudo_targets,  hover_logits

    def forward(self, x, count, eval=False):
        x = self.layer1(x)
        A_V = self.attention_V(x)  # NxD
        A_U = self.attention_U(x)
        A = self.attention_weights(A_V*A_U)
        A_T = torch.transpose(A, 1, 0)  # KxN
        A_T = F.softmax(A_T, dim=1)  # softmax over N
        M = torch.mm(A_T, x)  # KxL

        if eval == False:
            pseudo_logits, pseudo_targets, hover_logits = self.inst_eval(A_T, count)
            instance_loss = self.instance_loss(pseudo_logits, pseudo_targets)
       
        logits = self.classifier(M)
        y_probs = F.softmax(logits, dim=1)
        top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(1, )
        top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
        Y_hat = torch.topk(top_instance, 1, dim=1)[1]
        Y_prob = F.softmax(top_instance, dim=1)
        results_dict = {}
        if eval == False:
            results_dict.update({'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'Instance_loss': instance_loss, 'hover_logits':hover_logits})
        elif eval == True:
            results_dict.update({'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat,})


        return results_dict

In [10]:
class Accuracy_Logger(object):
    """Accuracy logger"""

    def __init__(self, n_classes):
        super(Accuracy_Logger, self).__init__()
        self.n_classes = n_classes
        self.initialize()

    def initialize(self):
        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]

    def log(self, Y_hat, Y):
        Y_hat = int(Y_hat)
        Y = int(Y)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)

    def log_batch(self, count, correct, c):
        self.data[c]["count"] += count
        self.data[c]["correct"] += correct

    def log_batch_rnn(self, Y_hat, Y):
        Y_hat = np.array(Y_hat).astype(int)
        Y = np.array(Y).astype(int)
        for label_class in np.unique(Y):
            cls_mask = Y == label_class
            self.data[label_class]["count"] += cls_mask.sum()
            self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()

    def get_summary(self, c):
        count = self.data[c]["count"]
        correct = self.data[c]["correct"]

        if count == 0:
            acc = None
        else:
            acc = float(correct) / count

        return acc, correct, count

In [11]:
def calculate_error(Y_hat, Y):
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
    return error

In [12]:
class EarlyStopping:
    def __init__(self, patience=40, stop_epoch=100, verbose=False):  # 连续patience轮，并且总论此超过stop_epoch轮就会终止
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            stop_epoch (int): Earliest epoch possible for stopping
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.stop_epoch = stop_epoch
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_max = np.Inf

    def __call__(self, epoch, val_loss, model, ckpt_name='checkpoint.pt'):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
        elif score < self.best_score:
            self.counter += 1
            logging.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience and epoch > self.stop_epoch:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
            self.counter = 0

    def save_checkpoint(self, early_stopping, model, ckpt_name):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            logging.info(
                f'Validation loss decreased ({self.val_loss_max:.6f} --> {early_stopping:.6f}).  Saving model ...')
        print("Save model!!!!!!!!!!!")
        torch.save(model.state_dict(), ckpt_name)
        self.val_loss_max = early_stopping

In [13]:
def summary_to_excel(fold, loader_name, label_list, probs_list, y_hat_list, accuracy, specificity, sensitivity, precision, f1_score, cls_auc, auprc):
    df = pd.DataFrame(
        {
            "Patients": loader_name,
            "labels": label_list,
            "probs": probs_list,
            "y_hat": y_hat_list,
            "Accuracy" : accuracy,
            "Specificity": specificity,
            "Sensitivity":sensitivity,
            "Precision": precision,
            "F1-score": f1_score,
            "Auc": cls_auc,
            "AUPRC": auprc,
        }
    )

    if fold == 0:
        with pd.ExcelWriter('./summary/0307/TP53/TP53_0307CV11_clam.xlsx', engine='openpyxl') as writer:
            df.to_excel(writer, sheet_name='Sheet0', index=False)
    else:
        with pd.ExcelWriter('./summary/0307/TP53/TP53_0307CV11_clam.xlsx', engine='openpyxl', mode='a') as writer:
            df.to_excel(writer, sheet_name='Sheet'+str(fold), index=False)

    print("Patients data is successfully written into Excel File")

In [14]:
def summary(model, loader, n_classes, fold):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    model.eval()
    cls_test_error = 0.
    cls_test_loss = 0.

    all_cls_probs = np.zeros((len(loader), n_classes))
    all_cls_labels = np.zeros(len(loader))
    all_cls_y_hats = np.zeros(len(loader))

    patient_results = {}
    
    loader_name = []
    for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
        data = npy_loader(npy_dir[0]).to(device)
        count = npy_loader_count(npy_dir_count[0]).to(device)
        label = label.to(device)
        with torch.no_grad():
            results_dict = model(data, count, eval=True)

        logits, Y_prob, Y_hat = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
        loader_name.append(npy_dir[0])

        cls_logger.log(Y_hat, label)
        cls_probs = Y_prob.cpu().numpy()
        cls_Yhats = Y_hat.cpu().numpy()
        all_cls_probs[batch_idx] = cls_probs
        all_cls_labels[batch_idx] = label.item()
        all_cls_y_hats[batch_idx] = cls_Yhats.item()

        cls_error = calculate_error(Y_hat, label)
        cls_test_error += cls_error

    cls_test_error /= len(loader)

    if n_classes == 2:
        print(all_cls_labels)
        print(all_cls_y_hats)
        tn, fp, fn, tp = confusion_matrix(all_cls_labels, all_cls_y_hats, labels=[0, 1]).ravel()
        accuracy = (tp+tn)/(tp+tn+fp+fn)
        sensitivity = tp/(tp+fn)
        specificity = tn/(tn+fp)
        recall = tp/(tp+fn)
        precision = tp/(tp+fp)
        f1_score = 2 * precision * recall / (precision + recall)
        try:
            cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
            precision1, recall1, _ = precision_recall_curve(all_cls_labels,all_cls_probs[:,1])
            auprc = calc_auc(recall1, precision1)
        except:
            cls_auc = 'nan'
            auprc = 'nan'
        print("Accuracy: "+str(accuracy))
        print("Specificity: "+str(specificity))
        print("Sensitivity: "+str(sensitivity))
        print("Recall: "+str(recall))
        print("Precision: "+str(precision))
        print("F1-score: "+str(f1_score))
        print("Auc: "+str(cls_auc))
        print("AUPRC: "+str(auprc))

        summary_to_excel(fold, loader_name, all_cls_labels, all_cls_probs[:,1], all_cls_y_hats, 
                         accuracy, specificity, sensitivity, precision, f1_score, cls_auc, auprc)
    else:
        cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:,1], multi_class='ovr')

    return patient_results, cls_test_error, cls_auc, cls_logger, all_cls_labels, all_cls_probs[:,1]

In [15]:
def plot_roc_curve(tprs, mean_fpr):
    plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')
    mean_tpr = np.mean(tprs, axis=0)
    mean_auc = calc_auc(mean_fpr, mean_tpr)
    plt.plot(mean_fpr, mean_tpr, color='blue',
            label=r'Mean ROC (AUC = %0.2f )' % (mean_auc),lw=2, alpha=1)

    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC')
    plt.legend(loc="lower right")
    plt.show()

In [16]:
def plot_loss_curve(train_loss, valid_loss, fold):
    title = 'MSI HoverAtt fold' + str(fold+1) + ' loss curve'
    plt.plot(train_loss, label='train loss')
    plt.plot(valid_loss, label='validation loss')
    plt.xlabel('Epoch', fontsize=16)
    plt.ylabel('Loss', fontsize=16)
    plt.title(title, fontsize=18)
    plt.legend(fontsize=16)
    plt.savefig('/home/ldap_howard/script/summary/attention_hover/loss_fold'+str(fold)+'_clam.png')
    plt.show()

In [17]:
def train_loop(epoch, model, loader, optimizer, n_classes, writer=None, loss_fn=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    cls_train_error = 0.
    cls_train_loss = 0.
    train_inst_loss = 0.
    total_loss = 0.
    for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
        data = npy_loader(npy_dir[0]).to(device)
        count = npy_loader_count(npy_dir_count[0]).to(device)
        label = label.to(device)

        results_dict = model(data, count)
        logits, Y_prob, Y_hat, instance_loss, hover_logits = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat'], results_dict['Instance_loss'], results_dict['hover_logits']

        cls_logger.log(Y_hat, label)

        cls_loss = loss_fn(logits, label)
        cls_loss_value = cls_loss.item()

        hover_loss = loss_fn(hover_logits, label)
        hover_loss_value = hover_loss.item()

        total_loss = 0.86*cls_loss + 0.11*hover_loss + 0.03*instance_loss
        cls_train_loss += cls_loss.item()

        cls_error = calculate_error(Y_hat, label)
        cls_train_error += cls_error

        # backward pass
        total_loss.backward()
        # step
        optimizer.step()
        optimizer.zero_grad()

    # calculate loss and error for epoch
    cls_train_loss /= len(loader)
    cls_train_error /= len(loader)

    logging.info(
        'Epoch: {}, cls train_loss: {:.4f}, cls train_error: {:.4f}'.format(epoch, cls_train_loss, cls_train_error))
    for i in range(n_classes):
        acc, correct, count = cls_logger.get_summary(i)
        logging.info('class {}: tpr {:.4f}, correct {}/{}'.format(i, acc, correct, count))
        if writer:
            writer.add_scalar('train/class_{}_tpr'.format(i), acc, epoch)

    if writer:
        writer.add_scalar('train/cls_loss', cls_train_loss, epoch)
        writer.add_scalar('train/cls_error', cls_train_error, epoch)
    return cls_train_loss

In [18]:
def validate(model_name, epoch, model, loader, n_classes, early_stopping=None, writer=None, loss_fn=None, results_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    cls_val_error = 0.
    cls_val_loss = 0.

    cls_probs = np.zeros((len(loader), n_classes))
    cls_labels = np.zeros(len(loader))
    all_labels = []
    all_outputs = []

    with torch.no_grad():
        for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
            data = npy_loader(npy_dir[0]).to(device)
            count = npy_loader_count(npy_dir_count[0]).to(device)
            label = label.to(device)
            
            results_dict = model(data, count, eval=True)
            logits, Y_prob, Y_hat = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
            del results_dict

            cls_logger.log(Y_hat, label)

            cls_loss = loss_fn(logits, label)
            cls_loss_value = cls_loss.item()      

            cls_probs[batch_idx] = Y_prob.cpu().numpy()
            cls_labels[batch_idx] = label.item()

            cls_val_loss += cls_loss.item()
            cls_error = calculate_error(Y_hat, label)
            
            cls_val_error += cls_error

            all_labels.append(label.detach().cpu().numpy()[0])
            all_outputs.append(Y_prob[0][1].detach().cpu().tolist())

    cls_val_error /= len(loader)
    cls_val_loss /= len(loader)

    print("cls val loss: " + str(cls_val_loss))

    if n_classes == 2:
        cls_auc = roc_auc_score(cls_labels, cls_probs[:, 1])
        precision1, recall1, _ = precision_recall_curve(cls_labels,cls_probs[:, 1])
        cls_auprc = calc_auc(recall1, precision1)
        cls_aucs = []
    else:
        cls_aucs = []
        binary_labels = label_binarize(cls_labels, classes=[i for i in range(n_classes)])
        for class_idx in range(n_classes):
            if class_idx in cls_labels:
                fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], cls_probs[:, class_idx])
                cls_aucs.append(calc_auc(fpr, tpr))
            else:
                cls_aucs.append(float('nan'))

        cls_auc = np.nanmean(np.array(cls_aucs))
    print("cls_auc: " + str(cls_auc))
    print("cls_auprc: " + str(cls_auprc))


    if writer:
        writer.add_scalar('val/cls_loss', cls_val_loss, epoch)
        writer.add_scalar('val/cls_auc', cls_auc, epoch)
        writer.add_scalar('val/cls_error', cls_val_error, epoch)

    logging.info(
        '\nVal Set, cls val_loss: {:.4f}, cls val_error: {:.4f}, cls auc: {:.4f}'.format(cls_val_loss, cls_val_error,
                                                                                         cls_auc))
    for i in range(n_classes):
        acc, correct, count = cls_logger.get_summary(i)
        logging.info('class {}: tpr {}, correct {}/{}'.format(i, acc, correct, count))
        if writer:
            writer.add_scalar('val/class_{}_tpr'.format(i), acc, epoch)
    print(model_name)

    if early_stopping:
        assert results_dir
        early_stopping(epoch, cls_val_loss, model,
                       ckpt_name=os.path.join(results_dir, model_name))

        if early_stopping.early_stop:
            logging.info("Early stopping")
            return cls_val_loss, True

    return cls_val_loss, False

In [19]:
train_data = SimpleDataset(train_x, train_y)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
early_stopping = None
writer = None

cuda


In [21]:
i = 0
tprs = []
aucs = []
auprcs = []
precisions = []
y_real = []
y_probs = []
mean_fpr = np.linspace(0,1,100)
mean_recall = np.linspace(0,1,100)
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10,20))

for fold, (train_ids, valid_ids) in enumerate(kf.split(train_x, train_y)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    
    model_name = './model/TP53_0307CV11_fold'+str(fold)+'_clam_checkpoint.pt'
    train_fold = torch.utils.data.Subset(train_data, train_ids)
    valid_fold = torch.utils.data.Subset(train_data, valid_ids)

    train_loader = DataLoader(train_fold, batch_size=1, shuffle = True)
    valid_loader = DataLoader(valid_fold, batch_size=1)
    
    model = Attention()
    model.to(device)

    loss_fn = FocalLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.00006, betas=(0.9, 0.999), weight_decay=0.003)

    writer = None
    early_stopping = EarlyStopping(patience=30, stop_epoch=100, verbose=True)

    results_dir = '/home/ldap_howard/script/'
    cur = 'model'
    train_loss = []
    valid_loss = []

    for epoch in range(2000):
        print(epoch)
        train_epoch_loss = train_loop(epoch, model, train_loader, optimizer, 2, writer, loss_fn)
        valid_epoch_loss, stop = validate(model_name, epoch, model, valid_loader, 2,
                                early_stopping, writer, loss_fn, results_dir)
        
        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)

        if stop:
            break
    
    if early_stopping:
        print("EarlyStopping")
        model.load_state_dict(torch.load(os.path.join(results_dir, model_name))) #"s_{}_checkpoint.pt".format(cur))))
    else:
        torch.save(model.state_dict(), os.path.join(results_dir, model_name)) #"s_{}_{}_checkpoint.pt".format(cur)))

    _, cls_val_error, cls_val_auc, _, cls_labels, cls_probs = summary(model, valid_loader, 2, fold)
    fpr, tpr, t = roc_curve(cls_labels, cls_probs)
    tprs.append(np.interp(mean_fpr, fpr, tpr))
    roc_auc = calc_auc(fpr, tpr)
    aucs.append(roc_auc)

    lab_fold = 'Fold %d AUROC=%.3f' % (i+1, roc_auc)
    ax1.plot(fpr, tpr, alpha=0.8, lw=4, label=lab_fold)

    precision_fold, recall_fold, _ = precision_recall_curve(cls_labels, cls_probs)
    precision_fold, recall_fold = precision_fold[::-1], recall_fold[::-1]
    precisions.append(np.interp(mean_recall, recall_fold, precision_fold))
    auprc = calc_auc(recall_fold, precision_fold)
    auprcs.append(auprc)
    y_real.append(cls_labels)
    y_probs.append(cls_probs)
    
    lab_fold = 'Fold %d AUPRC=%.3f' % (i+1, auprc)
    ax2.plot(recall_fold, precision_fold, alpha=0.8, lw=4, label=lab_fold)

    i=i+1
    logging.info('Cls Val error: {:.4f}, Cls ROC AUC: {:.4f}'.format(cls_val_error, cls_val_auc))
    print("Validation acc: " + str(cls_val_auc))
    print("validation error: " + str(cls_val_error))

    #plot_loss_curve(train_loss, valid_loss, fold)

ax1.set_title('AUROC',fontsize=20)
ax1.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
ax1.set(
    xlabel="False Positive Rate",
    ylabel="True Positive Rate",
    title=f"ROC curves",
)
ax1.set_xticks(np.arange(0, 1.1, 0.2))
ax1.set_yticks(np.arange(0, 1.1, 0.2))
ax1.tick_params(axis='both', which='major', labelsize=20)
ax1.set_xlabel("False Positive Rate", fontsize=24)
ax1.set_ylabel("True Positive Rate", fontsize=24)
ax1.set_title("HoverAtt(TP53)", fontsize=32)
ax1.legend(loc="lower right", fontsize = 20)

ax2.set_title('AUPRC',fontsize=20)
ax2.plot([0, 1], [1, 0], linestyle='--', lw=2, color='r', label='Chance', alpha=.8)
ax2.set(
    xlabel="Recall",
    ylabel="Precision",
    title=f"PR curves",
)
ax2.set_xticks(np.arange(0, 1.1, 0.2))
ax2.set_yticks(np.arange(0, 1.1, 0.2))
ax2.tick_params(axis='both', which='major', labelsize=20)
ax2.set_xlabel("Recall", fontsize=24)
ax2.set_ylabel("Precision", fontsize=24)
ax2.set_title("HoverAtt(TP53)", fontsize=32)
ax2.legend(loc='lower left', fontsize=20)

plt.show()
fig.savefig('/home/ldap_howard/script/summary/attention_hover/TP53_clam.png')


FOLD 0
--------------------------------


## MSIL ##

In [52]:
test_dataset = pd.read_excel('/home/ldap_howard/script/MSI_CRC_DX_0307_NA.xlsx', sheet_name='MSIL')
test_x = test_dataset['Patients'].values
test_y = test_dataset['isMSIH'].values
test_data = SimpleDataset(test_x, test_y)
test_loader = DataLoader(test_data, batch_size=1)

In [53]:
model = Attention().to(device)
model.load_state_dict(torch.load('/home/ldap_howard/script/model/MSI_0307CV13_fold1_clam_checkpoint.pt'))
model.eval()

Attention(
  (layer1): Linear(in_features=2048, out_features=1024, bias=True)
  (attention_V): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_U): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_weights): Linear(in_features=512, out_features=1, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=2, bias=True)
    (6): Sigmoid()
  )
  (instance_loss): FocalLoss()
  (fc_c): Sequential(
    (0): Linear(in_features=4, out_features=2, bias=True)
    (1): Sigmoid()
  )
  (fc_X): Sequential(
    (0): Linear(in_features=1, out_features=2, bias=True)
    (1): Sigmoid(

In [54]:
summary(model, test_loader, 2, 'CRC_NA_MSIL')

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 1. 0. 0. 1. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.
 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0

({},
 0.15060240963855423,
 0.7099116161616161,
 <__main__.Accuracy_Logger at 0x7fa37611afb0>,
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0.64317197, 0.66073507, 0.27932507, 0.31739086, 0.55644733,
        0.60599965, 0.29191434, 0.27950007, 0.32653153, 0.59652179,
        0.28533968, 0.2

### PAIP ###

In [104]:
class SimpleDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.path = '/ORCA_lake/TCGA-COAD/feature/PAIP_0307/'
        self.path_c = '/ORCA_lake/TCGA-COAD/hovernet_kmeans/PAIP_0307_MI2N/'
        
    def __len__(self):
        return len(self.x)   
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.x[idx]+'.npy')
        count_path = os.path.join(self.path_c, self.x[idx]+'.npy')

        return image_path, count_path, self.y[idx]

In [105]:
test_dataset = pd.read_excel('/home/ldap_howard/script/MSI_PAIP.xlsx')
test_x = test_dataset['Patients'].values
test_y = test_dataset['isMSIH'].values
test_data = SimpleDataset(test_x, test_y)
test_loader = DataLoader(test_data, batch_size=1)

In [107]:
model = Attention().to(device)
model.load_state_dict(torch.load('/home/ldap_howard/script/model/MSI_0307CV26_fold3_clam_checkpoint.pt'))
model.eval()

Attention(
  (layer1): Linear(in_features=2048, out_features=1024, bias=True)
  (attention_V): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_U): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_weights): Linear(in_features=512, out_features=1, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=2, bias=True)
    (6): Sigmoid()
  )
  (instance_loss): FocalLoss()
  (fc_c): Sequential(
    (0): Linear(in_features=19, out_features=2, bias=True)
    (1): Sigmoid()
  )
  (fc_X): Sequential(
    (0): Linear(in_features=1, out_features=2, bias=True)
    (1): Sigmoid

In [108]:
summary(model, test_loader, 2, 'PAIP3')

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
Accuracy: 0.8085106382978723
Specificity: 1.0
Sensitivity: 0.25
Recall: 0.25
Precision: 1.0
F1-score: 0.4
Auc: 0.9523809523809523
AUPRC: 0.9081551732867522
Patients data is successfully written into Excel File


({},
 0.19148936170212766,
 0.9523809523809523,
 <__main__.Accuracy_Logger at 0x7fb28c28d8d0>,
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]),
 array([0.57398266, 0.45270523, 0.38745889, 0.65565842, 0.32585603,
        0.2735382 , 0.48353097, 0.35349894, 0.27467501, 0.32804441,
        0.45700932, 0.65908653, 0.27158386, 0.26972786, 0.27510515,
        0.2721498 , 0.27069366, 0.27129275, 0.27123681, 0.33996898,
        0.27982265, 0.2879568 , 0.27321532, 0.27591282, 0.27932218,
        0.27228701, 0.27022722, 0.27123752, 0.27205902, 0.27101848,
        0.27168006, 0.27867323, 0.27108306, 0.27102104, 0.26938   ,
        0.27347121, 0.27040312, 0.27233604, 0.27585796, 0.27051467,
        0.27100447, 0.27219978, 0.26967797, 0.26970273, 0.27327457,
        0.27806199, 0.2701208 ]))

## CPTAC ##

In [109]:
class SimpleDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.path = '/ORCA_lake/TCGA-COAD/feature/CPTAC_0307/'
        self.path_c = '/ORCA_lake/TCGA-COAD/hovernet_kmeans/CPTAC_0307_MI2N/'
        
    def __len__(self):
        return len(self.x)   
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.x[idx]+'.npy')
        count_path = os.path.join(self.path_c, self.x[idx]+'.npy')

        return image_path, count_path, self.y[idx]

In [110]:
test_dataset = pd.read_excel('/home/ldap_howard/script/MSI_CPTAC_0307.xlsx')
test_x = test_dataset['Patients'].values
test_y = test_dataset['isMSIH'].values
test_data = SimpleDataset(test_x, test_y)
test_loader = DataLoader(test_data, batch_size=1)

In [28]:
model = Attention().to(device)
model.load_state_dict(torch.load('/home/ldap_howard/script/model/MSI_0307CV5_fold1_clam_checkpoint.pt'))
model.eval()

Attention(
  (layer1): Linear(in_features=2048, out_features=1024, bias=True)
  (attention_V): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_U): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_weights): Linear(in_features=512, out_features=1, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=2, bias=True)
    (6): Sigmoid()
  )
  (instance_loss): FocalLoss()
  (fc_X): Sequential(
    (0): Linear(in_features=1024, out_features=2, bias=True)
    (1): Sigmoid()
  )
  (fc_c): Sequential(
    (0): Linear(in_features=4, out_features=2, bias=True)
    (1): Sigmo

In [111]:
summary(model, test_loader, 2, 'CPTAC3')

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
[1. 1. 1. 0. 1. 0. 1. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1.
 0. 1. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 1. 

({},
 0.19090909090909092,
 0.8222298534798536,
 <__main__.Accuracy_Logger at 0x7fb28c28d2d0>,
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0