In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler
import torch.backends.cudnn as cudnn

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as albu
import os, glob, sys, shutil, random
import cv2, itertools, random, pickle
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from collections import Counter
from sklearn.mixture import GaussianMixture
from sklearn.metrics import roc_auc_score

In [3]:
import data_process
import utils
IMAGE_FOLDER = "/data/tcga/512dense/"

In [4]:
random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)

# data

In [5]:
import pickle
with open("./cohort_high_low.pkl", "rb") as fp:
    cohort_count_dict = pickle.load(fp)
print("# cohort: {}".format(len(cohort_count_dict)))

# cohort: 389


In [6]:
patient_ids = list(cohort_count_dict.keys())
patient_cls = list(cohort_count_dict.values())
lookup = dict(zip(patient_ids, patient_cls))
Counter(patient_cls)

Counter({0: 317, 1: 72})

In [7]:
train_patient, valid_patient = train_test_split(patient_ids, test_size = 0.3, random_state = 42)
train_cls = [lookup[i] for i in train_patient]
valid_cls = [lookup[i] for i in valid_patient]
print("# train patient:{}\n# valid patient:{}".format(Counter(train_cls), Counter(valid_cls)))
del train_cls, valid_cls

# train patient:Counter({0: 221, 1: 51})
# valid patient:Counter({0: 96, 1: 21})


In [8]:
train_images, valid_images = [], []
train_lookup, valid_lookup = {}, {}
train_npys = []
train_p_n, valid_p_n = [0, 0], [0, 0]

for idx, npy in enumerate(sorted(glob.glob("/data/tcga/512denseTumor/*.npy"))):
    if npy == "/data/tcga/512denseTumor/TCGA-CM-6679-01A-01-TS1.6fbabb32-470f-4fc2-9a73-3a40e1237988.npy":
        continue
    x_y_pairs = np.load(npy)
    svs_name = npy.split("/")[-1][:-4]
    patient = svs_name[:12]
    for x, y in x_y_pairs:
        path = "{}_{}_{}.jpg".format(svs_name, x, y)
        if patient in train_patient:
            train_images.append(path)
            train_lookup[path[:-4]] = lookup[patient]
        else:
            valid_images.append(path)
            valid_lookup[path[:-4]] = lookup[patient]
    
    if patient in train_patient:
        train_npys.append(npy)
        train_p_n[lookup[patient]] += 1
    if patient in valid_patient:
        valid_p_n[lookup[patient]] += 1
#     """ for demo """
#     if idx > 70:
#         break
train_images = np.array(train_images)
valid_images = np.array(valid_images)
print("# train images:{}\n# valid images:{}".format(len(train_images), len(valid_images)))
print("# train images:{}\n# valid images:{}".format(Counter(list(train_lookup.values())), \
                                      Counter(list(valid_lookup.values())) ))
print("# train npys:(0: {}, 1: {})\n# valid npys:(0: {}, 1: {})".format(train_p_n[0], train_p_n[1], \
                                      valid_p_n[0], valid_p_n[1]))
del train_p_n, valid_p_n
del train_p_n, valid_p_n

# train images:1014641
# valid images:448232
# train images:Counter({0: 809418, 1: 205223})
# valid images:Counter({0: 345166, 1: 103066})
# train npys:Counter({396: 1, 97: 1})
# valid npys:Counter({181: 1, 41: 1})


# dataloader

In [9]:
def to_tensor(x, **kwargs):
    x = x/255.
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing():
    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def get_training_augmentation():
    test_transform = [
        albu.Resize(224, 224),
        albu.Flip(),
        albu.RandomRotate90(),
        albu.OneOf(
            [
                albu.ElasticTransform(p=1),
                albu.GridDistortion(p=1),
                albu.OpticalDistortion(p=1)
            ],
            p=0.8,
        ),
        albu.ShiftScaleRotate(border_mode=0, value=0),
        albu.IAAAdditiveGaussianNoise(),
        albu.GaussianBlur(),
        albu.OneOf(
            [
                albu.RandomBrightnessContrast(p=1),
                albu.HueSaturationValue(p=1),
                albu.HueSaturationValue(hue_shift_limit=40, sat_shift_limit=60, val_shift_limit=40, p=1),
            ], 
            p=0.8,
        ),
        albu.ColorJitter(p=0.5)
        
    ]
    return albu.Compose(test_transform)

