In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torchvision
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Sampler
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import albumentations as albu
import os, glob, sys, shutil
import cv2, itertools, random, pickle
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import sklearn
from collections import Counter
from numpy.random import choice
from utils.meter import AverageValueMeter

In [None]:
import data_process
import utils
IMAGE_FOLDER = "/data/tcga/512dense/"

# data

In [None]:
import pickle
with open("./data/cohort_high_low.pkl", "rb") as fp:
    cohort_count_dict = pickle.load(fp)
print("# cohort: {}".format(len(cohort_count_dict)))

In [None]:
patient_ids = list(cohort_count_dict.keys())
patient_cls = list(cohort_count_dict.values())
lookup = dict(zip(patient_ids, patient_cls))
Counter(patient_cls)

In [None]:
train_patient, valid_patient = train_test_split(patient_ids, test_size = 0.2, random_state = 42)
train_cls = [lookup[i] for i in train_patient]
valid_cls = [lookup[i] for i in valid_patient]
print("# train patient:{}\n# valid patient:{}".format(Counter(train_cls), Counter(valid_cls)))
del train_cls
del valid_cls

In [None]:
train_images, valid_images, train_lookup, valid_lookup, train_npys = utils.data.load_data(train_patient=train_patient, valid_patient=valid_patient, \
                                                        patient_label_dict = lookup)

# dataset

