In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import random
import seaborn as sns
import os
import sys
sys.path.append('../')

from models import VQUNet, VQtransUNet
from trainer import Train
from skimage.util import random_noise

In [2]:
net = Net(num_classes=3)
net.load_state_dict(torch.load('/vol/biomedic2/agk21/PhDLogs/codes/Vector-Quantisation-for-Robust-Segmentation/Logs/outputchestsegfinal/outputchestsegfinal/unetNIH/best_metric_model-v1.ckpt')['state_dict'])
net = net.cuda()
net = net.eval()

In [3]:
from dataloaders import ChestXrayDataModule
loader = ChestXrayDataModule(
            batch_size=1,
            num_workers=4,
            binarise = False,
            csv_train_img=/VQRobustSegmentation/data/Chest/train.csv',
            csv_val_img=/VQRobustSegmentation/data/Chest/validation.csv',
            csv_test_img=/VQRobustSegmentation/data/Chest/test.csv',
        )

val_dataloader = loader.val_dataloader()

In [4]:
def noisy(noise_typ, image, weightage = 0.0):
    if noise_typ == "gauss":
        mean = 0
        var = 1
        sigma = var**0.5
        gauss = np.random.normal(mean, sigma,image.shape)
        gauss = gauss.reshape(*image.shape)
        noisy = image + weightage*gauss
        return noisy
    elif noise_typ == "s&p":
        s_vs_p = 0.5
        amount = weightage
        out = np.copy(image)
        # Salt mode
        num_salt = np.ceil(amount * image.size * s_vs_p) + 1
        coords = [np.random.randint(0, i - 1, int(num_salt))
              for i in image.shape[-2:]]
        for x, y in zip(*coords):
            out[0, 0, x, y] = 1

        # Pepper mode
        num_pepper = np.ceil(amount* image.size * (1. - s_vs_p)) + 1
        coords = [np.random.randint(0, i - 1, int(num_pepper))
              for i in image.shape[-2:]]
        for x, y in zip(*coords):
            out[0, 0, x, y] = 0
        return out
    
    elif noise_typ == "poisson":
        poisson_noise = np.sqrt(np.abs(image)) * np.random.normal(0, 1, image.shape)
        noisy = image + weightage* poisson_noise
        return noisy
    
    elif noise_typ =="speckle":
        gauss = np.random.randn(*image.shape)
        gauss = gauss.reshape(*image.shape)        
        noisy = image + weightage*image * gauss
        return noisy
    

In [5]:
import pickle


def generate_TTA_results(name, data, noise_type='gauss', base_path='.'):
    nTTAs = 25
    zdata = {}
    
    assert noise_type in ['gauss', 's&p', 'poisson', 'speckle'], 'invalid noise_type'
    
    noise_threshold_list = [0.0, 0.01, 0.1, 0.5, 0.75]
    
    for noise_threshold in noise_threshold_list:
        embs = []; recons = []
        for _ in range(nTTAs):
            x = data[0][:1, ...].numpy()
            dtype = data[0].dtype
            x = noisy(noise_type, x, noise_threshold)
            x = torch.tensor(x).type(dtype).cuda() 
            with torch.no_grad():
                x, encoding = net._model(x)
            embs.append(encoding)
            recons.append(x)

        embs = torch.cat(embs, 0)
        recons = torch.cat(recons, 0)
        recons = torch.argmax(recons, 1)
        
        zdata[noise_threshold] = {'emb': embs.cpu().numpy(), 
                                      'recon': recons.cpu().squeeze().numpy(), 
                                      'img': data[0][:1, ...].squeeze().numpy(), 
                                      'label': data[1][:1, ...].squeeze().numpy()}

    base_path = os.path.join(base_path, noise_type)
    os.makedirs(base_path, exist_ok=True)
    with open(os.path.join(base_path,'{}.pickle'.format(name)), 'wb') as file:
        pickle.dump(zdata, file)
    return zdata

In [None]:
from tqdm import tqdm

all_zdata = {}
for i in tqdm(range(100)):
    data = next(iter(val_dataloader))
    all_zdata[i] = {}
    all_zdata[i]['gauss'] = generate_TTA_results('image{}'.format(i), data, 'gauss', 'CXray/UNet')
    all_zdata[i]['s&p'] = generate_TTA_results('image{}'.format(i), data, 's&p', 'CXray/UNet')
    all_zdata[i]['poisson'] = generate_TTA_results('image{}'.format(i), data, 'poisson', 'CXray/UNet')
    all_zdata[i]['speckle'] = generate_TTA_results('image{}'.format(i), data, 'speckle', 'CXray/UNet')
    
with open('all_data_CXray_UNet.pickle', 'wb') as file:
        pickle.dump(all_zdata, file)

 23%|████████████████████████████████████▎                                                                                                                         | 23/100 [25:25<1:33:33, 72.91s/it]

In [None]:
data[0].device, net.device

In [None]:
!nvidia-smi

In [None]:
-


In [None]:
import pickle