def get_validation_augmentation():
    test_transform = [
        albu.Resize(224, 224),
    ]
    return albu.Compose(test_transform)

In [11]:
class CustomDataset(Dataset): 
    def __init__(self, root_dir, mode, images = None, lookup_table = None, 
                 augmentation = None, preprocessing = None, pred=[], probability=[], log=''): 
        self.root = root_dir
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mode = mode
        
        self.labels = lookup_table

        if self.mode == 'all':
            self.images = images                
        elif self.mode == "labeled":
            pred_idx = pred.nonzero()[0]
            self.images = [images[i] for i in pred_idx]                
            self.probability = [probability[i] for i in pred_idx]            
            print("%s data has a size of %d"%(self.mode,len(self.images)))            
            log.write('Numer of labeled samples:%d \n'%(pred.sum()))
            log.flush()                          
        elif self.mode == "unlabeled":
            pred_idx = (1-pred).nonzero()[0]                                               
            self.images = [images[i] for i in pred_idx]                           
            print("%s data has a size of %d"%(self.mode,len(self.images)))
        elif self.mode == "test":
            self.images = images
                    
    def __getitem__(self, index):
        if self.mode=='labeled':
            img_path = self.images[index]
            svs_name = img_path.split("_")[0]
            target = self.labels[img_path[:-4]] 
            prob = self.probability[index]
            
            full_path = os.path.join(self.root, svs_name, img_path)
            image = data_process.wsi_utils.vips_get_image(full_path)
            
            img1 = self.augmentation(image = image)['image']
            img2 = self.augmentation(image = image)['image']
            img1 = self.preprocessing(image = img1)['image']
            img2 = self.preprocessing(image = img2)['image']
            return img1, img2, target, prob              
        elif self.mode=='unlabeled':
            img_path = self.images[index]
            svs_name = img_path.split("_")[0]
            full_path = os.path.join(self.root, svs_name, img_path)
            image = data_process.wsi_utils.vips_get_image(full_path)
            
            img1 = self.augmentation(image = image)['image']
            img2 = self.augmentation(image = image)['image']
            img1 = self.preprocessing(image = img1)['image']
            img2 = self.preprocessing(image = img2)['image']
            return img1, img2  
        elif self.mode=='all':
            img_path = self.images[index]
            svs_name = img_path.split("_")[0]
            target = self.labels[img_path[:-4]] 
            
            full_path = os.path.join(self.root, svs_name, img_path)
            image = data_process.wsi_utils.vips_get_image(full_path)  
            
            img = self.augmentation(image = image)['image']
            img = self.preprocessing(image = img)['image']
            return img, target, index
        elif self.mode=='test':
            img_path = self.images[index]
            svs_name = img_path.split("_")[0]
            target = self.labels[img_path[:-4]] 
            
            full_path = os.path.join(IMAGE_FOLDER, svs_name, img_path)
            image = data_process.wsi_utils.vips_get_image(full_path)

            img = self.augmentation(image = image)['image']
            img = self.preprocessing(image = img)['image']
            return image, cls, img_path[:-4]
    
    def __len__(self):
        return len(self.images)
        