In [None]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing():
    _transform = [
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

def get_training_augmentation():
    test_transform = [
        albu.Resize(224, 224),
        albu.Flip(),
        albu.ShiftScaleRotate(border_mode=0, value=0),
        albu.HueSaturationValue(hue_shift_limit=40, p=0.8),
        albu.RandomBrightnessContrast(p=0.8),
        albu.IAAAdditiveGaussianNoise(),
        albu.GaussianBlur(),
        albu.Normalize()
    ]
    return albu.Compose(test_transform)

def get_validation_augmentation():
    test_transform = [
        albu.Resize(224, 224),
        albu.Normalize()
        #albu.ToGray(p = 1.)
    ]
    return albu.Compose(test_transform)

In [None]:
def split_labeled_unlabeled(train_filtered_label = None):
    unlabeled_patch = []
    labeled_patch = []
    num_p = 0
    for key, value in train_filtered_label.items():
        if value == -1:
            unlabeled_patch.append("{}".format(key))
        else:
            labeled_patch.append("{}".format(key))
            num_p += value
    
    return labeled_patch, unlabeled_patch, num_p

In [None]:
class LabeledDataset(Dataset):
    def __init__(self, patches, patch_cls_lookup, augmentation = None, preprocessing = None):
        self.patches = patches
        self.patch_cls_lookup = patch_cls_lookup
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __getitem__(self, i):
        image_name = self.patches[i]
        svs_name = image_name.split("_")[0]
        full_path = os.path.join(IMAGE_FOLDER, svs_name, image_name+".jpg")
        image = data_process.wsi_utils.vips_get_image(full_path)
        
        cls = self.patch_cls_lookup[image_name]
        
        if(self.augmentation):
            sampled = self.augmentation(image = image)
            s_input = sampled['image']
            
        if(self.preprocessing):
            sampled = self.preprocessing(image = s_input)
            s_input = sampled['image']
            
        return s_input, cls, image_name
    
    def __len__(self):
        return len(self.patches)

In [None]:
class UnlabeledDataset(Dataset):
    def __init__(self, patches, augmentation = None, preprocessing = None):
        self.patches = patches
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __getitem__(self, i):
        image_name = self.patches[i]
        svs_name = image_name.split("_")[0]
        full_path = os.path.join(IMAGE_FOLDER, svs_name, image_name+".jpg")
        image = data_process.wsi_utils.vips_get_image(full_path)
        
        if(self.augmentation):
            sampled = self.augmentation(image = image)
            _input1 = sampled['image']
            sampled = self.augmentation(image = image)
            _input2 = sampled['image']
            
        if(self.preprocessing):
            sampled = self.preprocessing(image = _input1, mask = _input2)
            _input1 = sampled['image']
            _input2 = sampled['mask']
            
        return _input1, _input2
            
    def __len__(self):
        return len(self.patches)

# epoch

In [None]:
class Epoch:
    def __init__(self, model, loss, metrics, stage_name, device='cpu', verbose=True):
        self.model = model
        self.loss = loss
        self.metrics = metrics
        self.stage_name = stage_name
        self.verbose = verbose
        self.device = device

        self._to_device()

    def _to_device(self):
        self.model.to(self.device)
        self.loss.to(self.device)
        for metric in self.metrics:
            metric.to(self.device)

    def _format_logs(self, logs):
        str_logs = ['{} - {:.4}'.format(k, v) for k, v in logs.items()]
        s = ', '.join(str_logs)
        return s

    def batch_update(self, x, y):
        raise NotImplementedError

    def on_epoch_start(self):
        pass
    
    def run(self, dataloader, epoch):
        self.on_epoch_start()

        logs = {}
        loss_meter = AverageValueMeter()
        metrics_meters = {metric.__name__: AverageValueMeter() for metric in self.metrics}
        patient_pred = {}
        with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
            for x, y, patch_name in iterator:
                patch_name = list(patch_name)
                x, y = x.to(self.device), y.to(self.device)
                loss_value, preds = self.batch_update(x, y)
                assert torch.isnan(loss_value).any().detach().cpu().numpy() != True, 'loss value 有問題 {}'.format(loss_value)
                
                # update loss logs
                loss_value = loss_value.cpu().detach().numpy()
                loss_meter.add(loss_value)
                loss_logs = {self.loss.__name__: loss_meter.mean}
                logs.update(loss_logs)
                
                # update metrics logs
                for metric_fn in self.metrics:
                    metric_value = metric_fn(preds, y).cpu().detach().numpy()
                    metrics_meters[metric_fn.__name__].add(metric_value)
                metrics_logs = {k: v.mean for k, v in metrics_meters.items()}
                logs.update(metrics_logs)                
                
                if self.verbose:
                    s = self._format_logs(logs)
                    iterator.set_postfix_str(s)
                    
                preds = torch.softmax(preds, dim=1)
                preds = preds[:, 1].detach().cpu().numpy()
                for p, p_n in zip(preds, patch_name):
                    p_id = p_n[:12]
                    if(p_id not in patient_pred):
                        patient_pred[p_id] = [0, 0]
                    if(p > 0.5):
                        patient_pred[p_id][1] += 1
                    else:
                        patient_pred[p_id][0] += 1
                    
        y_pred = []
        y_gt = []
        # voting
        for key, values in patient_pred.items():
            y_gt.append(lookup[key])
            d = values[0] + values[1]
            y_pred.append(values[1] / d)
            
        auc = sklearn.metrics.roc_auc_score(y_gt, y_pred)
        precision, recall, _thresholds = sklearn.metrics.precision_recall_curve(y_gt, y_pred)
        aupr = sklearn.metrics.auc(recall, precision)
        
        print('patient AUROC : ', auc)
        print('patient AUPR : ', aupr)
        logs.update({'auc' : auc})
        logs.update({'aupr' : aupr})
        writer.add_scalar("{}/{}".format(self.loss.__name__, self.stage_name), loss_meter.mean, epoch)
        writer.add_scalar("{}/{}".format("Patient AUC", self.stage_name), auc, epoch)
        writer.add_scalar("{}/{}".format("Patient AUPR", self.stage_name), aupr, epoch)
        writer.add_scalar("{}/{}".format(self.metrics[0].__name__, self.stage_name), metrics_meters[self.metrics[0].__name__].mean, epoch)
        
        return logs

In [None]:
class InitialTrainEpoch(Epoch):

    def __init__(self, model, loss, metrics, optimizer, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='Initial train',
            device=device,
            verbose=verbose,
        )
        self.optimizer = optimizer

    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, x, y):
        self.optimizer.zero_grad()
        prediction = self.model.forward(x)
        loss = self.loss(prediction, y)
        loss.backward()
        self.optimizer.step()
        return loss, prediction
    
class ValidEpoch(Epoch):

    def __init__(self, model, loss, metrics, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='valid',
            device=device,
            verbose=verbose,
        )

    def on_epoch_start(self):
        self.model.eval()

    def batch_update(self, x, y):
        with torch.no_grad():
            prediction = self.model.forward(x)
            loss = self.loss(prediction, y)
        return loss, prediction

