In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision import models as torch_models1
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import openpyxl
import xlsxwriter
from collections import namedtuple
import os
import random
import math
from PIL import Image
from scipy import interp
from sklearn.model_selection import KFold, StratifiedKFold
import pandas as pd
from torch.utils.data import Sampler, BatchSampler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, recall_score, precision_score, confusion_matrix, precision_recall_curve, f1_score
from torch.autograd import Variable
import logging
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
from sklearn.metrics import auc as calc_auc
from topk.svm import SmoothTop1SVM
from torch_kmeans import KMeans
from topk.svm import SmoothTop1SVM
from itertools import cycle



In [2]:
SEED = 1234
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
train_dataset = pd.read_excel('/home/ldap_howard/script/CRC_CMS.xlsx',sheet_name='CMS_0604')
train_x = train_dataset['Patients'].values
train_y = train_dataset['status']
train_y = train_y.values

In [4]:
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=14)

In [5]:
def npy_loader(path):
    x = np.load(path,allow_pickle=True).item()
    x_im = torch.from_numpy(x['features'])
    return x_im

def npy_loader_count(path):
    x = np.load(path,allow_pickle=True).item()
    x_im = torch.from_numpy(x['counts'])
    #x_im = x_im[:,1:3]
    return x_im

In [6]:
class SimpleDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.path = '/LVM_data/ldap_howard/feature/CRC_resnet0307/'
        self.path_c = '/LVM_data/ldap_howard/feature/CRC_0307_MI2N/'
        
    def __len__(self):
        return len(self.x)   
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.x[idx]+'.npy')
        count_path = os.path.join(self.path_c, self.x[idx]+'.npy')

        return image_path, count_path, self.y[idx]

In [7]:
class FocalLoss(nn.Module):
    
    def __init__(self, weight=None, 
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, logits, label):
        log_prob = F.log_softmax(logits, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob, 
            label, 
            weight=self.weight,
            reduction = self.reduction
        )