class WrapperLoader():  
    def __init__(self, batch_size, num_workers, log):

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.root_dir = IMAGE_FOLDER
        self.log = log

    def run(self,mode,pred=[],prob=[]):
        """
        
        """
        if mode=='warmup':
            all_dataset = CustomDataset(root_dir=self.root_dir, mode="all", images=train_images, lookup_table=train_lookup,
                                            augmentation=get_training_augmentation(), preprocessing=get_preprocessing())                
            trainloader = DataLoader(
                dataset=all_dataset, 
                batch_size=self.batch_size*2,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True)                 
            return trainloader
                                     
        elif mode=='train':
            labeled_dataset = CustomDataset(root_dir=self.root_dir, mode="labeled", pred=pred, probability=prob, log=self.log, 
                                                images=train_images, lookup_table=train_lookup,
                                                augmentation=get_training_augmentation(), preprocessing=get_preprocessing())            
            labeled_trainloader = DataLoader(
                dataset=labeled_dataset, 
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True)        
            
            unlabeled_dataset =  CustomDataset(root_dir=self.root_dir, mode="unlabeled", pred=pred, log=self.log, 
                                                images=train_images, lookup_table=train_lookup,
                                                augmentation=get_training_augmentation(), preprocessing=get_preprocessing())                   
            unlabeled_trainloader = DataLoader(
                dataset=unlabeled_dataset, 
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                pin_memory=True)
            return labeled_trainloader, unlabeled_trainloader
       
        elif mode=='eval_train':
            eval_dataset = CustomDataset(root_dir=self.root_dir, mode="all", images=train_images, lookup_table=train_lookup,
                                            augmentation=get_training_augmentation(), preprocessing=get_preprocessing())    
            eval_loader = DataLoader(
                dataset=eval_dataset, 
                batch_size=self.batch_size*5,
                shuffle=False,
                num_workers=self.num_workers,
                pin_memory=True)               
            return eval_loader
        
        elif mode=="test":
            test_dataset = CustomDataset(root_dir=self.root_dir, mode="all", images=valid_images, lookup_table=valid_lookup,
                                            augmentation=get_training_augmentation(), preprocessing=get_preprocessing()) 
            test_loader = DataLoader(
                test_dataset, 
                batch_size=self.batch_size*5, 
                shuffle=False, 
                num_workers=self.num_workers, 
                pin_memory=True)
            return test_loader

# dividemix

In [12]:
hyperParam = {
    "num_class": 2,
    "T": 0.5,
    "alpha": 4,
    "lambda_u": 20,
    "p_threshold": 0.5 # clean label threshold tao
}

In [13]:
# Training
def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader):
    net.train()
    net2.eval() #fix one network and train the other
    
    unlabeled_train_iter = iter(unlabeled_trainloader)    
    num_iter = (len(labeled_trainloader.dataset)//bs)+1
    
    labeled_loss_meter = utils.meter.AverageValueMeter()
    unlabeled_loss_meter = utils.meter.AverageValueMeter()
    total_loss_meter = utils.meter.AverageValueMeter()
    
    with tqdm(labeled_trainloader, desc="train", file=sys.stdout) as iterator:
        for batch_idx, (inputs_x, inputs_x2, labels_x, w_x) in enumerate(iterator):      
            try:
                inputs_u, inputs_u2 = unlabeled_train_iter.next()
            except:
                unlabeled_train_iter = iter(unlabeled_trainloader)
                inputs_u, inputs_u2 = unlabeled_train_iter.next()                 
            batch_size = inputs_x.size(0)

            # Transform label to one-hot
            labels_x = torch.zeros(batch_size, hyperParam["num_class"]).scatter_(1, labels_x.view(-1,1), 1)        
            w_x = w_x.view(-1,1).type(torch.FloatTensor) 

            inputs_x, inputs_x2, labels_x, w_x = inputs_x.cuda(), inputs_x2.cuda(), labels_x.cuda(), w_x.cuda()
            inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()

            with torch.no_grad():
                # label co-guessing of unlabeled samples
                outputs_u11 = net(inputs_u)
                outputs_u12 = net(inputs_u2)
                outputs_u21 = net2(inputs_u)
                outputs_u22 = net2(inputs_u2)            
                pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) \
                      + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4       
                ptu = pu**(1/hyperParam["T"]) # temparature sharpening
                
                targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
                targets_u = targets_u.detach()       
                # label refinement of labeled samples
                outputs_x = net(inputs_x)
                outputs_x2 = net(inputs_x2)            

                px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
                px = w_x*labels_x + (1-w_x)*px              
                ptx = px**(1/hyperParam["T"]) # temparature sharpening 

                targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
                targets_x = targets_x.detach()       

            # mixmatch
            l = np.random.beta(hyperParam["alpha"], hyperParam["alpha"])        
            l = max(l, 1-l)

            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

            idx = torch.randperm(all_inputs.size(0))

            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]

            mixed_input = l * input_a + (1 - l) * input_b        
            mixed_target = l * target_a + (1 - l) * target_b

            logits = net(mixed_input)
            logits_x = logits[:batch_size*2]
            logits_u = logits[batch_size*2:]        

            Lx, Lu, lamb = criterion(logits_x, mixed_target[:batch_size*2], \
                                     logits_u, mixed_target[batch_size*2:], \
                                     epoch+batch_idx/num_iter, warm_up)
            
            # regularization
            prior = torch.ones(hyperParam["num_class"])/hyperParam["num_class"]
            prior = prior.cuda()        
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))

            loss = Lx + lamb * Lu  + penalty
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            labeled_loss_meter.add(Lx.item())
            unlabeled_loss_meter.add(Lu.item())
            total_loss_meter.add(loss.item())
            
            s = 'Labeled loss: %.4f, Unlabeled loss:%.4f, Total loss: %.4f'\
                %(labeled_loss_meter.mean, unlabeled_loss_meter.mean, total_loss_meter.mean)
            iterator.set_postfix_str(s)

