In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import random
import seaborn as sns

import sys
import os
sys.path.append('../')

from trainprostVQBMCtoRUN import Net as QNet
from trainprostunetBMCtoRUM import Net as Net
from skimage.util import random_noise


In [2]:
net = Net()
net.load_state_dict(torch.load('/vol/biomedic2/agk21/PhDLogs/codes/Vector-Quantisation-for-Robust-Segmentation/Logs/outputprostatefinal/outputprostatefinal/unetBMCtoRUM/best_metric_model-v1.ckpt')['state_dict'])
net = net.eval()
net = net.cuda()

In [15]:
from trainprostVQBMCtoRUN import ProstateDataModule
loader = ProstateDataModule(
            batch_size=1,
            num_workers=4,
            csv_train_img="/vol/biomedic3/as217/vqseg/ProstateDomain/Prostatedomaintr1.csv",
            csv_val_img="/vol/biomedic3/as217/vqseg/ProstateDomain/ProstatedomainvalBMC1.csv",
            csv_test_img="/vol/biomedic3/as217/vqseg/ProstateDomain/Prostatedomaints1.csv",
        )

val_dataloader = loader.val_dataloader()

Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 13940.12it/s]
Loading Data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7590.13it/s]
Loading Data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 17351.91it/s]

#train:  25
#val:    5
#test:   25





In [16]:
def noisy(noise_typ, image, weightage = 0.0):
    if noise_typ == "gauss":
        mean = 0
        var = 1
        sigma = var**0.5
        gauss = np.random.normal(mean, sigma,image.shape)
        gauss = gauss.reshape(*image.shape)
        noisy = image + weightage*gauss
        return noisy
    elif noise_typ == "s&p":
        s_vs_p = 0.5
        amount = weightage
        out = np.copy(image)
        # Salt mode
        num_salt = np.ceil(amount * image.size * s_vs_p)
        coords = [np.random.randint(0, i - 1, int(num_salt))
              for i in image.shape[-3:]]
        for x, y, z in zip(*coords):
            out[0, 0, x, y, z] = 1

        # Pepper mode
        num_pepper = np.ceil(amount* image.size * (1. - s_vs_p))
        coords = [np.random.randint(0, i - 1, int(num_pepper))
              for i in image.shape[-3:]]
        for x, y, z in zip(*coords):
            out[0, 0, x, y, z] = 0
            
        return out
    elif noise_typ == "poisson":
        vals = len(np.unique(image))
        vals = 2 ** np.ceil(np.log2(vals))
        noisy = np.random.poisson(weightage * image * vals) / float(vals)
        return noisy
    elif noise_typ =="speckle":
        gauss = np.random.randn(*image.shape)
        gauss = gauss.reshape(*image.shape)        
        noisy = image + weightage*image * gauss
        return noisy
    

In [17]:
import pickle


def generate_TTA_results(name, data, noise_type='gauss', base_path='.'):
    nTTAs = 100
    zdata = {}
    
    assert noise_type in ['gauss', 's&p', 'poisson', 'speckle'], 'invalid noise_type'
    
    noise_threshold_list = [0.0, 0.01, 0.1, 0.5, 0.75]
    
    for noise_threshold in noise_threshold_list:
        embs = []; recons = []
        for _ in range(nTTAs):
            x = data['image'][:1, ...].numpy()
            dtype = data['image'].dtype
            x = noisy(noise_type, x, noise_threshold)
            x = torch.tensor(x).type(dtype).cuda() 
            with torch.no_grad():
                x, encoding = net._model(x)
            embs.append(encoding)
            recons.append(x)

        embs = torch.cat(embs, 0)
        recons = torch.cat(recons, 0)
        recons = torch.argmax(recons, 1)
        
        zdata[noise_threshold] = {'emb': embs.cpu().numpy(), 
                                      'recon': recons.cpu().squeeze().numpy(), 
                                      'img': data['image'][:1, ...].squeeze().numpy(), 
                                      'label': data['label'][:1, ...].squeeze().numpy()}

    base_path = os.path.join(base_path, noise_type)
    os.makedirs(base_path, exist_ok=True)
    with open(os.path.join(base_path,'{}.pickle'.format(name)), 'wb') as file:
        pickle.dump(zdata, file)
    return zdata

In [18]:
from tqdm import tqdm

all_zdata = {}
for i in tqdm(range(100)):
    data = next(iter(val_dataloader))
    all_zdata[i] = {}
    all_zdata[i]['gauss'] = generate_TTA_results('image{}'.format(i), data, 'gauss', 'PRST/UNet')
    all_zdata[i]['s&p'] = generate_TTA_results('image{}'.format(i), data, 's&p', 'PRST/UNet')
    all_zdata[i]['poisson'] = generate_TTA_results('image{}'.format(i), data, 'poisson', 'PRST/UNet')
    all_zdata[i]['speckle'] = generate_TTA_results('image{}'.format(i), data, 'speckle', 'PRST/UNet')
    