In [8]:
class Attention(nn.Module):
    def __init__(self, L=2048, D=1024, dropout=True, n_classes=2, top_k=1):
        super(Attention, self).__init__()
        self.L = L
        self.D = D
        self.K = 1

        self.layer1 = nn.Linear(self.L, self.D)
        if dropout:
            self.attention_V = nn.Sequential(nn.Linear(self.D, 512), nn.Tanh(), nn.Dropout(0.25))
            self.attention_U = nn.Sequential(nn.Linear(self.D, 512), nn.Sigmoid(), nn.Dropout(0.25))
        else:
            self.attention_V = nn.Sequential(nn.Linear(self.D, 512), nn.Tanh())
            self.attention_U = nn.Sequential(nn.Linear(self.D, 512), nn.Sigmoid())

        self.attention_weights = nn.Linear(512, self.K)

        self.classifier = nn.Sequential(nn.Linear(1028,1024),
                                        nn.ReLU(),
                                        nn.Dropout(0.25),
                                        nn.Linear(1024, 512),
                                        nn.ReLU(), 
                                        nn.Linear(512,4),
                                        nn.Sigmoid())
        self.top_k = top_k
        self.instance_loss = nn.CrossEntropyLoss()
        self.fc_c1 = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
        self.fc_X = nn.Sequential(nn.Linear(1, 4), nn.Sigmoid())

    def relocate(self):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.classifier.to(device)
        self.attention_V.to(device)
        self.attention_U.to(device)
        self.attention_weights.to(device)
        self.instance_loss.to(device)
        self.fc_c.to(device)
        self.fc_X.to(device)
    
    def inst_eval(self, A_T, count):
        logits_c = self.fc_c1(count)
        hover_logits = torch.mm(A_T, logits_c)
        y_probs_c = F.softmax(logits_c, dim=1)
        #k = math.ceil(logits_c.size()[0] / 20)
        _, predicted_class = torch.max(y_probs_c, dim=1)
        predicted_prob = y_probs_c[torch.arange(y_probs_c.size(0)), predicted_class]
        top_instance_idx = torch.topk(predicted_prob, 5, largest=True)[1]
        top_instance = torch.index_select(y_probs_c, dim=0, index=top_instance_idx)
        _, pseudo_targets = torch.max(top_instance, dim=1)

        A_T = torch.transpose(A_T, 1, 0)  # KxN
        logits_x = self.fc_X(A_T)
        y_probs_x = F.softmax(logits_x, dim=1)
        top_instance_x = torch.index_select(logits_x, dim=0, index=top_instance_idx)
        pseudo_logits = top_instance_x

        return pseudo_logits, pseudo_targets,  hover_logits

    def forward(self, x, count, eval=False):
        x = self.layer1(x)
        A_V = self.attention_V(x)  # NxD
        A_U = self.attention_U(x)
        A = self.attention_weights(A_V*A_U)
        A_T = torch.transpose(A, 1, 0)  # KxN
        A_T = F.softmax(A_T, dim=1)  # softmax over N
        M = torch.mm(A_T, x)  # KxL

        pseudo_logits, pseudo_targets, hover_logits = self.inst_eval(A_T, count)
        instance_loss = self.instance_loss(pseudo_logits, pseudo_targets)
        M  = torch.cat((M, hover_logits), dim=1)
       
        logits = self.classifier(M)
        y_probs = F.softmax(logits, dim=1)
        max_scores, max_indice = torch.max(y_probs, dim=1)
        Y_hat = max_indice
        Y_prob = y_probs
        results_dict = {}
        if eval == False:
            results_dict.update({'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat, 'Instance_loss': instance_loss, 'hover_logits':hover_logits})
        elif eval == True:
            results_dict.update({'logits': logits, 'Y_prob': Y_prob, 'Y_hat': Y_hat,})


        return results_dict

In [9]:
class Accuracy_Logger(object):
    """Accuracy logger"""

    def __init__(self, n_classes):
        super(Accuracy_Logger, self).__init__()
        self.n_classes = n_classes
        self.initialize()

    def initialize(self):
        self.data = [{"count": 0, "correct": 0} for i in range(self.n_classes)]

    def log(self, Y_hat, Y):
        Y_hat = int(Y_hat)
        Y = int(Y)
        self.data[Y]["count"] += 1
        self.data[Y]["correct"] += (Y_hat == Y)

    def log_batch(self, count, correct, c):
        self.data[c]["count"] += count
        self.data[c]["correct"] += correct

    def log_batch_rnn(self, Y_hat, Y):
        Y_hat = np.array(Y_hat).astype(int)
        Y = np.array(Y).astype(int)
        for label_class in np.unique(Y):
            cls_mask = Y == label_class
            self.data[label_class]["count"] += cls_mask.sum()
            self.data[label_class]["correct"] += (Y_hat[cls_mask] == Y[cls_mask]).sum()

    def get_summary(self, c):
        count = self.data[c]["count"]
        correct = self.data[c]["correct"]

        if count == 0:
            acc = None
        else:
            acc = float(correct) / count

        return acc, correct, count

In [10]:
def calculate_error(Y_hat, Y):
    error = 1. - Y_hat.float().eq(Y.float()).float().mean().item()
    return error

In [11]:
class EarlyStopping:
    def __init__(self, patience=40, stop_epoch=100, verbose=False):  # 连续patience轮，并且总论此超过stop_epoch轮就会终止
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 20
            stop_epoch (int): Earliest epoch possible for stopping
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
        """
        self.patience = patience
        self.stop_epoch = stop_epoch
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_max = np.Inf

    def __call__(self, epoch, val_loss, model, ckpt_name='checkpoint.pt'):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
        elif score < self.best_score:
            self.counter += 1
            logging.info(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience and epoch > self.stop_epoch:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, ckpt_name)
            self.counter = 0

    def save_checkpoint(self, early_stopping, model, ckpt_name):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            logging.info(
                f'Validation loss decreased ({self.val_loss_max:.6f} --> {early_stopping:.6f}).  Saving model ...')
        print("Save model!!!!!!!!!!!")
        torch.save(model.state_dict(), ckpt_name)
        self.val_loss_max = early_stopping

In [12]:
def summary_to_excel(fold, loader_name, label_list, probs_list, y_hat_list, accuracy, precision, recall, f1, cls_auc):
    df = pd.DataFrame(
        {
            "Patients": loader_name,
            "labels": label_list,
            "CMS1_probs": probs_list[:,0],
            "CMS2_probs": probs_list[:,1],
            "CMS3_probs": probs_list[:,2],
            "CMS4_probs": probs_list[:,3], 
            "y_hat": y_hat_list,
            "Accuracy" : accuracy,
            "Auc": cls_auc
        }
    )

    if fold == 0:
        with pd.ExcelWriter('./summary/0307/CMS/CMS_0307CV7_clam.xlsx', engine='openpyxl') as writer:
            df.to_excel(writer, sheet_name='Sheet0', index=False)
    else:
        with pd.ExcelWriter('./summary/0307/CMS/CMS_0307CV7_clam.xlsx', engine='openpyxl', mode='a') as writer:
            df.to_excel(writer, sheet_name='Sheet'+str(fold), index=False)

    print("Patients data is successfully written into Excel File")

In [13]:
def summary(model, loader, n_classes, fold):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #cls_logger = Accuracy_Logger(n_classes=n_classes)
    model.eval()
    cls_test_error = 0.
    cls_test_loss = 0.

    all_cls_probs = np.zeros((len(loader), n_classes))
    all_cls_logits = np.zeros((len(loader), n_classes))
    all_cls_labels = np.zeros(len(loader))
    all_cls_y_hats = np.zeros(len(loader))
    all_cls_onehot = np.zeros((len(loader), n_classes))

    patient_results = {}
    
    loader_name = []
    for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
        data = npy_loader(npy_dir[0]).to(device)
        count = npy_loader_count(npy_dir_count[0]).to(device)
        label_numpy = label.cpu().numpy()
        label = label.to(device)
        label_onehot = label_binarize(label_numpy, classes=np.arange(4))
        
        with torch.no_grad():
            results_dict = model(data, count, eval=True)

        logits, Y_prob, Y_hat = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
        loader_name.append(npy_dir[0])

        #cls_logger.log(Y_hat, label)
        cls_logits = logits.cpu().numpy()
        cls_probs = Y_prob.cpu().numpy()
        cls_Yhats = Y_hat.cpu().numpy()
        all_cls_logits[batch_idx] = cls_logits
        all_cls_probs[batch_idx] = cls_probs
        all_cls_labels[batch_idx] = label.item()
        all_cls_y_hats[batch_idx] = cls_Yhats.item()
        all_cls_onehot[batch_idx] = label_onehot

        cls_error = calculate_error(Y_hat, label)
        cls_test_error += cls_error

    cls_test_error /= len(loader)

    if n_classes == 2:
        print(all_cls_labels)
        print(all_cls_y_hats)
        tn, fp, fn, tp = confusion_matrix(all_cls_labels, all_cls_y_hats, labels=[0, 1]).ravel()
        accuracy = (tp+tn)/(tp+tn+fp+fn)
        sensitivity = tp/(tp+fn)
        specificity = tn/(tn+fp)
        recall = tp/(tp+fn)
        precision = tp/(tp+fp)
        try:
            cls_auc = roc_auc_score(all_cls_labels, all_cls_probs[:, 1])
            precision1, recall1, _ = precision_recall_curve(all_cls_labels,all_cls_probs[:,1])
            auprc = calc_auc(recall1, precision1)
        except:
            cls_auc = 'nan'
            auprc = 'nan'
        print("Accuracy: "+str(accuracy))
        print("Specificity: "+str(specificity))
        print("Sensitivity: "+str(sensitivity))
        print("Recall: "+str(recall))
        print("Precision: "+str(precision))
        print("Auc: "+str(cls_auc))
        print("AUPRC: "+str(auprc))

        summary_to_excel(fold, loader_name, all_cls_labels, all_cls_probs[:,1], all_cls_y_hats, 
                         accuracy, specificity, sensitivity, precision, f1_score, cls_auc, auprc)
    else:
        auc_scores = {}
        for i in range(n_classes):
            try:
                # Binary labels: 1 for the current class, 0 for the rest
                binary_labels = (all_cls_labels == i).astype(int)
                cls_auc = roc_auc_score(binary_labels, all_cls_probs[:, i])
                auc_scores[f'Class {i}'] = cls_auc
            except ValueError as e:
                auc_scores[f'Class {i}'] = 'nan'
        
        # Print AUROC scores for each class
        for class_label, auc in auc_scores.items():
            print(f"{class_label} AUROC: {auc}")
            
        cls_auc = roc_auc_score(all_cls_labels, all_cls_probs, multi_class='ovr')
        correct = (all_cls_y_hats == all_cls_labels).sum().item()
        accuracy = float(correct / all_cls_labels.shape[0])
        precision = precision_score(all_cls_y_hats, all_cls_labels, average=None)
        recall = recall_score(all_cls_y_hats, all_cls_labels, average=None)
        f1 = f1_score(all_cls_y_hats, all_cls_labels, average=None)
        print("Accuracy: "+str(accuracy))
        print("Precision: "+str(precision))
        print("Recall: "+str(recall))
        print("f1: "+str(f1))
        print(all_cls_onehot)
        summary_to_excel(fold, loader_name, all_cls_labels, all_cls_probs, all_cls_y_hats, 
                         accuracy, precision, recall, f1, cls_auc)

    return patient_results, cls_test_error, cls_auc, all_cls_onehot, all_cls_probs

In [14]:
def plot_roc_curve(tprs, mean_fpr):
    plt.plot([0,1],[0,1],linestyle = '--',lw = 2,color = 'black')
    mean_tpr = np.mean(tprs, axis=0)
    mean_auc = calc_auc(mean_fpr, mean_tpr)
    plt.plot(mean_fpr, mean_tpr, color='blue',
            label=r'Mean ROC (AUC = %0.2f )' % (mean_auc),lw=2, alpha=1)

    plt.xlabel('False Positive Rate', fontsize=18)
    plt.ylabel('True Positive Rate', fontsize=18)
    plt.title('ROC', fontsize=18)
    plt.legend(loc="lower right")
    plt.show()

In [15]:
def plot_loss_curve(train_loss, valid_loss, fold):
    title = 'CMS HoverAtt-CMS fold' + str(fold+1) + ' loss curve'
    plt.plot(train_loss, label='train loss')
    plt.plot(valid_loss, label='validation loss')
    plt.xlabel('Epoch', fontsize=16)
    plt.ylabel('Loss', fontsize=16)
    plt.title(title, fontsize=18)
    plt.legend(fontsize=16)
    plt.savefig('/home/ldap_howard/script/summary/attention_hover/loss_fold'+str(fold)+'_clam.png')
    plt.show()

In [16]:
def train_loop(epoch, model, loader, optimizer, n_classes, writer=None, loss_fn=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    cls_logger = Accuracy_Logger(n_classes=n_classes)
    cls_train_error = 0.
    cls_train_loss = 0.
    train_inst_loss = 0.
    total_loss = 0.
    for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
        data = npy_loader(npy_dir[0]).to(device)
        count = npy_loader_count(npy_dir_count[0]).to(device)
        label = label.to(device)

        results_dict = model(data, count)
        logits, Y_prob, Y_hat, instance_loss, hover_logits = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat'], results_dict['Instance_loss'], results_dict['hover_logits']

        cls_loss = loss_fn(logits, label)
        cls_loss_value = cls_loss.item()

        hover_loss = loss_fn(hover_logits, label)
        hover_loss_value = hover_loss.item()

        total_loss = 0.95*cls_loss + 0.05*instance_loss
        cls_train_loss += cls_loss.item()

        cls_error = calculate_error(Y_hat, label)
        cls_train_error += cls_error

        # backward pass
        total_loss.backward()
        # step
        optimizer.step()
        optimizer.zero_grad()

    # calculate loss and error for epoch
    cls_train_loss /= len(loader)
    cls_train_error /= len(loader)

    return cls_train_loss

In [17]:
def validate(model_name, epoch, model, loader, n_classes, early_stopping=None, writer=None, loss_fn=None, results_dir=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    #cls_logger = Accuracy_Logger(n_classes=n_classes)
    cls_val_error = 0.
    cls_val_loss = 0.

    cls_probs = np.zeros((len(loader), n_classes))
    cls_labels = np.zeros(len(loader))
    all_cls_labels = np.zeros(len(loader))
    all_cls_y_hats = np.zeros(len(loader))


    with torch.no_grad():
        for batch_idx, (npy_dir, npy_dir_count, label) in enumerate(loader):
            data = npy_loader(npy_dir[0]).to(device)
            count = npy_loader_count(npy_dir_count[0]).to(device)
            label = label.to(device)

            results_dict = model(data, count)
            logits, Y_prob, Y_hat = results_dict['logits'], results_dict['Y_prob'], results_dict['Y_hat']
            del results_dict

            cls_loss = loss_fn(logits, label)
            cls_loss_value = cls_loss.item()

            cls_probs[batch_idx] = Y_prob.cpu().numpy()
            cls_labels[batch_idx] = label.item()

            cls_val_loss += cls_loss_value
            cls_error = calculate_error(Y_hat, label)
            cls_val_error += cls_error

            cls_Yhats = Y_hat.cpu().numpy()
            all_cls_labels[batch_idx] = label.item()
            all_cls_y_hats[batch_idx] = cls_Yhats.item()            

    cls_val_error /= len(loader)
    cls_val_loss /= len(loader)

    all_cls_labels[batch_idx] = label.item()
    all_cls_y_hats[batch_idx] = cls_Yhats.item()

    print("cls val loss: " + str(cls_val_loss))

    if n_classes == 2:
        cls_auc = roc_auc_score(cls_labels, cls_probs[:, 1])
        precision1, recall1, _ = precision_recall_curve(cls_labels,cls_probs[:, 1])
        cls_auprc = calc_auc(recall1, precision1)     
        cls_aucs = []
    else:
        cls_aucs = []
        binary_labels = label_binarize(cls_labels, classes=[i for i in range(n_classes)])
        correct = (all_cls_y_hats == all_cls_labels).sum().item()
        accuracy = float(correct / all_cls_labels.shape[0])
        for class_idx in range(n_classes):
            if class_idx in cls_labels:
                fpr, tpr, _ = roc_curve(binary_labels[:, class_idx], cls_probs[:, class_idx])
                cls_aucs.append(calc_auc(fpr, tpr))
            else:
                cls_aucs.append(float('nan'))

        cls_auc = np.nanmean(np.array(cls_aucs))
    print("Accuracy: " + str(accuracy))
    print("cls_auprc: " + str(cls_auc))

    if early_stopping:
        assert results_dir
        early_stopping(epoch, cls_val_loss, model,
                       ckpt_name=os.path.join(results_dir, model_name))

        if early_stopping.early_stop:
            logging.info("Early stopping")
            return cls_val_loss, True

    return cls_val_loss, False    

In [18]:
class FocalLossCustom(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='none'):
        super(FocalLossCustom, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def forward(self, logits, label, pseudo_targets):
        ce_loss = F.cross_entropy(logits,label, reduction='none')
        pt = torch.exp(-ce_loss)

        N = len(pseudo_targets)
        pseudo_targets = torch.mean(pseudo_targets, dim=0).unsqueeze(0)
        ce_loss_pseudo = F.cross_entropy(pseudo_targets, label, reduction='none')
        pt_pseudo = torch.exp(-ce_loss_pseudo)
        
        focal_loss = 0.8*(self.alpha * (1 - pt) ** self.gamma * ce_loss) + 0.2*ce_loss_pseudo / math.log(N)

        if self.reduction == 'mean':
            return torch.mean(focal_loss)
        elif self.reduction == 'sum':
            return torch.sum(focal_loss)
        else:
            return focal_loss

In [19]:
train_data = SimpleDataset(train_x, train_y)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
early_stopping = None
writer = None

cuda


In [21]:
i = 0
precisions = []
y_real = []
y_probs = []
mean_fpr = np.linspace(0,1,100)
mean_recall = np.linspace(0,1,100)
tprs = {i: [] for i in range(4)}
aucs = {i: [] for i in range(4)}
plt.figure(figsize=(10, 8))

for fold, (train_ids, valid_ids) in enumerate(kf.split(train_x, train_y)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    
    model_name = './model/CMS_0307CV7_fold'+str(fold)+'_clam_checkpoint.pt'
    train_fold = torch.utils.data.Subset(train_data, train_ids)
    valid_fold = torch.utils.data.Subset(train_data, valid_ids)

    train_loader = DataLoader(train_fold, batch_size=1, shuffle = True)
    valid_loader = DataLoader(valid_fold, batch_size=1)
    
    model = Attention()
    model.to(device)

    loss_fn = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.00008, betas=(0.9, 0.999), weight_decay=0.001)

    writer = None
    early_stopping = EarlyStopping(patience=30, stop_epoch=100, verbose=True)

    results_dir = '/home/ldap_howard/script/'
    cur = 'model'
    train_loss = []
    valid_loss = []

    for epoch in range(1000):
        print(epoch)
        train_epoch_loss = train_loop(epoch, model, train_loader, optimizer, 4, writer, loss_fn)
        valid_epoch_loss, stop = validate(model_name, epoch, model, valid_loader, 4,
                                early_stopping, writer, loss_fn, results_dir)
        
        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)

        if stop:
            break
    
    if early_stopping:
        print("EarlyStopping")
        model.load_state_dict(torch.load(os.path.join(results_dir, model_name))) #"s_{}_checkpoint.pt".format(cur))))
    else:
        torch.save(model.state_dict(), os.path.join(results_dir, model_name)) #"s_{}_{}_checkpoint.pt".format(cur)))

    _, cls_val_error, cls_val_auc, cls_labels, cls_probs = summary(model, valid_loader, 4, fold)

    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    for i in range(4):
        fpr[i], tpr[i], _ = roc_curve(cls_labels[:, i], cls_probs[:, i])
        roc_auc[i] = calc_auc(fpr[i], tpr[i])
        tprs[i].append(np.interp(mean_fpr, fpr[i], tpr[i]))
        tprs[i][-1][0] = 0.0
        aucs[i].append(roc_auc[i])

    i=i+1
    logging.info('Cls Val error: {:.4f}, Cls ROC AUC: {:.4f}'.format(cls_val_error, cls_val_auc))
    print("Validation acc: " + str(cls_val_auc))
    print("validation error: " + str(cls_val_error))

    #plot_loss_curve(train_loss, valid_loss, fold)

#colors = cycle(["aqua", "darkorange", "cornflowerblue", "red"])
for i in range(4):
    mean_tpr = np.mean(tprs[i], axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = calc_auc(mean_fpr, mean_tpr)
    std_auc = np.std(aucs[i])

    plt.plot(mean_fpr, mean_tpr,
             label=r'Mean ROC of class {0} (area = {1:0.2f} $\pm$ {2:0.2f})'
             ''.format(i, mean_auc, std_auc), lw=2, alpha=0.8)

    std_tpr = np.std(tprs[i], axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, alpha=0.2)

plt.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", alpha=0.8)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel("False Positive Rate", fontsize = 20)
plt.ylabel("True Positive Rate", fontsize = 20)
plt.title("ROC curves", fontsize = 20)
plt.legend(loc="lower right", fontsize = 18)
plt.savefig('/home/ldap_howard/script/summary/attention_hover/auroc_clam.png')
plt.show()


FOLD 0
--------------------------------
0
cls val loss: 1.281624504498073
Accuracy: 0.44047619047619047
cls_auprc: 0.581200842710284
Save model!!!!!!!!!!!
1
cls val loss: 1.262003351535116
Accuracy: 0.44047619047619047
cls_auprc: 0.5920004810591597
Save model!!!!!!!!!!!
2
cls val loss: 1.2408851463170278
Accuracy: 0.4642857142857143
cls_auprc: 0.6558464158590915
Save model!!!!!!!!!!!
3
cls val loss: 1.2326319104149228
Accuracy: 0.47619047619047616
cls_auprc: 0.6647492648748878
Save model!!!!!!!!!!!
4
cls val loss: 1.2534320318982715
Accuracy: 0.4642857142857143
cls_auprc: 0.6432675683191648
5
cls val loss: 1.1971745888392131
Accuracy: 0.5119047619047619
cls_auprc: 0.6740196801806702
Save model!!!!!!!!!!!
6
cls val loss: 1.1842963901304064
Accuracy: 0.5119047619047619
cls_auprc: 0.6883973353961459
Save model!!!!!!!!!!!
7
cls val loss: 1.2012774738527479
Accuracy: 0.5
cls_auprc: 0.6890236300013318
8
cls val loss: 1.1753797445978438
Accuracy: 0.5476190476190477
cls_auprc: 0.71775537875871

In [21]:
class SimpleDataset(Dataset):
    def __init__(self, x, y):
        super().__init__()
        self.x = x
        self.y = y
        self.path = '/ORCA_lake/TCGA-COAD/feature/CRC_resnet0307/'
        self.path_c = '/ORCA_lake/TCGA-COAD/hovernet_kmeans/CRC_0307_MI2N/'
        
    def __len__(self):
        return len(self.x)   
    
    def __getitem__(self, idx):
        image_path = os.path.join(self.path, self.x[idx]+'.npy')
        count_path = os.path.join(self.path_c, self.x[idx]+'.npy')

        return image_path, count_path, self.y[idx]

In [23]:
test_dataset = pd.read_excel('/home/ldap_howard/script/CRC_CMS.xlsx',sheet_name = 'CMS_0604')
test_x = test_dataset['Patients'].values
test_y = test_dataset['status'].values
test_data = SimpleDataset(test_x, test_y)
test_loader = DataLoader(test_data, batch_size=1)

In [24]:
model = Attention().to(device)
model.load_state_dict(torch.load('/home/ldap_howard/script/model/CMS_0307CV6_fold2_clam_checkpoint.pt'))
model.eval()

Attention(
  (layer1): Linear(in_features=2048, out_features=1024, bias=True)
  (attention_V): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_U): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Sigmoid()
    (2): Dropout(p=0.25, inplace=False)
  )
  (attention_weights): Linear(in_features=512, out_features=1, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=1028, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=1024, out_features=512, bias=True)
    (4): ReLU()
    (5): Linear(in_features=512, out_features=4, bias=True)
    (6): Sigmoid()
  )
  (instance_loss): CrossEntropyLoss()
  (fc_c1): Sequential(
    (0): Linear(in_features=4, out_features=4, bias=True)
    (1): ReLU()
  )
  (fc_X): Sequential(
    (0): Linear(in_features=1, out_features=4, bias=True)
    (1): Sig

In [25]:
summary(model, test_loader, 4, 'CRC')

Class 0 AUROC: 0.9300460223537146
Class 1 AUROC: 0.8932173504173034
Class 2 AUROC: 0.9301029962546817
Class 3 AUROC: 0.8773024361259656
Accuracy: 0.75
Precision: [0.83076923 0.79558011 0.83333333 0.58181818]
Recall: [0.84375    0.74611399 0.73529412 0.7032967 ]
f1: [0.8372093  0.77005348 0.78125    0.63681592]
Patients data is successfully written into Excel File


({},
 0.25,
 0.9076672012879163,
 array([0., 3., 2., 1., 1., 0., 1., 2., 1., 1., 1., 3., 2., 2., 1., 0., 1.,
        2., 1., 1., 0., 0., 2., 1., 3., 0., 0., 1., 0., 3., 1., 1., 2., 1.,
        0., 1., 3., 2., 1., 1., 0., 1., 1., 3., 1., 2., 1., 0., 0., 2., 2.,
        1., 1., 1., 1., 0., 0., 2., 1., 1., 3., 3., 3., 1., 3., 1., 1., 0.,
        0., 1., 0., 3., 2., 1., 1., 1., 1., 3., 1., 2., 2., 2., 1., 1., 2.,
        1., 1., 1., 1., 1., 1., 0., 2., 1., 2., 1., 3., 3., 1., 2., 1., 1.,
        1., 3., 1., 1., 1., 1., 1., 1., 1., 1., 3., 3., 3., 3., 1., 3., 0.,
        0., 2., 0., 3., 1., 1., 3., 0., 1., 3., 2., 1., 3., 3., 1., 1., 0.,
        3., 3., 2., 3., 3., 1., 3., 3., 2., 3., 1., 3., 3., 2., 3., 1., 1.,
        0., 1., 3., 1., 3., 1., 2., 3., 3., 3., 2., 1., 3., 1., 2., 1., 1.,
        3., 0., 3., 0., 1., 3., 1., 0., 1., 3., 2., 3., 0., 1., 3., 1., 2.,
        0., 1., 1., 2., 1., 1., 1., 1., 3., 3., 3., 1., 2., 3., 3., 1., 2.,
        3., 1., 0., 2., 1., 3., 2., 3., 1., 1., 3., 1.,