In [14]:
def warmup(epoch,net,optimizer,dataloader):
    noise_mode = 'sym'
    net.train()
    
    loss_meter = utils.meter.AverageValueMeter()
    penalty_loss_meter = utils.meter.AverageValueMeter()
    ce_loss_meter = utils.meter.AverageValueMeter()
    with tqdm(dataloader, desc="warmup", file=sys.stdout) as iterator:
        for batch_idx, (inputs, labels, path) in enumerate(iterator):      
            inputs, labels = inputs.cuda(), labels.cuda() 
            optimizer.zero_grad()
            outputs = net(inputs)               
            loss = CEloss(outputs, labels)      
            if noise_mode=='asym':  # penalize confident prediction for asymmetric noise
#                 penalty = conf_penalty(outputs)
                penalty = NegEntropy()(outputs)
                L = loss - penalty
            elif noise_mode=='sym':   
                L = loss
            L.backward()  
            optimizer.step()
            
            if noise_mode=='sym':
                loss_meter.add(L.item())
                s = 'CE-loss: %.4f'%(loss_meter.mean)
                iterator.set_postfix_str(s)
            elif noise_mode=='asym':
                loss_meter.add(L.item())
                penalty_loss_meter.add(penalty.item())
                ce_loss_meter.add(loss.item())
                s = "Total-loss: {:.4f}, CE-loss: {:.4f}, Penalty: {:.4f}".format(loss_meter.mean, ce_loss_meter.mean, penalty_loss_meter.mean)
                iterator.set_postfix_str(s)

In [15]:
import sklearn
def test(epoch,net1,net2):
    net1.eval()
    net2.eval()
    patient_preds = {}
    
    loss_meter = utils.meter.AverageValueMeter()
    metric_meter = utils.meter.AverageValueMeter()
    with tqdm(test_loader, desc="valid", file=sys.stdout) as iterator:
        for inputs, targets, names in iterator:
            inputs, targets = inputs.cuda(), targets.cuda()
            with torch.no_grad():
                outputs1 = net1(inputs)
                outputs2 = net2(inputs)
            outputs = outputs1+outputs2
            
            loss_value = CEloss(outputs, targets)
            loss_value = loss_value.detach().cpu().numpy()
            metric_value = utils.metrics.Fscore()(outputs, targets)
            metric_value = metric_value.detach().cpu().numpy()
            loss_meter.add(loss_value)
            metric_meter.add(metric_value)
            
            _, outputs = torch.max(outputs, 1)  
            outputs = outputs.detach().cpu().numpy()
            for pid, p in zip(names, outputs):
                pid = pid[:12]
                if pid not in patient_preds:
                    patient_preds[pid] = [0,0]
                patient_preds[pid][p] += 1
            s = "{}-{:.2f}, {}-{:.2f}".format("CE-loss", loss_meter.mean,"fscore", metric_meter.mean)
            iterator.set_postfix_str(s)
            
    y_pred, y_gt = [], []
    for key, values in patient_preds.items():
        y_gt.append(lookup[key])
        y_pred.append(values[1]/(values[0]+values[1]))
    auc = sklearn.metrics.roc_auc_score(y_gt, y_pred)
    precision, recall, _thresholds = sklearn.metrics.precision_recall_curve(y_gt, y_pred)
    aupr = sklearn.metrics.auc(recall, precision)
    print("\n| Test Epoch #%d\t AUC: %.2f AUPR: %2f\n" %(epoch,auc,aupr))  
    test_log.write('Epoch:%d\t AUC:%.2f AUPR: %2f\n'%(epoch,auc,aupr))
    test_log.flush()  
    return auc, aupr

