In [3]:
import argparse
from timeit import default_timer as timer
from tqdm import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

import torch
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
import torchvision.transforms as tr

from data import BraTSDatasetUnet, TestDataset, BraTSDatasetLSTM
from losses import DICELossMultiClass
from models import UNet


In [93]:
model = UNet(num_channels=1, num_classes=2)
criterion = DICELossMultiClass()

In [94]:
from PIL import Image
import os

dir_path = './Data/images/'
sizes = []
for f in os.listdir(dir_path):
    im = Image.open(os.path.join(dir_path, f))
    print(im.size)
    sizes.append(im.size)
print(sizes)

(424, 548)
(424, 550)
(424, 550)
(408, 552)
(412, 559)
(640, 833)
(424, 549)
(420, 549)
(420, 549)
(432, 550)
(432, 550)
(424, 549)
(636, 829)
[(424, 548), (424, 550), (424, 550), (408, 552), (412, 559), (640, 833), (424, 549), (420, 549), (420, 549), (432, 550), (432, 550), (424, 549), (636, 829)]


In [125]:
def test_only(model, loader, criterion, cuda=False):
    model.eval()
    for batch_idx, image in tqdm(enumerate(loader)):
        if cuda:
            image = image.cuda()
        with torch.no_grad():
            image = Variable(image)
            pred = model(image)
            maxes, out = torch.max(pred, 1, keepdim=True)

        for i, output in enumerate(out):
            back_img = tr.ToPILImage()(output.type(torch.float32))  
            index = batch_idx*4 + i
            back_img = back_img.resize(sizes[index])
            back_img.save('./output/single/out-{}.png'.format(index))
        save_image(image, './output/images/images-batch-{}.png'.format(batch_idx))
        save_image(out, './output/predictions/outputs-batch-{}.png'.format(batch_idx))

In [126]:
test_dataset = TestDataset('./Data/', im_size=[256, 256], transform=tr.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)
print("Test Data : ", len(test_loader.dataset))
model.load_state_dict(torch.load('unet-model-16-100-0.001', map_location='cpu'))
test_only(model, test_loader, criterion)

Test Data :  13
























0it [00:00, ?it/s]

[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
























1it [00:16, 16.06s/it]

[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
























2it [00:29, 15.22s/it]

[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]
























3it [00:42, 14.59s/it]

[0. 1.]
























4it [00:46, 11.44s/it]

In [5]:
import os
import itertools

In [6]:
images_list = []
masks_list = []
output_list = []

for i in itertools.count():
    if not os.path.isfile('./npy-files/out-files/OutMasks-batch-{}-images.npy'.format(i)):
        break
    outs = np.load('./npy-files/out-files/OutMasks-batch-{}-outs.npy'.format(i))
    masks = np.load('./npy-files/out-files/OutMasks-batch-{}-masks.npy'.format(i))
    images = np.load('./npy-files/out-files/OutMasks-batch-{}-images.npy'.format(i))
    output_list.append(outs)
    masks_list.append(masks)
    images_list.append(images)

In [7]:
output_list[0][0].shape

(1, 128, 128)

In [8]:
output_list[0][0]

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=int64)

In [9]:
output_list[0][0, :, :] == output_list[0][0]

array([[[ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        ...,
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True],
        [ True,  True,  True, ...,  True,  True,  True]]])

In [None]:
#colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
#colors = np.asarray([(255, 0, 0), (0, 255, 0), (0, 0, 255)])

In [None]:
for i in range(len(images_list)):
    for j in range(images_list[i].shape[0]):
        np_image = images_list[i][j].transpose((1, 2, 0))
        np_image = (np_image * 255).astype(np.uint8)
        img = Image.fromarray(np_image, 'L')
        img.save('./output/image-{}-{}.png'.format(i, j))
        
        np_mask = masks_list[i][j].transpose((1, 2, 0))
        np_mask = (np_mask * 255).astype(np.uint8)
        mask_img = Image.fromarray(np_mask, 'RGB')
        mask_img.save('./output/mask-{}-{}.png'.format(i, j))
        
        np_out = 255 * np.squeeze(output_list[i][j]).astype('uint8')
        out_img = Image.fromarray(np_out, 'L')
        out_img.save('./output/out-{}-{}.png'.format(i, j))