In [26]:
import numpy as np
from utils.data_utils import get_imdb_dataset
from utils.evaluator import binarize_label
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F


In [2]:
data_params = dict({'data_dir': "datasets/silver_corpus",
'train_data_file' : "Data_train.h5",
'train_label_file' : "Label_train.h5",
'train_class_weights_file' : "Class_Weight_train.h5",
'train_weights_file' : "Weight_train.h5",
'test_data_file' : "Data_test.h5",
'test_label_file' : "Label_test.h5",
'test_class_weights_file' : "Class_Weight_test.h5",
'test_weights_file' : "Weight_test.h5",
'labels' : ["Background", "Left WM", "Left Cortex", "Left Lateral ventricle", "Left Inf LatVentricle", "Left Cerebellum WM", "Left Cerebellum Cortex", "Left Thalamus", "Left Caudate", "Left Putamen", "Left Pallidum", "3rd Ventricle", "4th Ventricle", "Brain Stem", "Left Hippocampus", "Left Amygdala", "CSF (Cranial)", "Left Accumbens", "Left Ventral DC", "Right WM", "Right Cortex", "Right Lateral Ventricle", "Right Inf LatVentricle", "Right Cerebellum WM", "Right Cerebellum Cortex", "Right Thalamus", "Right Caudate", "Right Putamen", "Right Pallidum", "Right Hippocampus", "Right Amygdala", "Right Accumbens", "Right Ventral DC"]
                   })

def load_data(data_params):
    print("Loading dataset")
    train_data, test_data = get_imdb_dataset(data_params)
    print("Train size: %i" % len(train_data))
    print("Test size: %i" % len(test_data))
    return train_data, test_data

train_data, test_data = load_data(data_params)

Loading dataset
Train size: 2576
Test size: 304


In [283]:
class_label = 7
fold = '2'

X, y = train_data.X, train_data.y
y = (y == class_label)
y = y.astype(np.float32)

batch_size, _, _ = y.shape

slice_with_class = np.sum(y.reshape(batch_size, -1), axis=1) > 10
X = X[slice_with_class]
y = y[slice_with_class]

query_slice = np.random.randint(0, len(X))
support_slice = np.random.randint(0, len(X))


no_skip_model = torch.load('saved_models/sne_position_all_type_spatial_fold'+fold+'.pth.tar')
# skip_model = torch.load('saved_models/sne_position_all_type_spatial_skipconn_baseline_fold'+fold+'.pth.tar')

no_skip_model.cuda()
no_skip_model.eval()
# skip_model.cuda()
# skip_model.eval()

query_input = torch.tensor(X[query_slice])
query_gt = torch.tensor(y[query_slice])

support_input = torch.tensor(X[support_slice])
support_gt = torch.tensor(y[support_slice])

support_gt = support_gt.unsqueeze(0)

condition_input = torch.cat((support_input, support_gt), dim=0)

query_input = query_input.unsqueeze(0)
condition_input = condition_input.unsqueeze(0)

query_input = query_input.cuda()
condition_input = condition_input.cuda()

weights = no_skip_model.conditioner(condition_input)
out = no_skip_model.segmentor(query_input, weights)

_, segmentation_no_chhapa = torch.max(F.softmax(out, dim=1), dim=1)

# weights = skip_model.conditioner(condition_input)
# out = skip_model.segmentor(query_input, weights)

# _, segmentation_chhapa = torch.max(F.softmax(out, dim=1), dim=1)

ncols = 5
fig, ax = plt.subplots(nrows=1, ncols=ncols, figsize=(20, 10), squeeze=False)

ax[0][0].imshow(torch.squeeze(query_input), cmap='gray', vmin=0, vmax=1)
ax[0][0].set_title("Query Input", fontsize=10, color="blue")
ax[0][0].axis('off')
ax[0][1].imshow(torch.squeeze(query_gt), cmap='gray', vmin=0, vmax=1)
ax[0][1].set_title("Query GT", fontsize=10, color="blue")
ax[0][1].axis('off')
ax[0][2].imshow(torch.squeeze(segmentation_no_chhapa), cmap='gray', vmin=0, vmax=1)
ax[0][2].set_title("No Chhapa", fontsize=10, color="blue")
ax[0][2].axis('off')
# ax[0][3].imshow(torch.squeeze(segmentation_chhapa), cmap='gray', vmin=0, vmax=1)
# ax[0][3].set_title("Chhapa", fontsize=10, color="blue")
# ax[0][3].axis('off')
ax[0][3].imshow(torch.squeeze(support_input), cmap='gray', vmin=0, vmax=1)
ax[0][3].set_title("Support Input", fontsize=10, color="blue")
ax[0][3].axis('off')
ax[0][4].imshow(torch.squeeze(support_gt), cmap='gray', vmin=0, vmax=1)
ax[0][4].set_title("Support GT", fontsize=10, color="blue")
ax[0][4].axis('off')

fig.set_tight_layout(True)
plt.show()
