In [1]:
import os
import scipy
import seaborn as sns
import nibabel as nib
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt

In [2]:
def softmax(x, T, b=0):
    x = x / T + b
    f = np.exp(x - np.max(x, axis = 0))  # shift values
    return f / f.sum(axis = 0)

def ComputMetric(ACTUAL, PREDICTED):
    ACTUAL = ACTUAL.flatten()
    PREDICTED = PREDICTED.flatten()
    idxp = ACTUAL == True
    idxn = ACTUAL == False

    tp = np.sum(ACTUAL[idxp] == PREDICTED[idxp])
    tn = np.sum(ACTUAL[idxn] == PREDICTED[idxn])
    fp = np.sum(idxn) - tn
    fn = np.sum(idxp) - tp
    FPR = fp / (fp + tn)
    if tp == 0 :
        dice = 0
        Precision = 0
        Sensitivity = 0
    else:
        dice = 2 * tp / (2 * tp + fp + fn)
        Precision = tp / (tp + fp)
        Sensitivity = tp / (tp + fn)
    return dice, Sensitivity, Precision

In [3]:
def sum_tensor(inp, axes, keepdim=False):
    axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(int(ax), keepdim=True)
    else:
        for ax in sorted(axes, reverse=True):
            inp = inp.sum(int(ax))
    return inp

def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):
    """
    net_output must be (b, c, x, y(, z)))
    gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
    if mask is provided it must have shape (b, 1, x, y(, z)))
    :param net_output:
    :param gt:
    :param axes:
    :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
    :param square: if True then fp, tp and fn will be squared before summation
    :return:
    """
    if axes is None:
        axes = tuple(range(2, len(net_output.size())))

    shp_x = net_output.shape
    shp_y = gt.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            gt = gt.view((shp_y[0], 1, *shp_y[1:]))

        if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
            # if this is the case then gt is probably already a one hot encoding
            y_onehot = gt
        else:
            gt = gt.long()
            y_onehot = torch.zeros(shp_x)
            if net_output.device.type == "cuda":
                y_onehot = y_onehot.cuda(net_output.device.index)
            y_onehot.scatter_(1, gt, 1)

    tp = net_output * y_onehot
    fp = net_output * (1 - y_onehot)
    fn = (1 - net_output) * y_onehot

    if mask is not None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)

    return tp, fp, fn


def SoftDiceLoss(x, y, loss_mask = None, smooth = 1e-5, batch_dice = False):
    '''
    Batch_dice means that we want to calculate the dsc of all batch
    It would make more sense for small patchsize, aka DeepMedic based training.
    '''
    shp_x = x.shape
    square = False

    axes = [0] + list(range(2, len(shp_x)))
    tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, square)
    dc = (2 * tp + smooth) / (2 * tp + fp + fn + smooth)
    dc_process = dc

    return dc_process

In [4]:
resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostateval/'
prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletestB/'

In [5]:
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

In [6]:
probs_all = []
preds_class_all = []
targets_all = []
logits_all = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.6)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)
    
    if logits_all == []:
        logits_all = clsflatten
    else:
        logits_all = np.concatenate((logits_all, clsflatten), axis=1)



  if logits_all == []:
100%|█████████████████████████████████████████████████████████████████| 10/10 [00:12<00:00,  1.29s/it]


In [7]:
preacts = logits_all.T
labels = targets_all

In [8]:
# prepare:
# -> preacts. N x C
# -> labels. N
def eval_func(x):
    ws = np.array(x[:int(len(x)/2)])
    bs = np.array(x[int(len(x)/2):]) 
    
    
#     vs_logits = preacts/ws[None,:] + bs[None,:]
    
    ws = np.concatenate((np.ones(1), ws[None, 0]))
    vs_logits = preacts/ws[None,:]
    
    exp_vs_logits = np.exp(vs_logits)
    sum_exp = np.sum(exp_vs_logits, axis=1, keepdims=True)
    AC = np.mean(np.max(exp_vs_logits/sum_exp, axis=1))
    preds = np.argmax(preacts, axis = 1)
    acc = np.sum(labels == preds) / len(labels)
    MC = np.abs(AC-acc)

    return MC