with open('all_data_PRST_UNet.pickle', 'wb') as file:
    pickle.dump(all_zdata, file)

  0%|                                                                                                                                                                         | 0/100 [00:00<?, ?it/s]


=== Transform input info -- LoadImaged ===
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case25.nii.gz
label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case25_Segmentation.nii.gz

=== Transform input info -- LoadImaged ===

=== Transform input info -- LoadImaged ===
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case26.nii.gz
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case28.nii.gz
label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case26_Segmentation.nii.gz
label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case28_Segmentation.nii.gz

=== Transform input info -- LoadImaged ===


  0%|                                                                                                                                                                         | 0/100 [00:00<?, ?it/s]


=== Transform input info -- LoadImaged ===

=== Transform input info -- LoadImaged ===
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case29.nii.gz
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case29.nii.gz





label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case29_Segmentation.nii.gz
label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case29_Segmentation.nii.gz
image statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case27.nii.gz
label statistics:
Type: <class 'str'> None
Value: /vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case27_Segmentation.nii.gz


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/nibabel/loadsave.py", line 42, in load
    stat_result = os.stat(filename)
PermissionError: [Errno 13] Permission denied: '/vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case25.nii.gz'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/transform.py", line 82, in apply_transform
    return _apply_transform(transform, data, unpack_items)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/transform.py", line 53, in _apply_transform
    return transform(parameters)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/io/dictionary.py", line 121, in __call__
    data = self._loader(d[key], reader)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/io/array.py", line 194, in __call__
    img = reader.read(filename)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/data/image_reader.py", line 383, in read
    img = nib.load(name, **kwargs_)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/nibabel/loadsave.py", line 44, in load
    raise FileNotFoundError(f"No such file or no access: '{filename}'")
FileNotFoundError: No such file or no access: '/vol/biomedic3/as217/vqseg/ProstateDomain/BMC/Case25.nii.gz'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "../trainprostVQBMCtoRUN.py", line 185, in __getitem__
    sample = self.get_sample(item)
  File "../trainprostVQBMCtoRUN.py", line 197, in get_sample
    sample = self.val_transforms(sample)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/compose.py", line 160, in __call__
    input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items)
  File "/vol/biomedic2/agk21/anaconda3/envs/quantization/lib/python3.7/site-packages/monai/transforms/transform.py", line 106, in apply_transform
    raise RuntimeError(f"applying transform {transform}") from e
RuntimeError: applying transform <monai.transforms.io.dictionary.LoadImaged object at 0x7fa97043bba8>


In [None]:
!nvidia-smi

In [None]:

qnet = QNet()
qnet.load_state_dict(torch.load('/vol/biomedic2/agk21/PhDLogs/codes/Vector-Quantisation-for-Robust-Segmentation/Logs/outputprostatefinal/outputprostatefinal/vqBCMtoRUM/best_metric_model-v1.ckpt')['state_dict'])
qnet = qnet.eval()

In [None]:
import pickle

def VQgenerate_TTA_results(name, data, noise_type='gauss', base_path='.'):
    nTTAs = 100
    zdata = {}
    
    assert noise_type in ['gauss', 's&p', 'poisson', 'speckle'], 'invalid noise_type'
    
    noise_threshold_list = [0.0, 0.01, 0.1, 0.5, 0.75]
    
    for noise_threshold in noise_threshold_list:
        embs = []; recons = []; emb_losses = []
        for _ in range(nTTAs):
            x = data['image'][:1, ...].numpy()
            dtype = data['image'].dtype
            x = noisy(noise_type, x, noise_threshold)
            x = torch.tensor(x).type(dtype).cuda() 
            with torch.no_grad():
                x, emb_loss, encoding = qnet._model(x)
            embs.append(encoding)
            recons.append(x)
            emb_losses.append(emb_loss)

        embs = torch.cat(embs, 0)
        emb_losses = torch.cat(emb_losses, 0)
        recons = torch.cat(recons, 0)
        recons = torch.argmax(recons, 1)
        
        zdata[noise_threshold] = {'emb': embs.cpu().numpy(),
                                      'emb_loss': emb_losses.cpu().numpy(),
                                      'recon': recons.cpu().squeeze().numpy(), 
                                      'img': data['image'][:1, ...].squeeze().numpy(), 
                                      'label': data['label'][:1, ...].squeeze().numpy()}

    base_path = os.path.join(base_path, noise_type)
    os.makedirs(base_path, exist_ok=True)
    with open(os.path.join(base_path,'{}.pickle'.format(name)), 'wb') as file:
        pickle.dump(zdata, file)
    return zdata

In [None]:
from tqdm import tqdm

