# **Stochastic Segmentation Networks demo.**
This is an intercative demo for our paper [Stochastic Segmentation Networks: Modelling Spatially Correlated Aleatoric Uncertainty](https://arxiv.org/abs/2006.06015).
To run this demo:
1.   Make sure to run the cell that downloads the images;
2.   Please be patient with the image slider; it takes little time to read the necessary files whenever a new image is selected.

The effect of the temperature sliders is defined by the sample direction and changes with each new sample.
Increasing and decreasing the temperature does not necessarily correlate with increasing or decreasing the volume of a class.

The temperature slider controls the scaling of the entire covariance matrix, move it to 0 to get the mean.
The individual class temperature sliders control the scaling of the part of the covariance matrix relating to only that class. (If the main temperature slider is set to 0 these have no effect).

In [None]:
#@title This cell downloads the images, it may take a few minutes since these are large 3D images. 

!pip install SimpleITK
!pip install gdown
!mkdir -p data/BraTS17_2013_10_1
!gdown https://drive.google.com/uc?id=1PDQZROumSjbOBIg42yKzWlS0Z46DjOUd -O "data/BraTS17_2013_10_1/t1ce.nii.gz"
!gdown https://drive.google.com/uc?id=1kFnIAa4lOjRnyKPJ8hqjibf5Q5OWJ4IW -O "data/BraTS17_2013_10_1/seg.nii.gz"
!gdown https://drive.google.com/uc?id=1ZQ_nJyI7-fVXXpxABKp0opKhlzCKahQ2 -O "data/BraTS17_2013_10_1/logit_mean.nii.gz"
!gdown https://drive.google.com/uc?id=1mE5UwDMOK0e7nHtt2wn6Dgj0b-0kft2b -O "data/BraTS17_2013_10_1/cov_factor.nii.gz"
!gdown https://drive.google.com/uc?id=1GmmUfPyiFQVWktlwH5kwHLWSQkT-LtI3 -O "data/BraTS17_2013_10_1/cov_diag.nii.gz"
!gdown https://drive.google.com/uc?id=1-fAt59jgkAGTyT7gDidIm_tj7fw2I7NN -O "data/BraTS17_2013_10_1/brainmask.nii.gz"
!mkdir -p data/BraTS17_2013_12_1
!gdown https://drive.google.com/uc?id=17BPLnfbZvl7jng4Kbnlk1oTNLWVQTkWG -O "data/BraTS17_2013_12_1/t1ce.nii.gz"
!gdown https://drive.google.com/uc?id=1G5pMMUfuxOqLPeiVgNoJOmnazEmWau8p -O "data/BraTS17_2013_12_1/seg.nii.gz"
!gdown https://drive.google.com/uc?id=17FEOGuNFQp5Ypm6GPIEC8n-cpDzYSoCl -O "data/BraTS17_2013_12_1/logit_mean.nii.gz"
!gdown https://drive.google.com/uc?id=1NFkruCK4nUyveWhUJXJyjv8Z8TQzM0_Z -O "data/BraTS17_2013_12_1/cov_factor.nii.gz"
!gdown https://drive.google.com/uc?id=1APnglw2vwkp2gbfhKfJysZu08rYViWqS -O "data/BraTS17_2013_12_1/cov_diag.nii.gz"
!gdown https://drive.google.com/uc?id=1YWcb7g2kjJ1XS0jtrD43glJY3Bj3VVHL -O "data/BraTS17_2013_12_1/brainmask.nii.gz"
!mkdir -p data/BraTS17_2013_20_1
!gdown https://drive.google.com/uc?id=1j0jMay2EC1CHs5hRgFUBNOP2SmaUN0Qm -O "data/BraTS17_2013_20_1/t1ce.nii.gz"
!gdown https://drive.google.com/uc?id=17M1Gzrkbf3TQ7n6MxPVBJN0g6D8b2Hhr -O "data/BraTS17_2013_20_1/seg.nii.gz"
!gdown https://drive.google.com/uc?id=1VGJ-uWfRm69YqccHHzKvakkd5XYr9XoA -O "data/BraTS17_2013_20_1/logit_mean.nii.gz"
!gdown https://drive.google.com/uc?id=1wdI8pGfbHY8lMmdpSkdmnnEkE4DhXXgd -O "data/BraTS17_2013_20_1/cov_factor.nii.gz"
!gdown https://drive.google.com/uc?id=16exPfMcQ4Cly3zqgANwf21d2bIeiN4aI -O "data/BraTS17_2013_20_1/cov_diag.nii.gz"
!gdown https://drive.google.com/uc?id=1Aol2mHxc5GZGaOPX3F0TqvnagyJEzj2V -O "data/BraTS17_2013_20_1/brainmask.nii.gz"

In [None]:
# @title Brain Tumour Segmentation demo
#Disclaimer: do not use this code for research, it has been optimised for visualisation in 2D
import os
import torch
import torch.distributions as td
from torch.distributions.lowrank_multivariate_normal import _standard_normal, _batch_mv
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
import warnings
from ipywidgets import interactive, IntSlider, FloatSlider, HBox, VBox, Button, HTML, Layout
from matplotlib import cm
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 300
warnings.filterwarnings("ignore", category=UserWarning)
torch.autograd.set_grad_enabled(False);
COLOR_MAP = ((0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 0), (0, 1, 1), (1, 0, 0))
DEVICE = torch.device(0)


def cast_to_tensor(array, device, dtype=torch.float32):
    return torch.tensor(array.transpose((-1,) + tuple(range(array.ndim - 1))), dtype=dtype, device=device,
                        requires_grad=False)


class InteractiveSampler(object):
    def __init__(self, logit_mean, cov_diag, cov_factor, device, mask, seed=None):
        self.seed = seed
        if seed is not None:
            torch.manual_seed(seed)
        self.device = device
        self.logit_mean, self.cov_factor, self.cov_diag, self.shape, self.rank = self.build_distribution(logit_mean, cov_diag, cov_factor, device, mask)
        self.num_classes = self.shape[0]
        self.eps_W, self.eps_D = None, None
        self.new_sample()

    def new_sample(self):
        shape = (1, ) + self.logit_mean.shape
        w_shape = shape[:-1] + self.cov_factor.shape[-1:]
        dtype=self.logit_mean.dtype
        self.eps_W = _standard_normal(w_shape, dtype=dtype, device=self.device)
        self.eps_D = _standard_normal(shape, dtype=dtype, device=self.device)

    @staticmethod
    def build_distribution(logit_mean, cov_diag, cov_factor, device, mask):
        logit_mean = cast_to_tensor(logit_mean, device)
        cov_diag = cast_to_tensor(cov_diag, device)
        cov_factor = cast_to_tensor(cov_factor, device)
        if mask is not None:
            mask = torch.tensor(mask, device=device, requires_grad=False)
            logit_mean[0, ~mask.type(torch.bool)] = 100
            cov_factor = cov_factor * mask.unsqueeze(0)
        shape = logit_mean.shape
        num_classes = shape[0]
        rank = int(cov_factor.shape[0] / num_classes)
        logit_mean = logit_mean.reshape(-1)
        cov_diag = cov_diag.reshape(-1)
        cov_factor = cov_factor.reshape((rank, -1)).transpose(1, 0)
        epsilon = 1e-3
        return logit_mean, cov_factor, cov_diag, shape, rank

    def get_manipulated_sample_slice_(self, slice_: int, temperature: float, class_weights: torch.Tensor):
        _cov_factor = self.cov_factor.view(self.shape + (self.rank, ))[:, slice_].reshape((-1, self.rank))
        factor_direction = _batch_mv(_cov_factor, self.eps_W).view((self.num_classes, ) + self.shape[2:])
        
        _cov_diag = self.cov_diag.view(self.shape)[:, slice_]
        eps_D = self.eps_D.view(self.shape)[:, slice_]
        diag_direction = _cov_diag.sqrt() * eps_D

        spatial_size = int(np.prod(self.shape[2:]))
        class_weights = torch.tensor(class_weights, device=self.device)
        class_weights = torch.repeat_interleave(class_weights, spatial_size)

        class_weights = class_weights.view((self.num_classes, ) + self.shape[2:])
        dist_loc = self.logit_mean.view(self.shape)[:, slice_]
        logit_samples = dist_loc + temperature * class_weights * (factor_direction + diag_direction)
        samples = torch.argmax(logit_samples, dim=0)
        means = torch.argmax(dist_loc, dim=0)
        return samples, means


class Data(object):
    def __init__(self):
        self.ids = {0: 'BraTS17_2013_10_1', 
                    1: 'BraTS17_2013_12_1',
                    2: 'BraTS17_2013_20_1'}
        self.color_map = np.array(COLOR_MAP)
        self.id_ = None
        self.image_rgb = None
        self.seg_rgb = None
        self.sampler = None

    def overlay_to_rgb(self, overlay):
        overlay = overlay.astype(np.uint8)
        new_overlay = np.zeros(shape=overlay.shape + (3,), dtype=np.uint8)
        for i in range(3):
            new_overlay[..., i] = self.color_map[overlay][..., i]
        return new_overlay

    @staticmethod
    def image_to_rgb(image):
        image = (image - image.min()) / (image.max() - image.min())
        return np.stack((image,) * 3, axis=-1)

    @staticmethod
    def mix_image_and_overlay(image, overlay, opacity=.5):
        overlay = overlay.reshape(image.shape)
        new_image = np.copy(image)
        ind = np.sum(overlay, axis=-1) > 0
        new_image[ind] = opacity * overlay[ind] + (1 - opacity) * image[ind]
        return new_image

    def generate_data(self, id_):
        path = f'data/{id_:s}'
        image = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 't1ce.nii.gz')))
        image_rgb = self.image_to_rgb(image)
        seg = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 'seg.nii.gz'))).astype(np.uint8)
        seg = self.overlay_to_rgb(seg)
        seg_rgb = self.mix_image_and_overlay(image_rgb, seg)
        self.image_rgb = image_rgb
        self.seg_rgb = seg_rgb
        mask = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 'brainmask.nii.gz')))
        logit_mean = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 'logit_mean.nii.gz')))
        cov_factor = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 'cov_factor.nii.gz')))
        cov_diag = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(path, 'cov_diag.nii.gz')))
        self.sampler = InteractiveSampler(logit_mean, cov_diag, cov_factor, DEVICE, mask)

    def get_data(self, slice_, image, temperature, class_weights):
        id_ = self.ids[image]
        if self.id_ is None or self.id_ != id_:
            self.id_ = id_
            self.generate_data(id_)

        sample, mean = self.sampler.get_manipulated_sample_slice_(slice_, temperature, class_weights)

        sample = sample.cpu().numpy().astype(np.uint8)
        sample = self.overlay_to_rgb(sample)
        sample_rgb = self.mix_image_and_overlay(self.image_rgb[slice_], sample)

        mean = mean.cpu().numpy().astype(np.uint8)
        mean = self.overlay_to_rgb(mean)
        mean_rgb = self.mix_image_and_overlay(self.image_rgb[slice_], mean)

        return self.image_rgb[slice_], self.seg_rgb[slice_], sample_rgb, mean_rgb