In [9]:
optimization_result = scipy.optimize.minimize(
                      fun=eval_func,
#                       fun=lambda x: eval_func(x)[0],
                      x0=np.array([1.0 for x in range(preacts.shape[1])]
                                  +[0.0 for x in range(preacts.shape[1])]),
                      bounds=[(0,None) for x in range(preacts.shape[1])]
                              +[(None,None) for x in range(preacts.shape[1])],
#                       jac=True,
                      method='L-BFGS-B',
                      tol=1e-07)

In [10]:
optimization_result

      fun: 4.918188079017227e-10
 hess_inv: <4x4 LbfgsInvHessProduct with dtype=float64>
      jac: array([-0.00046709,  0.        ,  0.        ,  0.        ])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 125
      nit: 6
     njev: 25
   status: 0
  success: True
        x: array([3.87936091, 1.        , 0.        , 0.        ])

In [50]:
eval_func([ 1.59172994,  1.60807655, -0.02340451,  0.02324877 ])

5.7114424301119016e-11

In [57]:
# find optimal parameters with training data
resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostatetrain/'
prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletrainingB/'
DatafiletsImgc1 = prostatevalpath + 'seg-train.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()

probs_all = []
preds_class_all = []
targets_all = []
logits_all = []
for Imgnamec1 in tqdm(Imgreadc1):
    knamelist = Imgnamec1.split("/")
    kname = knamelist[-1][0:6]
    
    cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
    cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
    cls0read = nib.load(cls0filename)
    cls1read = nib.load(cls1filename)
    cls0logit = cls0read.get_fdata()
    cls1logit = cls1read.get_fdata()
    GTread = nib.load(Imgnamec1)
    GTimg = GTread.get_fdata()
    imgshape = GTimg.shape
    cls0flatten = cls0logit.flatten()
    cls1flatten = cls1logit.flatten()
    clsflatten = np.stack((cls0flatten, cls1flatten))
    GTflatten = GTimg.flatten()
    probflatten = softmax(clsflatten, T = 1.2)
    
    pred_class = np.argmax(probflatten, axis = 0)
    preds_class_all = np.concatenate((preds_class_all, pred_class), axis=0)
    targets_all = np.concatenate((targets_all, GTflatten), axis=0)
    probmax = np.max(probflatten, axis = 0)
    probs_all = np.concatenate((probs_all, probmax), axis=0)
    
    if logits_all == []:
        logits_all = clsflatten
    else:
        logits_all = np.concatenate((logits_all, clsflatten), axis=1)

  if logits_all == []:
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:08<00:00,  2.26it/s]


In [58]:
print(np.sum(targets_all == preds_class_all) / len(targets_all))
print(np.mean(probs_all))

0.9987457493223437
0.9987651842426232