all_zqdata = {}
for i in tqdm(range(100)):
    data = next(iter(val_dataloader))
    all_zqdata[i] = {}
    all_zqdata[i]['gauss'] = VQgenerate_TTA_results('image{}'.format(i), data, 'gauss', 'PRST/VQNet')
    all_zqdata[i]['s&p'] = VQgenerate_TTA_results('image{}'.format(i), data, 's&p', 'PRST/VQNet')
    all_zqdata[i]['poisson'] = VQgenerate_TTA_results('image{}'.format(i), data, 'poisson', 'PRST/VQNet')
    all_zqdata[i]['speckle'] = VQgenerate_TTA_results('image{}'.format(i), data, 'speckle', 'PRST/VQNet')
    
with open('all_data_PRST_VQNet.pickle', 'wb') as file:
        pickle.dump(all_zqdata, file)

In [None]:
zqdata[0.01]['emb'].shape, zdata[0.01]['emb'].shape

In [None]:
for noise_threshold in noise_threshold_list:
    print ("============================ noise: {} ===========".format(noise_threshold))
    z, zq = zdata[noise_threshold]['emb'], zqdata[noise_threshold]['emb']
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 0,0,0], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 0,0,0], cmap='coolwarm')
    plt.show()
    
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 1, 1], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 1, 1], cmap='coolwarm')
    plt.show()
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 0, 1], cmap='coolwarm')
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 0, 1], cmap='coolwarm')
    plt.show()
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 0, 0], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 0, 0], cmap='coolwarm')
    plt.show()


In [None]:
mean_dice = lambda p, l: np.mean([2*(p == i)*(l == i)/((p==i) + (l ==i) + 1e-3) for i in np.unique(l)])
import pandas as pd

def _draw_(dict_info, noise_type='gauss', base_dir= '.'):
    nimgs = len(dict_info.keys())
    test_ = dict_info[0][noise_type]
    noise_threshold_list = list(test_.keys())
    
    test_embs = test_[noise_threshold_list[0]]['emb'] # nttas, ...
    nttas = test_embs.shape[0]
    z_dim = np.prod(test_embs.shape[1:])
    
    
    dice_scores_x = []
    dice_scores_y = []
    categories = []
    
    for ni, noise_threshold in enumerate(noise_threshold_list):
        
            
        result_img = np.zeros((nimgs, z_dim))
        
        for iimg in range(nimgs):
            z = dict_info[iimg][noise_type][noise_threshold]['emb'].reshape(nttas, -1)
            z = (z - z.min())/(z.max() - z.min())
            
            result_img[iimgs, :] = np.var(z, axis=0).T
            
            if noise_threshold == 0.0:
                dice_scores_x.append(np.mean([mean_dice(dict_info[iimg][noise_type][noise_threshold]['recon'][itta, ...], 
                                                    dict_info[iimg][noise_type][noise_threshold]['label']) \
                                                      for itta in range(nttas)]))
            else:
                dice_scores_y.append(np.mean([mean_dice(dict_info[iimg][noise_type][noise_threshold]['recon'][itta, ...], 
                                                    dict_info[iimg][noise_type][noise_threshold]['label']) \
                                                      for itta in range(nttas)]))
                
                categories.append('Noise Threshold: {}'.format(noise_threshold))
                
        
        base_path = os.path.join(base_path, noise_type)
        base_path = os.path.join(base_path, 'plots')
        os.makedirs(base_path, exist_ok=True)
        
        plt.figure(figsize=(1, 10))
        plt.imshow(result_img, cmap='coolwarm', vmin=0, vmax=1)
        plt.axis (‘off’)
        plt.tight_layout()
        plt.savefig(os.path.join(base_path, 'NoiseT{}_NoiseType{}.png'.format(noise_threshold, noise_type)))
        
        
    df = pd.DataFrame()
    df['actual-dice'] = dice_scores_x * (len(noise_threshold_list) - 2)
    df['perturbed- dice'] = dice_scores_y
    df['categories'] = categories

    p = sns.jointplot(data=df, x='actual-dice', y='perturbed- dice',  hue="categories", alpha=0.5)
    p.plot_joint(sns.kdeplot, levels=20)
    p.fig.suptitle("aligned latent vectors")
    p.fig.tight_layout()
    p.fig.subplots_adjust(top=0.95)
    plt.savefig(os.path.join(base_path, 'dice_score_distribution_{}.png'.format(noise_type)))

In [None]:
with open('all_data_ABDCT_UNet.pickle', 'rb') as file:
    dict_info = pickle.load(file)
    
_draw_(dict_info, 'gauss', 'PRST/VQNet')
_draw_(dict_info, 's&p', 'PRST/VQNet')
_draw_(dict_info, 'poisson', 'PRST/VQNet')
_draw_(dict_info, 'speckle', 'PRST/VQNet')

In [None]:
with open('all_data_ABDCT_VQNet.pickle', 'rb') as file:
    dict_info = pickle.load(file)
    
_draw_(dict_info, 'gauss', 'PRST/VQNet')
_draw_(dict_info, 's&p', 'PRST/VQNet')
_draw_(dict_info, 'poisson', 'PRST/VQNet')
_draw_(dict_info, 'speckle', 'PRST/VQNet')