In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import random
import seaborn as sns

import sys
sys.path.append('../')

from vqseg1 import Net

In [None]:
net = Net()
net.load_state_dict(torch.load('/vol/biomedic2/agk21/PhDLogs/codes/Vector-Quantisation-for-Robust-Segmentation/output4/disease1/vqseg1/version_1/checkpoints/epoch=289-step=6959.ckpt')['state_dict'])
net = net.eval()

In [None]:
net.prepare_data()
val_dataloader = net.val_dataloader()

In [None]:
for i, data in enumerate(val_dataloader):
    if i > 0: break
    
    print (data['image'].shape, data['label'].shape)

In [None]:
import pickle

nTTAs = 100

zdata = {}

noise_threshold_list = [0.01, 0.1, 0.5, 0.75, 0.9]

for noise_threshold in noise_threshold_list:
    embs = []; recons = []

    for _ in range(nTTAs):
        x = data['image'][:1, ...].cuda() + noise_threshold*torch.randn(*data['image'][:1, ...].shape).cuda()
        x, encoding = net.forward(x)
        embs.append(encoding)
        recons.append(x)

    embs = torch.cat(embs, 0)
    recons = torch.cat(recons, 0)
    
    zdata[noise_threshold] = {'emb': embs.detach().cpu().numpy(), 
                                  'recon': recons.detach().cpu().numpy(), 
                                  'img': data['image'][:1, ...].numpy(), 
                                  'label': data['label'][:1, ...].numpy()}


with open('data.pickle', 'wb') as file:
    pickle.dump(zdata, file)

In [None]:
embs.shape

In [None]:
from vqseg import Net as QNet

qnet = QNet()
qnet.load_state_dict(torch.load('/vol/biomedic2/agk21/PhDLogs/codes/Vector-Quantisation-for-Robust-Segmentation/output4/disease1/vqseg/version_1/checkpoints/epoch=294-step=7079.ckpt')['state_dict'])
qnet = qnet.eval()

In [None]:
import pickle

nTTAs = 100

zqdata = {}

noise_threshold_list = [0.01, 0.1, 0.5, 0.75, 0.9]

for noise_threshold in noise_threshold_list:
    embs = []; recons = []

    for _ in range(nTTAs):
        x = data['image'][:1, ...].cuda() + noise_threshold*torch.randn(*data['image'][:1, ...].shape).cuda()
        x, encoding = qnet.forward(x)
        embs.append(encoding)
        recons.append(x)

    embs = torch.cat(embs, 0)
    recons = torch.cat(recons, 0)
    
    zqdata[noise_threshold] = {'emb': embs.detach().cpu().numpy(), 
                                  'recon': recons.detach().cpu().numpy(), 
                                  'img': data['image'][:1, ...].numpy(), 
                                  'label': data['label'][:1, ...].numpy()}


with open('data.pickle', 'wb') as file:
    pickle.dump(zqdata, file)

In [None]:
zqdata[0.01]['emb'].shape, zdata[0.01]['emb'].shape

In [None]:
for noise_threshold in noise_threshold_list:
    print ("============================ noise: {} ===========".format(noise_threshold))
    z, zq = zdata[noise_threshold]['emb'], zqdata[noise_threshold]['emb']
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 0,0,0], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 0,0,0], cmap='coolwarm')
    plt.show()
    
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 1, 1], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 1, 1], cmap='coolwarm')
    plt.show()
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 0, 1], cmap='coolwarm')
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 0, 1], cmap='coolwarm')
    plt.show()
    
    plt.figure(figsize=(50, 50))
    plt.subplot(1, 2, 1)
    plt.imshow(z[:, :, 1, 0, 0], cmap='coolwarm')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(zq[:, :, 1, 0, 0], cmap='coolwarm')
    plt.show()


In [None]:
for noise_threshold in noise_threshold_list:
    print ("============================ noise: {} ===========".format(noise_threshold))
    z, zq = zdata[noise_threshold]['emb'].reshape(100, 256, -1), zqdata[noise_threshold]['emb'].reshape(100, 256, -1)
    
    maxZ = np.max(z, axis=0).T
    plt.figure(figsize=(50, 5))
    plt.subplot(2, 1, 1)
    plt.imshow(np.var(z, axis=0).T/maxZ, cmap='coolwarm', vmin=0, vmax=1)
    
    plt.subplot(2, 1, 2)
    plt.imshow(np.var(zq, axis=0).T/maxZ, cmap='coolwarm', vmin=0, vmax=1)
    plt.show()