In [17]:
# 5 conditions
domainlist = ['A', 'C', 'D', 'E', 'F']
softDSCs_AC = []
softDSCs_TS = []
softDSCs_CSTS = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_TS = []
    softDSC_FG_CSTS = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletest' + kcon + '/'
    DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By TS
        probflatten = softmax(clsflatten, T = 1.6, b = 0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_TS.append(softDSC[1].numpy())
        # By CSTS
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[0, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.47256267, b=-0.17790477)[0,:]
        probflatten[1, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.48775543, b=0.1779001)[1,:]
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[0, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.47256267, b=-0.17790477)[0,:]
        probflatten[1, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.48775543, b=0.1779001)[1,:]
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSTS.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_TS.append(np.mean(np.array(softDSC_FG_TS)))
    softDSCs_CSTS.append(np.mean(np.array(softDSC_FG_CSTS)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 5/5 [01:09<00:00, 13.81s/it]


In [18]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_TS))))
print('CSTS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSTS))))

AC_results:
1.8734415411981173e-01
TS_results:
0.09247324744821413
CSTS_results:
0.11249706978466718


In [19]:
print(softDSCs_CSTS)

[0.8235055430485072, 0.7694877732138141, 0.7868525282768104, 0.877403380802001, 0.83500830028285]


In [93]:
# CDA TS

In [11]:
# 5 conditions
domainlist = ['A', 'C', 'D', 'E', 'F']
softDSCs_AC = []
softDSCs_TS = []
softDSCs_CSTS = []
realDSCs = []
for kcon in tqdm(domainlist):
    softDSC_FG_AC = []
    softDSC_FG_TS = []
    softDSC_FG_CSTS = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + kcon + '/'
    
    prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletest' + kcon + '/'
    DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
    Imgfiletsc1 = open(DatafiletsImgc1)
    Imgreadc1 = Imgfiletsc1.read().splitlines()
    
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By TS
        probflatten = softmax(clsflatten, T = 1.6, b = 0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_TS.append(softDSC[1].numpy())
        # By CSTS
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 3.87936091)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSTS.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_TS.append(np.mean(np.array(softDSC_FG_TS)))
    softDSCs_CSTS.append(np.mean(np.array(softDSC_FG_CSTS)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|███████████████████████████████████████████████████████████████████| 5/5 [01:14<00:00, 14.81s/it]


In [12]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_TS))))
print('CSTS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSTS))))

AC_results:
1.8734415411981173e-01
TS_results:
0.09247324744821413
CSTS_results:
0.07273545111221746


In [13]:
print(softDSCs_CSTS)

[0.7561626372292171, 0.7401573516550323, 0.7523077598929022, 0.8135165625252605, 0.7880584423248818]


In [10]:
resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostateval/'
prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletestB/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()
# 84 conditions
softDSCs_AC = []
softDSCs_TS = []
softDSCs_CSTS = []
realDSCs = []
for kcon in tqdm(range(1, 84)):
    softDSC_FG_AC = []
    softDSC_FG_TS = []
    softDSC_FG_CSTS = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + str(kcon) + '/'
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By TS
        probflatten = softmax(clsflatten, T = 1.6)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_TS.append(softDSC[1].numpy())
        # By CSTS
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[0, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.47256267, b=-0.17790477)[0,:]
        probflatten[1, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.48775543, b=0.1779001)[1,:]
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[0, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.47256267, b=-0.17790477)[0,:]
        probflatten[1, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.48775543, b=0.1779001)[1,:]
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSTS.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_TS.append(np.mean(np.array(softDSC_FG_TS)))
    softDSCs_CSTS.append(np.mean(np.array(softDSC_FG_CSTS)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|████████████████████████████| 83/83 [38:27<00:00, 27.80s/it]


In [11]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_TS))))
print('CSTS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSTS))))

AC_results:
8.682389850061455e-02
TS_results:
0.03662467820391148
CSTS_results:
0.04832355614865498


In [12]:
print(softDSCs_CSTS)

[0.8979719741218796, 0.5334853374340107, 0.876805737282309, 0.6143788260382131, 0.8558768059042666, 0.6621393958104134, 0.8636254991082319, 0.6136797727211547, 0.8849229610293957, 0.646274860103667, 0.8805449527912634, 0.8859788128467587, 0.8823478638961311, 0.8730228861845711, 0.8722898800062516, 0.8659628725113595, 0.8863191127195422, 0.886931653229305, 0.857928412860983, 0.8638921064862444, 0.836079478744343, 0.8372183989447505, 0.7265564787966698, 0.7331672021690976, 0.7196580736210894, 0.7339636579585, 0.7181684589947378, 0.7267492321659054, 0.9120051805949567, 0.9004121159763233, 0.9095088020931421, 0.8391484419336928, 0.8333833171335845, 0.9018270257734077, 0.821809924546179, 0.820383034080845, 0.9099942733713897, 0.8713130294661982, 0.8792352650145169, 0.902539441424496, 0.9217076279699796, 0.909155050978755, 0.9243435756457513, 0.9022638638168745, 0.923801385288549, 0.9010796799615637, 0.9151687711690963, 0.9183929449964913, 0.9110347067564699, 0.9202563851774113, 0.9069044688

In [14]:
resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostateval/'
prostatevalpath = '/vol/biomedic3/zl9518/Prostatedata/datafiletestB/'
DatafiletsImgc1 = prostatevalpath + 'seg-eval.txt'
Imgfiletsc1 = open(DatafiletsImgc1)
Imgreadc1 = Imgfiletsc1.read().splitlines()
# 84 conditions
softDSCs_AC = []
softDSCs_TS = []
softDSCs_CSTS = []
realDSCs = []
for kcon in tqdm(range(1, 84)):
    softDSC_FG_AC = []
    softDSC_FG_TS = []
    softDSC_FG_CSTS = []
    realDSC_FG = []
    resultpath = '/vol/biomedic3/zl9518/ModelEvaluation/output/prostate/prostattestcondition_' + str(kcon) + '/'
    for Imgnamec1 in Imgreadc1:
        knamelist = Imgnamec1.split("/")
        kname = knamelist[-1][0:6]

        cls0filename = resultpath + '/results/pred_' + kname + 'cls0_prob.nii.gz'
        cls1filename = resultpath + '/results/pred_' + kname + 'cls1_prob.nii.gz'
        cls0read = nib.load(cls0filename)
        cls1read = nib.load(cls1filename)
        cls0logit = cls0read.get_fdata()
        cls1logit = cls1read.get_fdata()
        GTread = nib.load(Imgnamec1)
        GTimg = GTread.get_fdata()
        imgshape = GTimg.shape
        cls0flatten = cls0logit.flatten()
        cls1flatten = cls1logit.flatten()
        clsflatten = np.stack((cls0flatten, cls1flatten))
        GTflatten = GTimg.flatten()
        # By AC
        probflatten = softmax(clsflatten, T = 1.0)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_AC.append(softDSC[1].numpy())
        # By TS
        probflatten = softmax(clsflatten, T = 1.6)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_TS.append(softDSC[1].numpy())
        # By CSTS
        preds_all_argmax = np.argmax(clsflatten, axis = 0)
        # for cls 0, BG class 
        targets_y1 = np.where(preds_all_argmax==0)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 1.0)
        # for cls 1, FG class
        targets_y1 = np.where(preds_all_argmax==1)[0]
        probflatten[:, targets_y1] = softmax(clsflatten[:, targets_y1], T = 3.87936091)
        probr = probflatten.reshape((2, imgshape[0], imgshape[1], imgshape[2]))
        probr_tensor = torch.tensor(probr[np.newaxis, ...])
        GTimgf = np.argmax(probr, axis = 0)
        GT_tensor = torch.tensor(GTimgf[np.newaxis, ...])
        softDSC = SoftDiceLoss(probr_tensor, GT_tensor)
        softDSC_FG_CSTS.append(softDSC[1].numpy())

        realDSC, _, _ = ComputMetric(GTimg, np.argmax(probr, axis = 0))
        realDSC_FG.append(realDSC)
        
    softDSCs_AC.append(np.mean(np.array(softDSC_FG_AC)))
    softDSCs_TS.append(np.mean(np.array(softDSC_FG_TS)))
    softDSCs_CSTS.append(np.mean(np.array(softDSC_FG_CSTS)))
    realDSCs.append(np.mean(np.array(realDSC_FG)))

100%|█████████████████████████████████████████████████████████████████| 83/83 [26:12<00:00, 18.95s/it]


In [15]:
print('AC_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_AC))))
print('TS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_TS))))
print('CSTS_results:')
print(np.mean(np.abs(np.array(realDSCs) - np.array(softDSCs_CSTS))))

AC_results:
8.682389850061455e-02
TS_results:
0.03662467820391148
CSTS_results:
0.05798078426489817


In [16]:
print(softDSCs_CSTS)

[0.8290540188556854, 0.4667891720587563, 0.822737999032659, 0.5450536600715924, 0.80699298207058, 0.5966263213761087, 0.8234653472758217, 0.5498173689794983, 0.8435476635516412, 0.5874103797886671, 0.8104057233044593, 0.8159315339629017, 0.8140605015928539, 0.8088820427169173, 0.8076924735390852, 0.8061755926823976, 0.8162575546164502, 0.8170770122437201, 0.7994269939623904, 0.8028014771171325, 0.7834172293364775, 0.7843344670679295, 0.6580804462635699, 0.6641356284865417, 0.6563051618744103, 0.6669606336173245, 0.6566102710046248, 0.6637143564021865, 0.8370712359112137, 0.826524077593864, 0.8353324854838793, 0.7821952140424704, 0.7817253334292853, 0.8238280694900102, 0.7666182992570958, 0.7650234305824631, 0.8356737131165215, 0.8082449805806359, 0.8156168523538072, 0.8256488998910412, 0.8444432455713591, 0.8359543513024612, 0.8460536101433933, 0.8286476830496123, 0.8459239790630726, 0.8267807258823927, 0.8405320120038201, 0.8425027517103082, 0.8376498083713871, 0.843611028649056, 0.83