In [1]:
import sys
sys.path.append('../')

In [2]:
import argparse
from dataset.get_dataset import getDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch

In [8]:
class args_struct:
    def __init__(self):
        self.dataroot = '../DiscRegion/'
        self.rater_num = 6
        self.img_width = 256
        self.img_height = 256
        self.dataset = 'RIGA'
        self.validate = True
        self.standardize = False
        self.random_seed = 27

In [9]:
args = args_struct()
train_set, valid_set, test_set = getDataset(args, validate=args.validate)

In [12]:
train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)

In [11]:
batch_data = next(iter(train_loader))

plt.figure(figsize=(16, 16))
for i,(img,mask) in enumerate(zip(batch_data['image'], batch_data['mask'])):
    img = ((img.permute(1,2,0))*255.0).numpy().astype(np.uint8)
    plt.subplot(2, 4, i+1)
    plt.imshow(img,vmin=0,vmax=255)
    plt.imshow(get_intersection_mask(mask[0], mask[2])[1].squeeze().numpy(), alpha=0.2, cmap='gray')
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=0.1)

plt.tight_layout()

TypeError: list indices must be integers or slices, not str

<Figure size 1600x1600 with 0 Axes>

In [24]:
len(batch_data['mask'])

6

### Check intersection, unions

In [53]:
def get_intersection_mask(mask1, mask2):
    intersection = (mask1 * mask2)
    return intersection


def get_union_mask(mask1, mask2):
    intersection = (mask1 * mask2)
    total = (mask1 + mask2)
    union = total - intersection
    return union


def get_intersection_mask_flat(mask1, mask2):
    original_shape = mask1.shape
    mask1_flat = mask1.view(-1)
    mask2_flat = mask2.view(-1)

    intersection = (mask1_flat * mask2_flat).view(original_shape)
    return intersection


def get_union_mask_flat(mask1, mask2):
    original_shape = mask1.shape
    mask1_flat = mask1.view(-1)
    mask2_flat = mask2.view(-1)

    intersection = (mask1_flat * mask2_flat)
    total = (mask1_flat + mask2_flat)

    return (total - intersection).view(original_shape)


In [39]:
mask1 = batch_data['mask'][0]
mask2 = batch_data['mask'][1]

In [54]:
torch.unique(mask1, return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([459422,  64866]))

In [55]:
torch.unique(mask2, return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([460099,  64189]))

In [56]:
torch.unique(get_intersection_mask(mask1, mask2), return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([463577,  60711]))

In [57]:
torch.unique(get_intersection_mask_flat(mask1, mask2), return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([463577,  60711]))

In [58]:
torch.unique(get_union_mask(mask1, mask2), return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([455944,  68344]))

In [59]:
torch.unique(get_union_mask_flat(mask1, mask2), return_counts=True)

(tensor([0., 1.], dtype=torch.float64), tensor([455944,  68344]))

In [30]:
mask1.shape, \
mask2.shape

(torch.Size([4, 2, 256, 256]), torch.Size([4, 2, 256, 256]))

In [31]:
mask1.min(), mask1.max()

(tensor(0., dtype=torch.float64), tensor(1., dtype=torch.float64))

In [None]:
intersection = (output * mask).sum()
total = (output + mask).sum()
union = total - intersection

In [6]:
import cv2
import numpy as np

In [2]:
img = cv2.imread('/media/kudaibergen/TS512/projects_ts/grader_soup/data/DiscRegion/DiscRegion/Rater1/BinRushed/BinRushed1/image1.tif')

In [7]:
np.unique(img)

array([  0, 150, 255], dtype=uint8)

### InterObserverAgreement

In [33]:
def get_dice_threshold(output, mask, threshold):
    """
    :param output: output shape per image, float, (0,1)
    :param mask: mask shape per image, float, (0,1)
    :param threshold: the threshold to binarize output and feature (0,1)
    :return: dice of threshold t
    """
    smooth = 1e-6

    zero = torch.zeros_like(output)
    one = torch.ones_like(output)
    output = torch.where(output > threshold, one, zero)
    mask = torch.where(mask > threshold, one, zero)
    output = output.view(-1)
    mask = mask.view(-1)
    intersection = (output * mask).sum()
    dice = (2. * intersection + smooth) / (output.sum() + mask.sum() + smooth)

    return dice

In [34]:
def get_soft_dice(outputs, masks):
    """
    :param outputs: B * output shape per image
    :param masks: B * mask shape per image
    :return: average dice of B items
    """
    dice_list = []
    for this_item in range(outputs.size(0)):
        output = outputs[this_item]
        mask = masks[this_item]
        dice_item_thres_list = []
        for thres in [0.1, 0.3, 0.5, 0.7, 0.9]:
            dice_item_thres = get_dice_threshold(output, mask, thres)
            dice_item_thres_list.append(dice_item_thres.data)
        dice_item_thres_mean = np.mean(dice_item_thres_list)
        dice_list.append(dice_item_thres_mean)

    return np.mean(dice_list)


In [51]:
from tqdm.notebook import tqdm

In [61]:
train_loader = DataLoader(valid_set, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)

In [62]:
disc_dices = []
cup_dices =  []
for batch_data in train_loader:
    imgs, masks, file_names = batch_data
    
    patient_disc_dices = []
    patient_cup_dices =  []
    for expert_i in range(len(masks)):
        for expert_j in range(expert_i+1, len(masks)):
            patient_disc_dices.append(
                get_soft_dice(masks[expert_i][:, 0, :, :], masks[expert_j][:, 0, :, :])
            )
            patient_cup_dices.append(
                get_soft_dice(masks[expert_i][:, 1, :, :], masks[expert_j][:, 1, :, :])
            )
    
    disc_dices.append(np.mean(patient_disc_dices))
    cup_dices.append(np.mean(patient_cup_dices))

In [50]:
len(train_loader)

524

### Train

In [55]:
np.mean(disc_dices)

0.9570894409706318

In [56]:
np.mean(cup_dices)

0.8028233087855987

### Valid Set

In [64]:
np.mean(disc_dices)

0.9544239015240016

In [65]:
np.mean(cup_dices)

0.7897554742991928

### Test

In [59]:
np.mean(disc_dices)

0.9566070548498949

In [60]:
np.mean(cup_dices)

0.7960036175563726