In [16]:
def eval_train(model,all_loss):    
    model.eval()
    losses = torch.zeros(len(eval_loader.dataset))    
    with torch.no_grad():
        with tqdm(eval_loader, desc="eval train", file=sys.stdout) as iterator:
            for batch_idx, (inputs, targets, index) in enumerate(iterator):
                inputs, targets = inputs.cuda(), targets.cuda() 
                outputs = model(inputs) 
                loss = CE(outputs, targets)  
                for b in range(inputs.size(0)):
                    losses[index[b]]=loss[b]       
                    
    losses = (losses-losses.min())/(losses.max()-losses.min())    
    all_loss.append(losses)

    # fit a two-component GMM to the loss
    input_loss = losses.reshape(-1,1)
    gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
    gmm.fit(input_loss)
    prob = gmm.predict_proba(input_loss) 
    prob = prob[:,gmm.means_.argmin()]         
    return prob,all_loss

In [17]:
def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
    return hyperParam["lambda_u"]*float(current)

# Loss

In [18]:
class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, linear_rampup(epoch,warm_up)
    
class NegEntropy(object):
    def __call__(self,y_pred):
        log_likelihood = -1*nn.LogSoftmax(dim=1)(y_pred)
        N, C = y_pred.size()
        loss = torch.sum(torch.mul(log_likelihood, y_pred))/N
        return loss

In [19]:
def create_model():
    from utils.weight_init import weight_init
    model = models.resnet18()
    model.fc = nn.Linear(model.fc.in_features, hyperParam["num_class"])
    model.apply(weight_init)
    model = model.cuda()
    return model

# settings

In [20]:
stats_log = open('./%s'%("dividmix_test")+'_stats.txt','w') 
test_log = open('./%s'%("dividmix_test")+'_acc.txt','w')     
bs = 32
loader = WrapperLoader(batch_size=bs, num_workers=16, log=stats_log)

In [21]:
print('| Building net')
net1 = create_model()
net2 = create_model()
cudnn.benchmark = True

| Building net


In [22]:
criterion = SemiLoss()
optimizer1 = torch.optim.SGD(net1.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)
optimizer2 = torch.optim.SGD(net2.parameters(), lr=1e-4, momentum=0.9, weight_decay=5e-4)

CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
conf_penalty = NegEntropy()

all_loss = [[],[]] # save the history of losses from two networks