def VQgenerate_TTA_results(name, data, noise_type='gauss', base_path='.'):
    nTTAs = 25
    zdata = {}
    
    assert noise_type in ['gauss', 's&p', 'poisson', 'speckle'], 'invalid noise_type'
    
    noise_threshold_list = [0.0, 0.01, 0.1, 0.5, 0.75]
    
    for noise_threshold in noise_threshold_list:
        embs = []; recons = []; emb_losses = []
        for _ in range(nTTAs):
            x = data[0][:1, ...].numpy()
            dtype = data[0].dtype
            x = noisy(noise_type, x, noise_threshold)
            x = torch.tensor(x).type(dtype).cuda() 
            with torch.no_grad():
                x, emb_loss, encoding, info = qnet._model(x)
            embs.append(encoding)
            recons.append(x)
            emb_losses.append(emb_loss.unsqueeze(0))

        embs = torch.cat(embs, 0)
        emb_losses = torch.cat(emb_losses, 0)
        recons = torch.cat(recons, 0)
        recons = torch.argmax(recons, 1)
        
        zdata[noise_threshold] = {'emb': embs.detach().cpu().numpy(),
                                      'emb_loss': emb_losses.detach().cpu().numpy(),
                                      'recon': recons.detach().cpu().squeeze().numpy(), 
                                      'img': data[0][:1, ...].squeeze().numpy(), 
                                      'label': data[1][:1, ...].squeeze().numpy()}

    base_path = os.path.join(base_path, noise_type)
    os.makedirs(base_path, exist_ok=True)
    with open(os.path.join(base_path,'{}.pickle'.format(name)), 'wb') as file:
        pickle.dump(zdata, file)

In [None]:
from tqdm import tqdm

all_zqdata = {}
for i in tqdm(range(100)):
    data = next(iter(val_dataloader))
    all_zqdata[i] = {}
    all_zqdata[i]['gauss'] = VQgenerate_TTA_results('image{}'.format(i), data, 'gauss', 'CXray/VQNet')
    all_zqdata[i]['s&p'] = VQgenerate_TTA_results('image{}'.format(i), data, 's&p', 'CXray/VQNet')
    all_zqdata[i]['poisson'] = VQgenerate_TTA_results('image{}'.format(i), data, 'poisson', 'CXray/VQNet')
    all_zqdata[i]['speckle'] = VQgenerate_TTA_results('image{}'.format(i), data, 'speckle', 'CXray/VQNet')
    
with open('all_data_CXray_VQNet.pickle', 'wb') as file:
        pickle.dump(all_zqdata, file)

In [None]:
mean_dice = lambda p, l: np.mean([2*(p == i)*(l == i)/((p==i) + (l ==i) + 1e-3) for i in np.unique(l)])
import pandas as pd

def _draw_(dict_info, noise_type='gauss', base_dir= '.'):
    nimgs = len(dict_info.keys())
    test_ = dict_info[0][noise_type]
    noise_threshold_list = list(test_.keys())
    
    test_embs = test_[noise_threshold_list[0]]['emb'] # nttas, ...
    nttas = test_embs.shape[0]
    z_dim = np.prod(test_embs.shape[1:])
    
    
    dice_scores_x = []
    dice_scores_y = []
    categories = []
    
    for ni, noise_threshold in enumerate(noise_threshold_list):
        
            
        result_img = np.zeros((nimgs, z_dim))
        
        for iimg in range(nimgs):
            z = dict_info[iimg][noise_type][noise_threshold]['emb'].reshape(nttas, -1)
            z = (z - z.min())/(z.max() - z.min())
            
            result_img[iimgs, :] = np.var(z, axis=0).T
            
            if noise_threshold == 0.0:
                dice_scores_x.append(np.mean([mean_dice(dict_info[iimg][noise_type][noise_threshold]['recon'][itta, ...], 
                                                    dict_info[iimg][noise_type][noise_threshold]['label']) \
                                                      for itta in range(nttas)]))
            else:
                dice_scores_y.append(np.mean([mean_dice(dict_info[iimg][noise_type][noise_threshold]['recon'][itta, ...], 
                                                    dict_info[iimg][noise_type][noise_threshold]['label']) \
                                                      for itta in range(nttas)]))
                
                categories.append('Noise Threshold: {}'.format(noise_threshold))
                
        
        base_path = os.path.join(base_path, noise_type)
        base_path = os.path.join(base_path, 'plots')
        os.makedirs(base_path, exist_ok=True)
        
        plt.figure(figsize=(1, 10))
        plt.imshow(result_img, cmap='coolwarm', vmin=0, vmax=1)
        plt.axis (‘off’)
        plt.tight_layout()
        plt.savefig(os.path.join(base_path, 'NoiseT{}_NoiseType{}.png'.format(noise_threshold, noise_type)))
        
        
    df = pd.DataFrame()
    df['actual-dice'] = dice_scores_x * (len(noise_threshold_list) - 2)
    df['perturbed- dice'] = dice_scores_y
    df['categories'] = categories

    p = sns.jointplot(data=df, x='actual-dice', y='perturbed- dice',  hue="categories", alpha=0.5)
    p.plot_joint(sns.kdeplot, levels=20)
    p.fig.suptitle("aligned latent vectors")
    p.fig.tight_layout()
    p.fig.subplots_adjust(top=0.95)
    plt.savefig(os.path.join(base_path, 'dice_score_distribution_{}.png'.format(noise_type)))

In [None]:
with open('all_data_CXray_UNet.pickle', 'rb') as file:
    dict_info = pickle.load(file)
    
_draw_(dict_info, 'gauss', 'CXray/VQNet')
_draw_(dict_info, 's&p', 'CXray/VQNet')
_draw_(dict_info, 'poisson', 'CXray/VQNet')
_draw_(dict_info, 'speckle', 'CXray/VQNet')

In [None]:
with open('all_data_CXray_VQNet.pickle', 'rb') as file:
    dict_info = pickle.load(file)
    
_draw_(dict_info, 'gauss', 'CXray/VQNet')
_draw_(dict_info, 's&p', 'CXray/VQNet')
_draw_(dict_info, 'poisson', 'CXray/VQNet')
_draw_(dict_info, 'speckle', 'CXray/VQNet')