In [170]:
import os
import seaborn as sns
import nibabel as nib
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import scipy
import matplotlib.pyplot as plt

In [171]:
def softmax(x, T):
    x = x / T
    f = np.exp(x - np.max(x, axis = 0))  # shift values
    return f / f.sum(axis = 0)

def ComputMetric(ACTUAL, PREDICTED):
    ACTUAL = ACTUAL.flatten()
    PREDICTED = PREDICTED.flatten()
    idxp = ACTUAL == True
    idxn = ACTUAL == False

    tp = np.sum(ACTUAL[idxp] == PREDICTED[idxp])
    tn = np.sum(ACTUAL[idxn] == PREDICTED[idxn])
    fp = np.sum(idxn) - tn
    fn = np.sum(idxp) - tp
    FPR = fp / (fp + tn)
    if tp == 0 :
        dice = 0
        Precision = 0
        Sensitivity = 0
    else:
        dice = 2 * tp / (2 * tp + fp + fn)
        Precision = tp / (tp + fp)
        Sensitivity = tp / (tp + fn)
    return dice, Sensitivity, Precision

def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn


def SoftDiceLoss(x, y, loss_mask = None, smooth = 1e-5, batch_dice = False):
    '''
    Batch_dice means that we want to calculate the dsc of all batch
    It would make more sense for small patchsize, aka DeepMedic based training.
    '''
    shp_x = x.shape
    square = False

    axes = [0] + list(range(2, len(shp_x)))
    tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, square)
    dc = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    dc_process = dc

    return dc_process

## Class-Specific Temperature-Scaling (CS TS)

In [None]:
resultpath = './data/Prostateresults/prostateval/'
prostatevalpath = './data/Prostateresults/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

### Temperature-Scaling

In [174]:
probs_all = []
preds_class_all = []
targets_all = []
logits_all = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.6)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)
    
    if len(logits_all) == 0:
        logits_all = clsflatten
    else:
        logits_all = np.concatenate((logits_all, clsflatten), axis=1)

100%|█████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00,  1.42s/it]


In [175]:
preacts = logits_all.T
labels = targets_all
preds = np.argmax(preacts, axis = 1)
acc = np.sum(labels == preds) / len(labels)
def eval_func(x):
   
    ts_logits = preacts/x
    exp_ts_logits = np.exp(ts_logits)
    sum_exp = np.sum(exp_ts_logits, axis=1, keepdims=True)
    AC = np.mean(np.max(exp_ts_logits/sum_exp, axis=1))
    MC = np.abs(AC-acc)

    return MC
optimization_result = scipy.optimize.minimize(
                          fun=eval_func,
                          x0=np.array([1.0]),
                          method='Nelder-Mead',
                          tol=1e-07)

In [176]:
LearedTemp = optimization_result.x[0]
print(LearedTemp)

1.6034124374389669e+00


### Class-Specific Temperature-Scaling

#### the first step, align the background with acc

In [177]:
# -> preacts. N x C
# -> labels. N
preacts = logits_all.T
labels = targets_all
preds_all_argmax = np.argmax(preacts, axis = 1)
targets_y1 = np.where(preds_all_argmax==0)[0]
pred_class = np.argmax(preacts, axis = 1)[targets_y1]
target_class = targets_all[targets_y1]

acc = np.sum(pred_class == target_class) / len(target_class)
def eval_func(x):
    
    prob_Topt = softmax(logits_all, T = x).transpose()[targets_y1]
    AC = np.mean(np.max(prob_Topt, axis = 1))

    MC = np.abs(AC-acc)

    return MC

optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      tol=1e-07)

In [178]:
LearedTempBG = optimization_result.x[0]
print(LearedTempBG)

1.5503941535949726e+00


#### the second step, align the foreground with DSC

In [179]:
def eval_func(x):
    softDSCs = []
    realDSCs = []
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        probflatten = softmax(clsflatten, T = 1.0)

        preds_all_argmax = np.argmax(clsflatten, axis = 0)

        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempBG)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = x)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])

        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSCs.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSCs.append(realDSC)
    MC = np.abs(np.mean(softDSCs) - np.mean(realDSCs))

    return MC
        
optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      bounds=[(0,None)],
                      tol=1e-07)

In [180]:
LearedTempFG = optimization_result.x[0]
print(LearedTempFG)

1.9221460342407244e+00


### Test performance

In [181]:
# 5 conditions
domainlist = ['A']
softDSCs_AC = []
softDSCs_TS = []
softDSCs_CSTS = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_TS = []
    softDSC_FG_CSTS = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    DatafiletsImgc1 = prostatevalpath + 'seg-test' + kcon + '.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By TS
        probflatten = softmax(clsflatten, T = LearedTemp)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_TS.append(softDSC[1].numpy())
        # By CSTS
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempBG)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempFG)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSTS.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_TS.append(np.mean(np.array(softDSC_FG_TS)))
    softDSCs_CSTS.append(np.mean(np.array(softDSC_FG_CSTS)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.90s/it]


In [182]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_TS))))
print('CSTS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSTS))))

AC_results:
1.049066211781956e-01
TS_results:
0.015096730900921562
CSTS_results:
0.004232643061958363


## Class-Specific Difference of Confidences (CS DoC)

In [201]:
resultpath = './data/Prostateresults/prostateval/'
prostatevalpath = './data/Prostateresults/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

### Difference of Confidences

In [203]:
probs_all = []
preds_class_all = []
targets_all = []

softDSCs = []
realDSCs = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.0)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)

acc = np.sum(targets_all == preds_class_all) / len(targets_all)
mean_prob = np.mean(probs_all)
DoC = mean_prob - acc

100%|█████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.45it/s]


In [204]:
print(DoC)

0.0022837051173588696


### Class-Specific Difference of Confidences

In [205]:
probs_all = []
preds_class_all = []
targets_all = []
softDSCs = []
realDSCs = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.0)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)

    preds_all_argmax = np.argmax(clsflatten, axis = 0)

    # for cls 0, BG class 
    targets_y1 = np.where(preds_all_argmax==0)[0]
    probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
    # for cls 1, FG class
    targets_y1 = np.where(preds_all_argmax==1)[0]
    probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
    probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
    probr_tensor = torch.tensor(probr[np.newaxis, ...])
    GTimgf = np.argmax(probr, axis = 0)
    GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
    
    softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
    softDSCs.append(softDSC[1].numpy())

    realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
    realDSCs.append(realDSC)

msoftDSC = np.mean(softDSCs)
mrealDSC = np.mean(realDSCs)
CS_DoC = msoftDSC - mrealDSC

100%|█████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.28it/s]


In [206]:
print(CS_DoC)

6.645691983623936e-02


### Test performance