In [None]:
num_epochs = 20
lr = 1e-4
p_threshold = 0.5
warm_up = 5
max_value = 0.8
for epoch in range(0, num_epochs+1):   
    print("Epoch [{}/{}]".format(epoch, num_epochs))
    if epoch >= 10:
        lr /= 10      
    for param_group in optimizer1.param_groups:
        param_group['lr'] = lr       
    for param_group in optimizer2.param_groups:
        param_group['lr'] = lr              
        
    test_loader = loader.run('test')   
    eval_loader = loader.run('eval_train')
    
    if epoch<warm_up:      
        warmup_trainloader = loader.run('warmup')
        print('Warmup Net1')
        warmup(epoch,net1,optimizer1,warmup_trainloader)    
        print('Warmup Net2')
        warmup(epoch,net2,optimizer2,warmup_trainloader)   
    else:         
        prob1,all_loss[0]=eval_train(net1,all_loss[0])   
        prob2,all_loss[1]=eval_train(net2,all_loss[1])          
               
        pred1 = (prob1 > hyperParam["p_threshold"])      
        pred2 = (prob2 > hyperParam["p_threshold"])      
        
        print('Train Net1')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2) # co-divide
        train(epoch,net1,net2,optimizer1,labeled_trainloader, unlabeled_trainloader) # train net1  
        
        print('Train Net2')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1) # co-divide
        train(epoch,net2,net1,optimizer2,labeled_trainloader, unlabeled_trainloader) # train net2         
    
    print('Valid')
    auc, aupr = test(epoch,net1,net2)  
    if max_value < auc:
        model_name1 = "/data/weight/DivideMix_ResNet18_1"      
        model_name2 = "/data/weight/DivideMix_ResNet18_2"
        torch.save(net1.state_dict(), model_name1)
        torch.save(net2.state_dict(), model_name2)

In [29]:
eval_loader = loader.run('eval_train')
net1 = create_model()
net1.cuda()
net1.eval()
net1.load_state_dict(torch.load("/data/weight/DivideMix_ResNet18_1"))

<All keys matched successfully>

# visual

In [None]:
net1.eval()
net2.eval()

In [None]:
with open("/data/tcga/cohort_count.pkl", "rb") as fp:
    raw_TMB_dict = pickle.load(fp)

In [None]:
# using_npy = "/data/tcga/kmeans_cluster_32/"
# using_npy = "/data/tcga/512densenpy/"
using_npy = "/data/tcga/512denseTumor/"
valid_npy_pos = []
valid_npy_neg = []
valid_npy_normal = []
for npy in sorted(glob.glob(os.path.join(using_npy, "*.npy"))):    
    patient = npy.split("/")[-1][:12]
    if npy.split("/")[-1][13] == "1":
        valid_npy_normal.append(npy)
    elif patient in valid_patient:
        if lookup[patient] == 1:
            valid_npy_pos.append(npy)
        else:
            valid_npy_neg.append(npy)

In [None]:
TT_positive_pred = {}
with tqdm(valid_npy_pos, desc="test", file=sys.stdout) as iterator:
    for npy in iterator:
        svs_name = npy.split("/")[-1][:-4]
        patient_name = svs_name[:12]
        x_y_pairs = np.load(npy)
        image_names = ["{}_{}_{}.jpg".format(svs_name, x, y) for x, y in x_y_pairs]
        
        test_dataset = CustomDataset(
            image_names,
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing()
        )
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
        patch_predictions = np.array([])
        svs_pred = [0,0]
        for images, labels, patch_names in test_loader:
            with torch.no_grad():
                outputs1 = net1.forward(images.cuda())
                outputs2 = net2.forward(images.cuda())
            outputs = outputs1 + outputs2
            outputs = torch.softmax(outputs, dim=1)
            outputs = outputs[:, 1].detach().cpu().numpy()
            patch_predictions = np.concatenate((patch_predictions, outputs))
            for p in outputs:
                if p > 0.5:
                    svs_pred[1] += 1
                else:
                    svs_pred[0] += 1
        y_pred = svs_pred[1]/(svs_pred[0]+svs_pred[1])
        if patient_name not in TT_positive_pred:
            TT_positive_pred[patient_name] = []
        TT_positive_pred[patient_name].append(y_pred)
        
        title = "{}_gt={}_pred={:.4f}_raw={}".format(svs_name, lookup[patient_name], y_pred, raw_TMB_dict[patient_name])
        data_process.stitch.stitch(wsi_name = svs_name, x_y_pairs = x_y_pairs, preds = patch_predictions, title=title)