data = Data()


def plot(image, slice_, temperature, background, necrotic_core, oedema, enhancing_core):
    class_weights = [background, necrotic_core, oedema, enhancing_core]

    image_rgb, seg_rgb, sample_rgb, mean_rgb = data.get_data(slice_, image, temperature, class_weights)

    fig, ax = plt.subplots(1, 4, figsize=(10, 2.5), gridspec_kw=dict(wspace=0, hspace=0))
    ax[0].set_title('Image')
    ax[0].imshow(image_rgb)
    ax[1].set_title('Ground Truth')
    ax[1].imshow(seg_rgb)
    ax[2].set_title('Prediction (Mean)')
    ax[2].imshow(mean_rgb)
    ax[3].set_title('Sample')
    ax[3].imshow(sample_rgb)

    for axi in ax:
        axi.axis('off')
        axi.xaxis.set_major_locator(plt.NullLocator())
        axi.yaxis.set_major_locator(plt.NullLocator())

    plt.show()


def interactive_plot():
    style = {'description_width': '140px'}
    layout = Layout(width='400px')
    w = interactive(
        plot,
        image=IntSlider(min=0, max=2, description='Image #', style={'description_width': '100'}, layout=layout),
        slice_=IntSlider(value=76, min=0, max=154, description='Slice #', style={'description_width': '100'}, layout=layout),
        temperature=FloatSlider(value=1., min=-3., max=3., step=.5, continuous_update=False, description='Temperature', style=style, layout=layout),
        background=FloatSlider(value=1., min=-3., max=3., step=.5, continuous_update=False, description='Background', style=style, layout=layout),
        necrotic_core=FloatSlider(value=1., min=-3., max=3., step=.5, continuous_update=False, description='Necrotic Core (red)', style=style, layout=layout),
        oedema=FloatSlider(value=1., min=-3., max=3., step=.5, continuous_update=False, description='Oedema (green)', style=style, layout=layout),
        enhancing_core=FloatSlider(value=1., min=-3., max=3., step=.5, continuous_update=False, description='Enhancing Core (blue)', style=style, layout=layout),
    )

    new_sample = Button(description="New Sample!")
    def new_sample_on_click(b):
        if data.sampler is not None:
            data.sampler.new_sample()
            w.update()
    new_sample.on_click(new_sample_on_click)
    message = HTML( value="Please be patient with the image slider until you see a new image.<br /> It takes a few seconds to load.")
    ui = VBox([
        HBox([
            VBox([
                message,
                w.children[0],
                w.children[1],
            ]),
            VBox([
                w.children[2],
                w.children[3],
                w.children[4],
                w.children[5],
                w.children[6],
            ]),
            new_sample
        ]),
        w.children[-1]
    ])

    display(ui)
    w.update()


interactive_plot()