In [None]:
class FilterEpoch(Epoch):
    """
    SELF
    """
    def __init__(self, model, init_lookup, loss, metrics, ensemble_momentum=0.5, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='Filtered',
            device=device,
            verbose=verbose,
        )
        self.init_lookup = init_lookup
        self.ensemble_momentum = ensemble_momentum
    def _to_device(self):
        self.model.to(self.device)
        
    def on_epoch_start(self):
        self.model.eval()

    def run(self, dataloader, epoch):
        """
        dataloader: unshuffle train dataloader
        """
        self.on_epoch_start()
        filtered_label = self.init_lookup.copy()
        with tqdm(dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator: 
            for x, y, patch_name in iterator:
                patch_name = list(patch_name)
                x = x.to(self.device)
                
                with torch.no_grad():
                    predictions = self.model(x)

                # predictions:  A `Tensor` of shape [batch_size, class_number]
                predictions = torch.softmax(predictions, dim = 1)
                predictions = predictions.detach().cpu().numpy()

                for pred, p_id in zip(predictions, patch_name):
                    predict_running_average[p_id] = self.ensemble_momentum * predict_running_average[p_id] +\
                                                    (1 - self.ensemble_momentum)*pred
                    # determine which class is agreement
                    agreement = np.argmax(predict_running_average[p_id])
                    # agreement class != initial class
                    if agreement != self.init_lookup[p_id]:
                        filtered_label[p_id] = -1
        return filtered_label

In [None]:
class NewFilterEpoch(Epoch):
    def __init__(self, model, init_lookup, loss, metrics, ensemble_momentum=0.5, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='Filtered',
            device=device,
            verbose=verbose,
        )
        self.init_lookup = init_lookup
        self.ensemble_momentum = ensemble_momentum
        
    def _to_device(self):
        self.model.to(self.device)
        self.loss.to(self.device)
        
    def on_epoch_start(self):
        self.model.eval()

    def SELF_filter_noisy(self, patch_names, patch_preds, new_filtered_label):
        for patch_pred, p_id in zip(patch_preds, patch_names):
            p_id = p_id
            predict_running_average[p_id] = self.ensemble_momentum * predict_running_average[p_id] +\
                                            (1 - self.ensemble_momentum)*patch_pred
            # determine which class is agreement
            agreement = np.argmax(predict_running_average[p_id])
            # agreement class != initial class
            if agreement != self.init_lookup[p_id]:
                new_filtered_label[p_id] = -1
                
    def GMM_filter_noisy(self, patch_names, patch_losses, patch_preds, new_filtered_label):
        """ normalize loss"""
        patch_losses = (patch_losses - patch_losses.min())/(patch_losses.max() - patch_losses.min() + 1e-10)
        """ fit a two-component GMM to the loss """
        gmm = GaussianMixture(n_components=2,max_iter=10,tol=1e-2,reg_covar=5e-4)
        input_losses = patch_losses.reshape(-1, 1)
        gmm.fit(input_losses)
        """ minimize loss"""
        gmm_preds = gmm.predict_proba(input_losses) 
        gmm_preds = gmm_preds[:,gmm.means_.argmin()]
        gmm_preds = np.where(gmm_preds > 0.5, 1, 0)
        for gmm_pred, patch_pred, p_id in zip(gmm_preds, patch_preds, patch_names):
            p_id = p_id
            new_filtered_label[p_id] = gmm_pred
            """ 繼續ensemble """
            predict_running_average[p_id] = self.ensemble_momentum * predict_running_average[p_id] +\
                                            (1 - self.ensemble_momentum)*patch_pred
        
    def run(self, all_npys, epoch):
        self.on_epoch_start()
        new_filtered_label = self.init_lookup.copy()
        
        with tqdm(all_npys, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as iterator:
            for npy in iterator:
                x_y_pairs = np.load(npy)
                svs_name = npy.split("/")[-1][:-4]
                patient_name = npy.split("/")[-1][:12]
                
                image_names = ["{}_{}_{}".format(svs_name, x, y) for x, y in x_y_pairs]
                test_dataset = LabeledDataset(
                    image_names,
                    self.init_lookup,
                    augmentation = get_validation_augmentation(),
                    preprocessing = get_preprocessing(),
                )
                test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=10, pin_memory=True)
                """ calculate preds and losses """
                all_preds = None
                all_losses = None
                for images, labels, patch_name in test_loader:
                    patch_name = list(patch_name)
                    images, labels = images.to(self.device), labels.cuda(self.device)
                    with torch.no_grad():
                        preds = self.model(images)
                    """ Loss 'reduction' should be "none" """
                    loss_values = self.loss(preds, labels).detach().cpu().numpy() 
                    preds = torch.softmax(preds, dim = 1).detach().cpu().numpy()
                    assert len(loss_values) == len(preds), "loss values 數量不同"
                    if all_losses is None:
                        all_losses = loss_values
                        all_preds = preds
                    else:
                        all_losses = np.concatenate((all_losses, loss_values))
                        all_preds = np.concatenate((all_preds, preds))        
                """
                    1) if GT is positive, and majority voting > 0.3 -> SELF
                    2) if GT is positive, and majority voting < 0.3 -> GMM
                    3) if GT is negatuve -> SELF
                """
                assert all_preds is not None, "all_preds get error {}".format(all_preds.dtype)
                majority_voting = np.count_nonzero(np.where(all_preds > 0.5))/len(all_preds)
                
                if lookup[patient_name] == 0:
                    self.SELF_filter_noisy(image_names, all_preds, new_filtered_label)
                elif majority_voting > 0.3:
                    self.SELF_filter_noisy(image_names, all_preds, new_filtered_label)
                else:    
                    self.GMM_filter_noisy(image_names, all_losses, all_preds, new_filtered_label)
        print("filter: {}".format(Counter(list(new_filtered_label.values()))))
        return new_filtered_label

In [None]:
class TrainEpoch(utils.train.Epoch):

    def __init__(self, model, loss_xent_fn, loss_mse_fn, optimizer, loss=None, metrics=None, \
                 T=0.5, lambda_u=1, alpha=0.75, device='cpu', verbose=True):
        super().__init__(
            model=model,
            loss=loss,
            metrics=metrics,
            stage_name='Iteration train',
            device=device,
            verbose=verbose,
        )
        self.loss_xent_fn = loss_xent_fn
        self.loss_mse_fn = loss_mse_fn
        self.optimizer = optimizer
        self.T = T
        self.lambda_u = lambda_u
        self.alpha = alpha

    def _to_device(self):
        self.model.to(self.device)
        
    def on_epoch_start(self):
        self.model.train()

    def batch_update(self, XandU, p, q):
#         with torch.autograd.detect_anomaly():
        self.optimizer.zero_grad()
        prediction = self.model(XandU)
        loss_xent_value = self.loss_xent_fn(prediction[:len(p)], p)
        loss_mse_value = self.loss_mse_fn(prediction[len(p):], q)
        loss = loss_xent_value + self.lambda_u * loss_mse_value
        loss.backward()
        self.optimizer.step()
        return loss_xent_value, loss_mse_value, prediction
    
    def run(self, labeled_dataloader, unlabeled_dataloader, epoch):
        logs = {}
        loss_xent_meters = AverageValueMeter()
        loss_mse_meters = AverageValueMeter()
        
#         with tqdm(unlabeled_dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as unlabeled_iter:
#             labeled_iter = iter(labeled_dataloader)
#             for (inputs_u, inputs_u2) in unlabeled_iter:            
#                 try:
#                     inputs_x, targets_x, names= labeled_iter.next()
#                 except:
#                     labeled_iter = iter(labeled_dataloader)
#                     inputs_x, targets_x, names = labeled_iter.next()        
        with tqdm(labeled_dataloader, desc=self.stage_name, file=sys.stdout, disable=not (self.verbose)) as labeled_iter:
            unlabeled_iter = iter(unlabeled_dataloader)
            for (inputs_x, targets_x, names) in labeled_iter:
                try:
                    inputs_u, inputs_u2 = unlabeled_iter.next()
                except:
                    unlabeled_iter = iter(unlabeled_dataloader)
                    inputs_u, inputs_u2 = unlabeled_iter.next()
                    
                inputs_x, targets_x = inputs_x.to(self.device), targets_x.to(self.device)
                inputs_u, inputs_u2 = inputs_u.to(self.device), inputs_u2.to(self.device)
                """
                MixMatch:
                    inputs_x: shape `(N//2, channel, H, W)`
                    targets_x: shape `(N//2, )`
                    outputs_u: shape `(N//2, channel, H, W)`
                    outputs_u2: shape `(N//2, channel, H, W)`
                    targets_u: shape `(N//2, C)`
                """
                self.model.eval()
                with torch.no_grad():
                    outputs_u = self.model(inputs_u)
                    outputs_u2 = self.model(inputs_u2)
                    # Compute average predictions
                    targets_u = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
                    # Apply temperature sharpen
                    targets_u = targets_u**(1/self.T)
                    targets_u = targets_u / targets_u.sum(dim=1, keepdim=True)
                    targets_u = targets_u.detach()
                
                """
                convert targets_x to one hot (N//2, C)
                indices = (N//2 + N//2 + N//2), 打亂順序給mixup用
                Wx: shape `(3*N//2, channel, H, W)`, 透過indices shuffle過的labeled inputs + unlabeled inputs
                Wy: shape `(3*N//2, C)`, 透過indices shuffle過的labeled + unlabeled
                X: shape `(N//2, channel, H, W)`
                p: shape `(N//2, C)`
                U: shape `(N//2, channel, H, W)`
                q: shape `(N//2, C)`
                """
                targets_x = get_tensor_onehot(targets_x)
                indices = np.random.choice(np.arange(len(targets_x)+len(targets_u)+len(targets_u)),\
                                           len(targets_x)+len(targets_u)+len(targets_u))
                Wx = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)[indices]
                Wy = torch.cat([targets_x, targets_u, targets_u], dim=0)[indices]

                X, p = mixup(inputs_x, Wx[:len(inputs_x)], targets_x, Wy[:len(inputs_x)], self.alpha)
                U, q = mixup(inputs_u, Wx[len(inputs_x):len(inputs_u)+len(inputs_x)], \
                             targets_u, Wy[len(inputs_x):len(targets_u)+len(inputs_x)], self.alpha)
                """
                Model Training:
                    outputs_u: shape `(N, channel, H, W)`
                    p: labeled data,  shape `(N//2, C)`
                    q: unlabeled data,  shape `(N//2, C)`
                """
                self.model.train()
                # update
                loss_xent_value, loss_mse_value, preds = self.batch_update(torch.cat([X,U], dim=0), p, q)
                
                if torch.isnan(loss_xent_value).any().detach().cpu().numpy() == True:
                    print("xent ERROR")
                    break
                if torch.isnan(loss_mse_value).any().detach().cpu().numpy() == True:
                    print("mse ERROR")
                    break
                
                # update loss logs
                loss_xent_value = loss_xent_value.cpu().detach().numpy()
                loss_xent_meters.add(loss_xent_value)
                loss_mse_value = loss_mse_value.cpu().detach().numpy()
                loss_mse_meters.add(loss_mse_value)
                
                logs.update({self.loss_xent_fn.__name__: loss_xent_meters.mean})
                logs.update({self.loss_mse_fn.__name__: loss_mse_meters.mean})
                
                if self.verbose:
                    s = self._format_logs(logs)
                    labeled_iter.set_postfix_str(s)
            writer.add_scalar("{}/{}".format(self.loss_xent_fn.__name__, self.stage_name), loss_xent_meters.mean, epoch)
            writer.add_scalar("{}/{}".format(self.loss_mse_fn.__name__, self.stage_name), loss_mse_meters.mean, epoch)
        return logs
    