In [208]:
# 5 conditions
domainlist = ['A']
softDSCs_AC = []
softDSCs_DoC = []
softDSCs_CSDoC = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_DoC = []
    softDSC_FG_CSDoC = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletest' + kcon + '/'
    DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By DoC
        probflattens = softmax(clsflatten, T = 1.0)
        # calculate the diff
        
        preds_all_argmax = np.argmax(clsflatten, axis = 0)

        probflattens[0, np.where(preds_all_argmax == 0)] = probflattens[0, np.where(preds_all_argmax == 0)] - DoC
        probflattens[1, np.where(preds_all_argmax == 0)] = probflattens[1, np.where(preds_all_argmax == 0)] + DoC
        probflattens[0, np.where((1-preds_all_argmax) == 0)] = probflattens[0, np.where((1-preds_all_argmax) == 0)] + DoC
        probflattens[1, np.where((1-preds_all_argmax) == 0)] = probflattens[1, np.where((1-preds_all_argmax) == 0)] - DoC
        probflattens = np.clip(probflattens, 0, 1)
        
        
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probrs_tensor = torch.tensor(probrs[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probrs_tensor, GT_tensor)
        softDSC_FG_DoC.append(softDSC[1].numpy())
        # By CSDoC
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        
        probflatten = np.clip(probflatten, 0, 1)
        
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
#         softDSC_FG_CSDoC.append(softDSC[1].numpy())
        softDSC_FG_CSDoC.append(softDSC[1].numpy() - CS_DoC)

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_DoC.append(np.mean(np.array(softDSC_FG_DoC)))
    softDSCs_CSDoC.append(np.mean(np.array(softDSC_FG_CSDoC)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 1/1 [00:11<00:00, 11.43s/it]


In [209]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('DoC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_DoC))))
print('CSDoC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSDoC))))

AC_results:
0.20720681944147046
DoC_results:
0.1841104766406948
CSDoC_results:
0.1407498996052311


## Class-Specific Average Thresholded Confidence (CS ATC)

In [210]:
resultpath = './data/Prostateresults/prostateval/'
prostatevalpath = './data/Prostateresults/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

### Average Thresholded Confidence

In [211]:
probs_all = []
preds_class_all = []
targets_all = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.0)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)

100%|█████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.38it/s]


In [225]:
def eval_func(x):
    
    prob = softmax(logit_all.transpose(), T = 1).transpose()
    probmax = np.max(prob, axis = 1)
    acc_appr = np.sum(probs_all > x) / len(targets_all)
    
    preds = np.argmax(preacts, axis = 1)
    acc = np.sum(targets_all == preds_class_all) / len(targets_all)
    
    MC = np.abs(acc_appr-acc)

    return MC
optimization_result = scipy.optimize.minimize(
                          fun=eval_func,
                          x0=np.array([1.0]),
                          method='Nelder-Mead',
                          tol=1e-07)

In [226]:
LearedThreshold = optimization_result.x[0]
print(LearedThreshold)

0.8772399902343748


### Class-Specific Average Thresholded Confidence

#### the first step, align the background with acc

In [240]:
# -> preacts. N x C
# -> labels. N
preacts = logits_all.T
labels = targets_all
preds_all_argmax = np.argmax(preacts, axis = 1)
targets_y1 = np.where(preds_all_argmax==0)[0]
pred_class = np.argmax(preacts, axis = 1)[targets_y1]
target_class = targets_all[targets_y1]

acc = np.sum(pred_class == target_class) / len(target_class)
def eval_func(x):
    
    prob_Topt = softmax(logits_all, T = 1).transpose()[targets_y1]
    acc_appr = np.sum(prob_Topt > x) / len(targets_y1)

    MC = np.abs(acc_appr-acc)

    return MC

optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      tol=1e-07)

In [241]:
LearedThresholdBG = optimization_result.x[0]
print(LearedThresholdBG)

8.566162109375e-01


#### the second step, align the foreground with DSC

In [243]:
def eval_func(x):
    softDSCs = []
    realDSCs = []
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        probflatten = softmax(clsflatten, T = 1.0)

        preds_all_argmax = np.argmax(clsflatten, axis = 0)

        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflattens = probflatten
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdBG
        probflattens = probflattens.astype(float)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > x
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])

        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSCs.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSCs.append(realDSC)
    MC = np.abs(np.mean(softDSCs) - np.mean(realDSCs))

    return MC
        
optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      bounds=[(0,None)],
                      tol=1e-07)

In [244]:
LearedThresholdFG = optimization_result.x[0]
print(LearedThresholdFG)

0.9695383071899413


### Test performance

