In [35]:
import argparse
from timeit import default_timer as timer
from tqdm import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

import torch
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
import torchvision.transforms as tr

from data import BraTSDatasetUnet, TestDataset, BraTSDatasetLSTM
from losses import DICELossMultiClass
from models import UNet

import os
import itertools
import random
%matplotlib inline

In [36]:
def generate_colorimg(output):
    colors = [[0, 0, 0], [0, 0, 255], [255, 0, 0], [0, 255, 0], [255, 255, 0], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
    _, height, width = output.shape
    colorimg = np.zeros((height, width, 3), dtype=np.uint8)
    #colorimg = np.full((height, width, 3), 255, dtype=np.uint8)
    #colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    for y in range(height):
        for x in range(width):
            selected_color = colors[output[0,y,x]]
            colorimg[y,x,:] = selected_color
    return tr.ToTensor()(colorimg.astype(np.uint8))

In [37]:
def test_only(model, loader, criterion, cuda=False):
    model.eval()
    for batch_idx, image in tqdm(enumerate(loader)):
        if cuda:
            image = image.cuda()
        with torch.no_grad():
            image = Variable(image)
            pred = model(image)
            maxes, out = torch.max(pred, 1, keepdim=True)
            
        if (batch_idx == 1):
            out[out == 1] = 2
            out[:,:,:,120:240] = 1
            
        print(np.unique(out))
        print(out)
        save_image(image, './output/images/images-batch-{}.png'.format(batch_idx))
        save_image(out, './output/predictions/outputs-batch-{}.png'.format(batch_idx), normalize=True)
        new_outs = []
        for o in out:
            new_outs.append(generate_colorimg(o))
        save_image(new_outs, './output/rgb-outs/rgb-batch-{}.png'.format(batch_idx))

In [38]:
model = UNet(num_channels=1, num_classes=3)
criterion = DICELossMultiClass()
model.load_state_dict(torch.load('unet-multiclass-model-16-100-0.001', map_location='cpu'))

In [39]:
test_dataset = TestDataset('./Data/', im_size=[256, 256], transform=tr.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1)
print("Test Data : ", len(test_loader.dataset))
test_only(model, test_loader, criterion)

Test Data :  13



0it [00:00, ?it/s]

[0 1]
tensor([[[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]]])



1it [00:19, 19.85s/it]

[0 1 2]
tensor([[[[2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          ...,
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2]]],


        [[[2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          ...,
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2]]],


        [[[2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          ...,
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2]]],


        [[[2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          ...,
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2],
          [2, 2, 2,  ..., 2, 2, 2]]]])



2it [00:36, 18.83s/it]

[0 1]
tensor([[[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]],


        [[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]]])



3it [00:52, 18.13s/it]

[0 1]
tensor([[[[1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          ...,
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1],
          [1, 1, 1,  ..., 1, 1, 1]]]])



4it [00:57, 14.18s/it]