In [None]:
TT_positive_pred = {}
with tqdm(valid_npy_pos, desc="test", file=sys.stdout) as iterator:
    for npy in iterator:
        svs_name = npy.split("/")[-1][:-4]
        patient_name = svs_name[:12]
        x_y_pairs = np.load(npy)
        image_names = ["{}_{}_{}.jpg".format(svs_name, x, y) for x, y in x_y_pairs]
        
        test_dataset = CustomDataset(
            image_names,
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing()
        )
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
        patch_predictions = np.array([])
        svs_pred = [0,0]
        for images, labels, patch_names in test_loader:
            with torch.no_grad():
                outputs1 = net1.forward(images.cuda())
                outputs2 = net2.forward(images.cuda())
            outputs = outputs1 + outputs2
            outputs = torch.softmax(outputs, dim=1)
            outputs = outputs[:, 1].detach().cpu().numpy()
            patch_predictions = np.concatenate((patch_predictions, outputs))
            for p in outputs:
                if p > 0.5:
                    svs_pred[1] += 1
                else:
                    svs_pred[0] += 1
        y_pred = svs_pred[1]/(svs_pred[0]+svs_pred[1])
        if patient_name not in TT_positive_pred:
            TT_positive_pred[patient_name] = []
        TT_positive_pred[patient_name].append(y_pred)
        
        title = "{}_gt={}_pred={:.4f}_raw={}".format(svs_name, lookup[patient_name], y_pred, raw_TMB_dict[patient_name])
        data_process.stitch.stitch(wsi_name = svs_name, x_y_pairs = x_y_pairs, preds = patch_predictions, title=title)

In [None]:
TT_negative_pred = {}
with tqdm(valid_npy_neg, desc="test", file=sys.stdout) as iterator:
    for npy in iterator:
        svs_name = npy.split("/")[-1][:-4]
        patient_name = svs_name[:12]
        x_y_pairs = np.load(npy)
        image_names = ["{}_{}_{}.jpg".format(svs_name, x, y) for x, y in x_y_pairs]
        
        test_dataset = CustomDataset(
            image_names,
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing()
        )
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
        patch_predictions = np.array([])
        svs_pred = [0,0]
        for images, labels, patch_names in test_loader:
            with torch.no_grad():
                outputs1 = net1.forward(images.cuda())
                outputs2 = net2.forward(images.cuda())
            outputs = outputs1 + outputs2
            outputs = torch.softmax(outputs, dim=1)
            outputs = outputs[:, 1].detach().cpu().numpy()
            patch_predictions = np.concatenate((patch_predictions, outputs))
            for p in outputs:
                if p > 0.5:
                    svs_pred[1] += 1
                else:
                    svs_pred[0] += 1
        y_pred = svs_pred[1]/(svs_pred[0]+svs_pred[1])
        if patient_name not in TT_negative_pred:
            TT_negative_pred[patient_name] = []
        TT_negative_pred[patient_name].append(y_pred)
        
        title = "{}_gt={}_pred={:.4f}_raw={}".format(svs_name, lookup[patient_name], y_pred, raw_TMB_dict[patient_name])
        data_process.stitch.stitch(wsi_name = svs_name, x_y_pairs = x_y_pairs, preds = patch_predictions, title=title)

In [None]:
preds = [np.amax(values) for key, values in TT_positive_pred.items()] \
        + [np.amax(values) for key, values in TT_negative_pred.items()]
gt = [1]*len(TT_positive_pred) + [0]*len(TT_negative_pred)
print(roc_auc_score(gt, preds))

In [None]:
class CustomDataset(Dataset):
    def __init__(self, patches, augmentation = None, preprocessing = None):
        self.patches = patches
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __getitem__(self, i):
        image_name = self.patches[i]
        svs_name = image_name.split("_")[0]
        full_path = os.path.join(IMAGE_FOLDER, svs_name, image_name)
        image = data_process.wsi_utils.vips_get_image(full_path)
        
        cls = lookup[image_name[:12]]
        
        if(self.augmentation):
            sampled = self.augmentation(image = image)
            _input = sampled['image']
            
        if(self.preprocessing):
            sampled = self.preprocessing(image = _input)
            _input = sampled['image']
            
        return _input, cls, image_name[:-4]
            
    def __len__(self):
        return len(self.patches)