In [246]:
# 5 conditions
domainlist = ['A']
softDSCs_AC = []
softDSCs_ATC = []
softDSCs_CSATC = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_ATC = []
    softDSC_FG_CSATC = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletest' + kcon + '/'
    DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By ATC
        probflatten = softmax(clsflatten, T = 1.0)
        probflattens = probflatten > LearedThreshold
        probflattens = probflattens.astype(float)
        
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_ATC.append(softDSC[1].numpy())
        # By CSATC
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdBG
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdFG
        probflattens = probflattens.astype(float)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSATC.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_ATC.append(np.mean(np.array(softDSC_FG_ATC)))
    softDSCs_CSATC.append(np.mean(np.array(softDSC_FG_CSATC)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 1/1 [00:14<00:00, 14.15s/it]


In [247]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('ATC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_ATC))))
print('CSATC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSATC))))

AC_results:
0.20720681944147046
ATC_results:
0.16340066019200306
CSATC_results:
0.04222054066044989


## Class-Specific Temperature-Scaling Average Thresholded Confidence (CS TS-ATC)

In [251]:
resultpath = './data/Prostateresults/prostateval/'
prostatevalpath = './data/Prostateresults/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

### Temperature-Scaling Average Thresholded Confidence

In [252]:
# -> preacts. N x C
# -> labels. N
preacts = logits_all.T
labels = targets_all
preds_all_argmax = np.argmax(preacts, axis = 1)
targets_y1 = np.where(preds_all_argmax==0)[0]
pred_class = np.argmax(preacts, axis = 1)[targets_y1]
target_class = targets_all[targets_y1]

acc = np.sum(pred_class == target_class) / len(target_class)
def eval_func(x):
    
    prob_Topt = softmax(logits_all, T = LearedTemp).transpose()[targets_y1]
    acc_appr = np.sum(prob_Topt > x) / len(targets_y1)

    MC = np.abs(acc_appr-acc)

    return MC

optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      tol=1e-07)

In [253]:
LearedThreshold = optimization_result.x[0]
print(LearedThreshold)

7.530212402343748e-01


### Class-Specific Temperature-Scaling Average Thresholded Confidence

#### the first step, align the background with acc

In [258]:
# -> preacts. N x C
# -> labels. N
preacts = logits_all.T
labels = targets_all
preds_all_argmax = np.argmax(preacts, axis = 1)
targets_y1 = np.where(preds_all_argmax==0)[0]
pred_class = np.argmax(preacts, axis = 1)[targets_y1]
target_class = targets_all[targets_y1]

acc = np.sum(pred_class == target_class) / len(target_class)
def eval_func(x):
    
    prob_Topt = softmax(logits_all, T = LearedTempBG).transpose()[targets_y1]
    acc_appr = np.sum(prob_Topt > x) / len(targets_y1)

    MC = np.abs(acc_appr-acc)

    return MC

optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      tol=1e-07)

In [259]:
LearedThresholdBG = optimization_result.x[0]
print(LearedThresholdBG)

7.600433349609372e-01


#### the second step, align the foreground with DSC

In [284]:
def eval_func(x):
    softDSCs = []
    realDSCs = []
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        probflatten = softmax(clsflatten, T = LearedTemp)

        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        
        #
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempBG)
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempFG)

        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflattens = probflatten
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdBG
        probflattens = probflattens.astype(float)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > x
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])

        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSCs.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSCs.append(realDSC)
    MC = np.abs(np.mean(softDSCs) - np.mean(realDSCs))

    return MC
        
optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
                      x0=np.array([1.0]),
                      method='Nelder-Mead',
                      bounds=[(0,None)],
                      tol=1e-07)

In [285]:
LearedThresholdFG = optimization_result.x[0]
print(LearedThresholdFG)

0.8437387466430664


### Test performance

In [286]:
# 5 conditions
domainlist = ['A']
softDSCs_AC = []
softDSCs_ATC = []
softDSCs_CSATC = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_ATC = []
    softDSC_FG_CSATC = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletest' + kcon + '/'
    DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By ATC
        probflatten = softmax(clsflatten, T = LearedTemp)
        probflattens = probflatten > LearedThreshold
        probflattens = probflattens.astype(float)
        
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_ATC.append(softDSC[1].numpy())
        # By CSATC
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempBG)
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdBG
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = LearedTempFG)
        probflattens[:, targets_y1] = probflatten[:, targets_y1] > LearedThresholdFG
        probflattens = probflattens.astype(float)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probrs = probflattens.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probrs[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSATC.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_ATC.append(np.mean(np.array(softDSC_FG_ATC)))
    softDSCs_CSATC.append(np.mean(np.array(softDSC_FG_CSATC)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 1/1 [00:11<00:00, 11.72s/it]


In [287]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_ATC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_ATC))))
print('CSTS_ATC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSATC))))

AC_results:
0.20720681944147046
TS_ATC_results:
0.17979765781510382
CSTS_ATC_results:
1.0671801997652608e-08