def get_tensor_onehot(y, num_class = 2):
    N = len(y)
    onehot = torch.zeros((N, num_class), dtype=torch.float32, device=y.device)
    onehot = onehot.scatter_(1, y.unsqueeze(1), 1)
    return onehot 

def mixup(x1, x2, y1, y2, alpha):
    beta = np.random.beta(alpha, alpha)
    beta = max(beta, 1-beta)
    beta = torch.tensor(beta, dtype=x1.dtype, device = x1.device)
    x = beta * x1 + (1 - beta) * x2
    y = beta * y1 + (1 - beta) * y2
    return x, y

# loss

In [None]:
class SoftLabelCrossEntropy(nn.Module):
    def __init__(self):
        super().__init__()
    @property
    def __name__(self):
        return 'CrossEntropyLoss'
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Variable :math:`(N, C)` where `C = number of classes`
            y_true: Variable :math:`(N, C)` where each value is torch.Floattensor
        Returns:
            softmax-cross_entropy
        """
        log_likelihood = -1*nn.LogSoftmax(dim=1)(y_pred)
        N, C = y_pred.size()
        loss = torch.sum(torch.mul(log_likelihood, y_true))/N
        return loss
    
class MSELoss(nn.Module):
    def __init__(self, max_value = 1):
        super().__init__()
        self.max_value = max_value
    @property
    def __name__(self):
        return 'mseloss'
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Variable :math:`(N, C)` where `C = number of classes`
            y_true: Variable :math:`(N, C)` where each value is torch.Floattensor
        Returns:
            softmax-cross_entropy
        """
        return torch.nn.MSELoss()(y_pred, y_true)
    
class SoftmaxMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
    @property
    def __name__(self):
        return "SoftmaxMSELoss"
    def forward(self, y_pred, y_gt):
        y_pred = F.softmax(y_pred, dim = 1)
        y_gt = F.softmax(y_gt, dim = 1)
        return nn.MSELoss()(y_pred, y_gt) 
    
class CrossEntropy(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.kwargs = kwargs
    @property
    def __name__(self):
        return 'CrossEntropyLoss'
    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: Variable :math:`(N, C)` where `C = number of classes`
            y_true: Variable :math:`(N)` where each value is `0 <= targets[i] <= C-1`, torch.Longtensor
        Returns:
            softmax-cross_entropy
        """
        return nn.CrossEntropyLoss(**self.kwargs)(y_pred, y_true)

# train setting

In [None]:
model = models.resnet18(pretrained=True)
# model.load_state_dict(torch.load("_ckpt_epoch_9.h5"))
model.cuda()
model.fc = nn.Linear(model.fc.in_features, 2)

In [None]:
from datetime import datetime
from pytz import timezone   

# Initail predict running average
predict_running_average = {}
for p in train_images:
    predict_running_average[p] = 0

taipei = timezone('Asia/Taipei')
taipei_time = datetime.now(taipei)
current_time = taipei_time.strftime('%Y-%m-%d_%H-%M')

loss_fn = utils.global_objective.AUCPRHingeLoss()
loss_xent_fn = SoftLabelCrossEntropy()
# loss_xent_fn = utils.global_objective.AUCPRHingeLoss()
loss_mse_fn = MSELoss()
filter_loss_fn = CrossEntropy(reduction="none")

metrics = [utils.metrics.Fscore()]

optimizer = torch.optim.SGD([ 
    dict(params=model.parameters(), lr=1e-4),
], weight_decay=1e-3,  momentum=0.9)

In [None]:
warmup = 5
batch_size = 128
epochs_in_iteration = 2
max_score = 0.8
max_iteration = 10
DEVICE = 'cuda:0' if torch.cuda.is_available() else "cpu"

### tensorboard

In [None]:
from torch.utils.tensorboard import SummaryWriter
# model_name = current_time + "_Retrain_MixMatch_ResNet18_ensemble_GMM"
model_name = "2021-02-02_10-59_Retrain_MixMatch_ResNet18_ensemble_GMM"
log_folder_name = os.path.join('/data/log_folder/mixmatch/',model_name)

# Writer
writer = SummaryWriter(log_dir=log_folder_name, flush_secs=3)
print(log_folder_name)

# initial train

In [None]:
train_dataset = LabeledDataset(
    patches = train_images,
    patch_cls_lookup = train_lookup,
    augmentation = get_training_augmentation(),
    preprocessing = get_preprocessing(),
)

valid_dataset = LabeledDataset(
    patches = valid_images,
    patch_cls_lookup = valid_lookup,
    augmentation = get_validation_augmentation(),
    preprocessing = get_preprocessing(),
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = True, num_workers=16, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True, drop_last=True)

In [None]:
print('\n#=====================#')
print('∥   Initial training  ∥')
print('#=====================#')
train_epoch = InitialTrainEpoch(
    model, 
    loss_fn,
    metrics=metrics, 
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)
valid_epoch = ValidEpoch(
    model,
    loss_fn,
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

train_epoch.stage_name = 'Initial train'
valid_epoch.stage_name = 'Initial valid'

for i in range(warmup, 5):
    train_logs = train_epoch.run(train_loader, i)
    valid_logs = valid_epoch.run(valid_loader, i)

In [None]:
output_path = '/data/weight/noisy_label/{}_epoch_{}_auc_{:.4f}.h5'.format(
    model_name, 5, valid_logs['auc'])
torch.save(model.state_dict(), output_path)

In [None]:
model.load_state_dict(torch.load("/data/weight/noisy_label/2021-02-02_10-59_Retrain_MixMatch_ResNet18_ensemble_GMM_epoch_5_auc_0.8582.h5"))

# filter and train

In [None]:
optimizer = torch.optim.SGD([ 
    dict(params=model.parameters(), lr=1e-4),
], weight_decay=1e-3,  momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1)

In [None]:
filter_epoch = NewFilterEpoch(
        model, 
        init_lookup=train_lookup,
        loss=filter_loss_fn,
        metrics=metrics,
        ensemble_momentum=0.5,
        device=DEVICE,
        verbose=True,
    )
train_epoch = TrainEpoch(
        model,
        loss_xent_fn, 
        loss_mse_fn, 
        optimizer, 
        device=DEVICE,
        verbose=True,
    )
valid_epoch = ValidEpoch(
    model,
    loss_fn,
    metrics=metrics, 
    device=DEVICE,
    verbose=True,
)

In [None]:
iteration = 0
while(iteration < 20):
    """
    Filter label
    """
    print('\n------------------------------------------------------------')
    print('Iteration: {}\n#=====================#'.format(iteration))
    print('∥   Start Filtering   ∥')
    print('#=====================#')
    
    filtered_label = filter_epoch.run(train_npys, iteration)
    # apply new label
    labeled_patches, unlabeled_patches, num_p= split_labeled_unlabeled(train_filtered_label = filtered_label)
    print('\n# of Labeled :{}, # of Unlabeled :{}'.format(len(labeled_patches), len(unlabeled_patches)))
    print('# of Positive Labeled :{}, # of Negative Labeled :{}'.format(num_p, len(labeled_patches) - num_p))
    
    labeled_dataset = LabeledDataset(
        patches = labeled_patches,
        patch_cls_lookup = filtered_label,
        augmentation = get_training_augmentation(),
        preprocessing = get_preprocessing(),
    )
    unlabeled_dataset = UnlabeledDataset(
        patches = unlabeled_patches,
        augmentation = get_training_augmentation(),
        preprocessing = get_preprocessing(),
    )
    labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size//2, \
                                num_workers=10, shuffle = True, pin_memory=True, drop_last=True)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size//2, \
                                  num_workers=10, shuffle = True, pin_memory=True, drop_last=True)
    
    """
    Model train
    """ 
    print('\n#=====================#')
    print('∥    Start training   ∥')
    print('#=====================#')
    
    train_epoch.stage_name = 'Iteration train'
    valid_epoch.stage_name = 'Iteration valid'
    
    #train with new label
    for i in range(epochs_in_iteration):
        epoch = iteration*epochs_in_iteration+i
        print('\nEpoch: {}, batch: {}'.format(epoch, batch_size))
        print('Epoch-{0} lr: {1}'.format(epoch, optimizer.param_groups[0]['lr']))
        train_logs = train_epoch.run(labeled_loader, unlabeled_loader, epoch)
        valid_logs = valid_epoch.run(valid_loader, epoch)
        
        if max_score < valid_logs['auc']:
            max_score = valid_logs['auc']
        
            output_path = '/data/weight/noisy_label/{}_epoch_{}_auc_{:.4f}.h5'.format(
                model_name, epochs_in_iteration, valid_logs['auc'])
            torch.save(model.state_dict(), output_path)
            print('Model saved! {}'.format(output_path))    
    
    if iteration % 3 == 0:
        
        output_path = '/data/weight/noisy_label/{}_epoch_{}_auc_{:.4f}.h5'.format(
            model_name, epochs_in_iteration, valid_logs['auc'])
        torch.save(model.state_dict(), output_path)
        print('Model saved! {}'.format(output_path)) 
    
    iteration += 1

# valid 

In [None]:
# using_npy = "/data/tcga/kmeans_cluster_32/"
using_npy = "/data/tcga/512denseTumor/"
valid_npy_pos = []
valid_npy_neg = []
for npy in sorted(glob.glob(os.path.join(using_npy, "*.npy"))):    
    patient = npy.split("/")[-1][:12]
    if patient in valid_patient:
        if lookup[patient] == 1:
            valid_npy_pos.append(npy)
        else:
            valid_npy_neg.append(npy)

In [None]:
weight = "/data/weight/noisy_label/2020-12-29_20-57_MixMatch_ResNet18_ensemblePred_epoch_0_auc_0.8328042328042329.h5"
# weight = "/data/weight/noisy_label/2021-01-08_16-17_MixMatch_ResNet18_ensemble+GMM_epoch_1_auc_0.8006.h5"

In [None]:
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(weight))
model.cuda()
model.eval()

In [None]:
TT_positive_pred = {}
with tqdm(valid_npy_pos, desc="test", file=sys.stdout) as iterator:
    for npy in iterator:
        svs_name = npy.split("/")[-1][:-4]
        patient_name = svs_name[:12]
        x_y_pairs = np.load(npy)
        image_names = ["{}_{}_{}.jpg".format(svs_name, x, y) for x, y in x_y_pairs]
        
        test_dataset = CustomDataset(
            image_names,
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing()
        )
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
        patch_predictions = np.array([])
        svs_pred = [0,0]
        for images, labels, patch_names in test_loader:
            with torch.no_grad():
                outputs = model.forward(images.cuda())
            outputs = torch.softmax(outputs, dim=1)
            outputs = outputs[:, 1].detach().cpu().numpy()
            patch_predictions = np.concatenate((patch_predictions, outputs))
            for p in outputs:
                if p > 0.5:
                    svs_pred[1] += 1
                else:
                    svs_pred[0] += 1
        y_pred = svs_pred[1]/(svs_pred[0]+svs_pred[1])
        if patient_name not in TT_positive_pred:
            TT_positive_pred[patient_name] = []
        TT_positive_pred[patient_name].append(y_pred)
        
        title = "{}_gt={}_pred={:.4f}_raw={}".format(svs_name, lookup[patient_name], y_pred, raw_TMB_dict[patient_name])
        data_process.stitch.stitch(wsi_name = svs_name, x_y_pairs = x_y_pairs, preds = patch_predictions, title=title)

In [None]:
TT_negative_pred = {}
with tqdm(valid_npy_neg, desc="test", file=sys.stdout) as iterator:
    for npy in iterator:
        svs_name = npy.split("/")[-1][:-4]
        patient_name = svs_name[:12]
        x_y_pairs = np.load(npy)
        image_names = ["{}_{}_{}.jpg".format(svs_name, x, y) for x, y in x_y_pairs]
        
        test_dataset = CustomDataset(
            image_names,
            augmentation=get_validation_augmentation(),
            preprocessing=get_preprocessing()
        )
        test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=16, pin_memory=True)
        patch_predictions = np.array([])
        svs_pred = [0,0]
        for images, labels, patch_names in test_loader:
            with torch.no_grad():
                outputs = model.forward(images.cuda())
            outputs = torch.softmax(outputs, dim=1)
            outputs = outputs[:, 1].detach().cpu().numpy()
            patch_predictions = np.concatenate((patch_predictions, outputs))
            for p in outputs:
                if p > 0.5:
                    svs_pred[1] += 1
                else:
                    svs_pred[0] += 1
        y_pred = svs_pred[1]/(svs_pred[0]+svs_pred[1])
        if patient_name not in TT_negative_pred:
            TT_negative_pred[patient_name] = []
        TT_negative_pred[patient_name].append(y_pred)
        
        title = "{}_gt={}_pred={:.4f}_raw={}".format(svs_name, lookup[patient_name], y_pred, raw_TMB_dict[patient_name])
        data_process.stitch.stitch(wsi_name = svs_name, x_y_pairs = x_y_pairs, preds = patch_predictions, title=title)

In [None]:
import sklearn
preds = [np.amax(values) for key, values in TT_positive_pred.items()] \
        + [np.amax(values) for key, values in TT_negative_pred.items()]
gt = [1]*len(TT_positive_pred) + [0]*len(TT_negative_pred)
print(sklearn.metrics.roc_auc_score(gt, preds))

In [None]:
with open("/data/tcga/cohort_count.pkl", "rb") as fp:
    raw_TMB_dict = pickle.load(fp)

In [None]:
def stitch(wsi_name = None, edge_resize_factor = 32, tile_size = 512, overlap = 256, x_y_pairs = None, preds = None, title = ""):
    import pickle
    ORIGINAL_FOLDER = "/data/tcga/svs/masks"
    original_image = data_process.wsi_utils.vips_get_image(os.path.join(ORIGINAL_FOLDER, svs_name[:-4]+".png"))
    
    with open("./svs_x_y.pkl", "rb") as fp:
        svs_x_y_dict = pickle.load(fp)
    width, height = svs_x_y_dict[wsi_name]
    tumor_mask = np.zeros((int(height/edge_resize_factor), int(width/edge_resize_factor))).astype('float32')

    for xy, pred in zip(x_y_pairs, preds):
        x, y = xy
        start_x = int(x/edge_resize_factor)
        end_x = int((x+tile_size)/edge_resize_factor)
        start_y = int(y/edge_resize_factor)
        end_y = int((y+tile_size)/edge_resize_factor)
        tumor_mask[start_y:end_y, start_x:end_x] = pred
    
    tumor_mask[tumor_mask == 0.0] = np.nan
    """ plt heatmap """
#     fig = plt.figure(figsize=(15, 30))
    fig, ax = plt.subplots(1, 2, figsize=(20, 10))
#     plt.suptitle(title,fontsize=15, y=0.8)
    # ax = plt.axes()
#     ax = plt.subplot(121)
    ax[0].title.set_text(title)
    im = ax[0].imshow(tumor_mask, cmap = 'coolwarm', vmin=0, vmax=1)
    cax = fig.add_axes([ax[0].get_position().x1+0.01,ax[0].get_position().y0,0.02,ax[0].get_position().height])
    plt.colorbar(im, cax=cax) # Similar to fig.colorbar(im, cax = cax)
    """ plt original image """
#     ax2 = plt.subplot(122)
    im = ax[1].imshow(original_image, vmin=0, vmax=255)
    
#     fig.tight_layout()
    plt.show()
    
    return 

In [None]:
class CustomDataset(Dataset):
    def __init__(self, patches, augmentation = None, preprocessing = None):
        self.patches = patches
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        
    def __getitem__(self, i):
        image_name = self.patches[i]
        svs_name = image_name.split("_")[0]
        full_path = os.path.join(IMAGE_FOLDER, svs_name, image_name)
        image = data_process.wsi_utils.vips_get_image(full_path)
        
        cls = lookup[image_name[:12]]
        
        if(self.augmentation):
            sampled = self.augmentation(image = image)
            _input = sampled['image']
            
        if(self.preprocessing):
            sampled = self.preprocessing(image = _input)
            _input = sampled['image']
            
        return _input, cls, image_name
            
    def __len__(self):
        return len(self.patches)