In [5]:
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

import torch
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
import torchvision.transforms as tr

from data import BraTSDatasetUnet, BraTSDatasetLSTM
from losses import DICELossMultiClass
from models import UNet

In [6]:
def train(model, epoch, loss_list, train_loader, optimizer, criterion, cuda=False):
    model.train()
    for batch_idx, (image, mask) in enumerate(train_loader):
        if cuda:
            image, mask = image.cuda(), mask.cuda()

        image, mask = Variable(image), Variable(mask)

        optimizer.zero_grad()

        output = model(image)
        loss = criterion(output, mask)
        loss_list.append(loss.item())

        loss.backward()
        optimizer.step()

        if batch_idx % 1 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage DICE Loss: {:.6f}'.format(
                epoch, batch_idx * len(image), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [27]:
def to_colorimg(output):
    colors = [[0, 0, 255], [0, 255, 0], [255, 0, 0], [255, 255, 0], [0, 255, 255], [255, 0, 255], [255, 255, 255], [0, 0, 0]]
    #colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])

    _, height, width = output.shape
    colorimg = np.zeros((height, width, 3), dtype=np.uint8)
    #colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    for y in range(height):
        for x in range(width):
            selected_color = colors[output[0,y,x]]
            colorimg[y,x,:] = selected_color
    return tr.ToTensor()(colorimg.astype(np.uint8))

In [28]:
def test(model, loader, criterion, cuda=False, validation=False, save_output=False):
    test_loss = 0
    model.eval()
    for batch_idx, (image, mask) in tqdm(enumerate(loader)):
        if cuda:
            image, mask = image.cuda(), mask.cuda()

        with torch.no_grad():
            image, mask = Variable(image), Variable(mask)
            pred = model(image)
            print('Test  output.size: ',pred.size())
            print(pred)
            #pred = torch.sigmoid(pred)
            maxes, out = torch.max(pred, 1, keepdim=True)
            print('Test  out.size: ',out.size())
            print(out)
            print(np.unique(out))
            

        
        if save_output:
            save_image(image, './output/images/images-batch-{}.png'.format(batch_idx))
            save_image(mask, './output/masks/masks-batch-{}.png'.format(batch_idx))
            save_image(out, './output/predictions/outputs-batch-{}.png'.format(batch_idx))
            
            new_outs = []
            for o in out:
                new_outs.append(to_colorimg(o))
            save_image(new_outs, './output/predictions/color-batch-{}.png'.format(batch_idx))

        test_loss += criterion(pred, mask).item()
    # Average Dice Coefficient
    test_loss /= len(loader)
    if validation:
        print('\nValidation Set: Average DICE Coefficient: {:.4f}\n'.format(test_loss))
    else:
        print('\nTest Set: Average DICE Coefficient: {:.4f}\n'.format(test_loss))

In [29]:
test(model, test_loader, criterion)

0it [00:00, ?it/s]

Test  output.size:  torch.Size([4, 3, 128, 128])
tensor([[[[0.3366, 0.3362, 0.3363,  ..., 0.3363, 0.3362, 0.3357],
          [0.3368, 0.3362, 0.3362,  ..., 0.3362, 0.3361, 0.3354],
          [0.3366, 0.3364, 0.3360,  ..., 0.3362, 0.3359, 0.3355],
          ...,
          [0.3368, 0.3364, 0.3359,  ..., 0.3361, 0.3360, 0.3357],
          [0.3370, 0.3363, 0.3360,  ..., 0.3361, 0.3361, 0.3352],
          [0.3367, 0.3363, 0.3361,  ..., 0.3362, 0.3361, 0.3360]],

         [[0.3336, 0.3341, 0.3342,  ..., 0.3344, 0.3341, 0.3352],
          [0.3337, 0.3347, 0.3345,  ..., 0.3343, 0.3340, 0.3352],
          [0.3341, 0.3345, 0.3347,  ..., 0.3343, 0.3345, 0.3353],
          ...,
          [0.3338, 0.3342, 0.3342,  ..., 0.3343, 0.3345, 0.3353],
          [0.3339, 0.3342, 0.3345,  ..., 0.3343, 0.3349, 0.3355],
          [0.3345, 0.3347, 0.3348,  ..., 0.3347, 0.3347, 0.3349]],

         [[0.3298, 0.3297, 0.3295,  ..., 0.3293, 0.3296, 0.3292],
          [0.3295, 0.3291, 0.3293,  ..., 0.3295, 0.3300, 0.

1it [00:07,  7.32s/it]

Test  output.size:  torch.Size([4, 3, 128, 128])
tensor([[[[0.3366, 0.3363, 0.3362,  ..., 0.3362, 0.3362, 0.3357],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          ...,
          [0.3368, 0.3364, 0.3359,  ..., 0.3362, 0.3359, 0.3356],
          [0.3370, 0.3364, 0.3361,  ..., 0.3361, 0.3361, 0.3352],
          [0.3367, 0.3363, 0.3361,  ..., 0.3362, 0.3361, 0.3360]],

         [[0.3336, 0.3342, 0.3343,  ..., 0.3341, 0.3341, 0.3352],
          [0.3340, 0.3343, 0.3341,  ..., 0.3341, 0.3340, 0.3351],
          [0.3341, 0.3343, 0.3344,  ..., 0.3342, 0.3345, 0.3352],
          ...,
          [0.3338, 0.3342, 0.3343,  ..., 0.3344, 0.3346, 0.3354],
          [0.3339, 0.3341, 0.3344,  ..., 0.3342, 0.3348, 0.3355],
          [0.3345, 0.3347, 0.3348,  ..., 0.3347, 0.3347, 0.3349]],

         [[0.3298, 0.3295, 0.3295,  ..., 0.3298, 0.3298, 0.3292],
          [0.3295, 0.3296, 0.3301,  ..., 0.3299, 0.3300, 0.

2it [00:11,  6.46s/it]

Test  output.size:  torch.Size([4, 3, 128, 128])
tensor([[[[0.3366, 0.3363, 0.3362,  ..., 0.3362, 0.3362, 0.3357],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          ...,
          [0.3368, 0.3364, 0.3359,  ..., 0.3362, 0.3359, 0.3356],
          [0.3370, 0.3364, 0.3361,  ..., 0.3361, 0.3361, 0.3352],
          [0.3367, 0.3363, 0.3361,  ..., 0.3362, 0.3361, 0.3360]],

         [[0.3336, 0.3342, 0.3343,  ..., 0.3341, 0.3341, 0.3352],
          [0.3340, 0.3343, 0.3341,  ..., 0.3341, 0.3340, 0.3351],
          [0.3341, 0.3343, 0.3344,  ..., 0.3342, 0.3345, 0.3352],
          ...,
          [0.3338, 0.3342, 0.3343,  ..., 0.3344, 0.3346, 0.3354],
          [0.3339, 0.3341, 0.3344,  ..., 0.3342, 0.3348, 0.3355],
          [0.3345, 0.3347, 0.3348,  ..., 0.3347, 0.3347, 0.3349]],

         [[0.3298, 0.3295, 0.3295,  ..., 0.3298, 0.3298, 0.3292],
          [0.3295, 0.3296, 0.3301,  ..., 0.3299, 0.3300, 0.

3it [00:16,  5.85s/it]

Test  output.size:  torch.Size([1, 3, 128, 128])
tensor([[[[0.3366, 0.3363, 0.3362,  ..., 0.3362, 0.3362, 0.3357],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          ...,
          [0.3368, 0.3364, 0.3359,  ..., 0.3362, 0.3359, 0.3356],
          [0.3370, 0.3364, 0.3361,  ..., 0.3361, 0.3361, 0.3352],
          [0.3367, 0.3363, 0.3361,  ..., 0.3362, 0.3361, 0.3360]],

         [[0.3336, 0.3342, 0.3343,  ..., 0.3341, 0.3341, 0.3352],
          [0.3340, 0.3343, 0.3341,  ..., 0.3341, 0.3340, 0.3351],
          [0.3341, 0.3343, 0.3344,  ..., 0.3342, 0.3345, 0.3352],
          ...,
          [0.3338, 0.3342, 0.3343,  ..., 0.3344, 0.3346, 0.3354],
          [0.3339, 0.3341, 0.3344,  ..., 0.3342, 0.3348, 0.3355],
          [0.3345, 0.3347, 0.3348,  ..., 0.3347, 0.3347, 0.3349]],

         [[0.3298, 0.3295, 0.3295,  ..., 0.3298, 0.3298, 0.3292],
          [0.3295, 0.3296, 0.3301,  ..., 0.3299, 0.3300, 0.

4it [00:17,  4.53s/it]



Test Set: Average DICE Coefficient: 0.5800



In [8]:
DATA_FOLDER = './Data/'

full_dataset = BraTSDatasetUnet(DATA_FOLDER, im_size=[128, 128], transform=tr.ToTensor())
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, validation_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1)
validation_loader = DataLoader(validation_dataset, batch_size=4, shuffle=False, num_workers=1)
test_loader = DataLoader(full_dataset, batch_size=4, shuffle=False, num_workers=1)

print("Training Data : ", len(train_loader.dataset))
print("Validaion Data : ", len(validation_loader.dataset))
print("Test Data : ", len(test_loader.dataset))

model = UNet(num_channels=1, num_classes=3)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.99)
criterion = DICELossMultiClass()

loss_list = []
for i in tqdm(range(2)):
    train(model, i, loss_list, train_loader, optimizer, criterion)
    test(model, validation_loader, criterion, validation=True)

        


Training Data :  11
Validaion Data :  2
Test Data :  13


  0%|                                                    | 0/2 [00:00<?, ?it/s]




0it [00:00, ?it/s]

Test  output.size:  torch.Size([2, 3, 128, 128])
tensor([[[[0.3361, 0.3359, 0.3359,  ..., 0.3359, 0.3359, 0.3356],
          [0.3361, 0.3359, 0.3358,  ..., 0.3358, 0.3358, 0.3355],
          [0.3361, 0.3358, 0.3358,  ..., 0.3359, 0.3359, 0.3355],
          ...,
          [0.3363, 0.3360, 0.3359,  ..., 0.3359, 0.3358, 0.3355],
          [0.3364, 0.3359, 0.3358,  ..., 0.3359, 0.3358, 0.3352],
          [0.3362, 0.3358, 0.3357,  ..., 0.3358, 0.3358, 0.3357]],

         [[0.3330, 0.3333, 0.3333,  ..., 0.3332, 0.3333, 0.3337],
          [0.3330, 0.3333, 0.3332,  ..., 0.3332, 0.3334, 0.3337],
          [0.3332, 0.3334, 0.3333,  ..., 0.3331, 0.3335, 0.3336],
          ...,
          [0.3330, 0.3334, 0.3333,  ..., 0.3333, 0.3336, 0.3337],
          [0.3330, 0.3335, 0.3333,  ..., 0.3332, 0.3336, 0.3337],
          [0.3332, 0.3335, 0.3335,  ..., 0.3335, 0.3335, 0.3335]],

         [[0.3310, 0.3308, 0.3308,  ..., 0.3309, 0.3308, 0.3307],
          [0.3309, 0.3309, 0.3310,  ..., 0.3309, 0.3308, 0.


1it [00:04,  4.79s/it]


Validation Set: Average DICE Coefficient: 0.6459



 50%|██████████████████████                      | 1/2 [00:42<00:42, 42.52s/it]




0it [00:00, ?it/s]

Test  output.size:  torch.Size([2, 3, 128, 128])
tensor([[[[0.3366, 0.3363, 0.3362,  ..., 0.3362, 0.3362, 0.3357],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          [0.3365, 0.3361, 0.3358,  ..., 0.3360, 0.3360, 0.3355],
          ...,
          [0.3368, 0.3364, 0.3359,  ..., 0.3362, 0.3359, 0.3356],
          [0.3370, 0.3364, 0.3361,  ..., 0.3361, 0.3361, 0.3352],
          [0.3367, 0.3363, 0.3361,  ..., 0.3362, 0.3361, 0.3360]],

         [[0.3336, 0.3342, 0.3343,  ..., 0.3341, 0.3341, 0.3352],
          [0.3340, 0.3343, 0.3341,  ..., 0.3341, 0.3340, 0.3351],
          [0.3341, 0.3343, 0.3344,  ..., 0.3342, 0.3345, 0.3352],
          ...,
          [0.3338, 0.3342, 0.3343,  ..., 0.3344, 0.3346, 0.3354],
          [0.3339, 0.3341, 0.3344,  ..., 0.3342, 0.3348, 0.3355],
          [0.3345, 0.3347, 0.3348,  ..., 0.3347, 0.3347, 0.3349]],

         [[0.3298, 0.3295, 0.3295,  ..., 0.3298, 0.3298, 0.3292],
          [0.3295, 0.3296, 0.3301,  ..., 0.3299, 0.3300, 0.


1it [00:04,  4.62s/it]


Validation Set: Average DICE Coefficient: 0.6458



100%|████████████████████████████████████████████| 2/2 [01:13<00:00, 39.15s/it]


In [5]:
test(model, test_loader, criterion, save_output=True)

0it [00:00, ?it/s]

Test  output.size:  torch.Size([4, 2, 128, 128])
tensor([[[[0.5539, 0.5544, 0.5544,  ..., 0.5540, 0.5546, 0.5547],
          [0.5540, 0.5546, 0.5545,  ..., 0.5541, 0.5541, 0.5549],
          [0.5544, 0.5541, 0.5542,  ..., 0.5543, 0.5547, 0.5552],
          ...,
          [0.5542, 0.5543, 0.5542,  ..., 0.5544, 0.5548, 0.5550],
          [0.5539, 0.5544, 0.5545,  ..., 0.5546, 0.5550, 0.5550],
          [0.5540, 0.5543, 0.5544,  ..., 0.5543, 0.5545, 0.5543]],

         [[0.4461, 0.4456, 0.4456,  ..., 0.4460, 0.4454, 0.4453],
          [0.4460, 0.4454, 0.4455,  ..., 0.4459, 0.4459, 0.4451],
          [0.4456, 0.4459, 0.4458,  ..., 0.4457, 0.4453, 0.4448],
          ...,
          [0.4458, 0.4457, 0.4458,  ..., 0.4456, 0.4452, 0.4450],
          [0.4461, 0.4456, 0.4455,  ..., 0.4454, 0.4450, 0.4450],
          [0.4460, 0.4457, 0.4456,  ..., 0.4457, 0.4455, 0.4457]]],


        [[[0.5542, 0.5543, 0.5544,  ..., 0.5539, 0.5541, 0.5546],
          [0.5541, 0.5542, 0.5539,  ..., 0.5545, 0.5540, 

1it [00:06,  6.27s/it]

Test  output.size:  torch.Size([4, 2, 128, 128])
tensor([[[[0.5542, 0.5543, 0.5545,  ..., 0.5544, 0.5546, 0.5547],
          [0.5541, 0.5541, 0.5540,  ..., 0.5543, 0.5545, 0.5549],
          [0.5543, 0.5546, 0.5545,  ..., 0.5546, 0.5550, 0.5552],
          ...,
          [0.5542, 0.5543, 0.5542,  ..., 0.5544, 0.5547, 0.5551],
          [0.5539, 0.5544, 0.5544,  ..., 0.5546, 0.5550, 0.5551],
          [0.5541, 0.5542, 0.5544,  ..., 0.5543, 0.5545, 0.5543]],

         [[0.4458, 0.4457, 0.4455,  ..., 0.4456, 0.4454, 0.4453],
          [0.4459, 0.4459, 0.4460,  ..., 0.4457, 0.4455, 0.4451],
          [0.4457, 0.4454, 0.4455,  ..., 0.4454, 0.4450, 0.4448],
          ...,
          [0.4458, 0.4457, 0.4458,  ..., 0.4456, 0.4453, 0.4449],
          [0.4461, 0.4456, 0.4456,  ..., 0.4454, 0.4450, 0.4449],
          [0.4459, 0.4458, 0.4456,  ..., 0.4457, 0.4455, 0.4457]]],


        [[[0.5542, 0.5543, 0.5545,  ..., 0.5544, 0.5546, 0.5547],
          [0.5541, 0.5541, 0.5540,  ..., 0.5543, 0.5545, 

2it [00:09,  5.46s/it]

Test  output.size:  torch.Size([4, 2, 128, 128])
tensor([[[[0.5542, 0.5543, 0.5545,  ..., 0.5544, 0.5546, 0.5547],
          [0.5541, 0.5541, 0.5540,  ..., 0.5543, 0.5545, 0.5549],
          [0.5543, 0.5546, 0.5545,  ..., 0.5546, 0.5550, 0.5552],
          ...,
          [0.5542, 0.5543, 0.5542,  ..., 0.5544, 0.5547, 0.5551],
          [0.5539, 0.5544, 0.5544,  ..., 0.5546, 0.5550, 0.5551],
          [0.5541, 0.5542, 0.5544,  ..., 0.5543, 0.5545, 0.5543]],

         [[0.4458, 0.4457, 0.4455,  ..., 0.4456, 0.4454, 0.4453],
          [0.4459, 0.4459, 0.4460,  ..., 0.4457, 0.4455, 0.4451],
          [0.4457, 0.4454, 0.4455,  ..., 0.4454, 0.4450, 0.4448],
          ...,
          [0.4458, 0.4457, 0.4458,  ..., 0.4456, 0.4453, 0.4449],
          [0.4461, 0.4456, 0.4456,  ..., 0.4454, 0.4450, 0.4449],
          [0.4459, 0.4458, 0.4456,  ..., 0.4457, 0.4455, 0.4457]]],


        [[[0.5542, 0.5543, 0.5545,  ..., 0.5539, 0.5543, 0.5546],
          [0.5541, 0.5541, 0.5540,  ..., 0.5546, 0.5547, 

3it [00:13,  4.85s/it]

Test  output.size:  torch.Size([1, 2, 128, 128])
tensor([[[[0.5542, 0.5543, 0.5545,  ..., 0.5544, 0.5546, 0.5547],
          [0.5541, 0.5541, 0.5540,  ..., 0.5543, 0.5545, 0.5549],
          [0.5543, 0.5546, 0.5545,  ..., 0.5546, 0.5550, 0.5552],
          ...,
          [0.5542, 0.5543, 0.5542,  ..., 0.5544, 0.5547, 0.5551],
          [0.5539, 0.5544, 0.5544,  ..., 0.5546, 0.5550, 0.5551],
          [0.5541, 0.5542, 0.5544,  ..., 0.5543, 0.5545, 0.5543]],

         [[0.4458, 0.4457, 0.4455,  ..., 0.4456, 0.4454, 0.4453],
          [0.4459, 0.4459, 0.4460,  ..., 0.4457, 0.4455, 0.4451],
          [0.4457, 0.4454, 0.4455,  ..., 0.4454, 0.4450, 0.4448],
          ...,
          [0.4458, 0.4457, 0.4458,  ..., 0.4456, 0.4453, 0.4449],
          [0.4461, 0.4456, 0.4456,  ..., 0.4454, 0.4450, 0.4449],
          [0.4459, 0.4458, 0.4456,  ..., 0.4457, 0.4455, 0.4457]]]])
Test  out.size:  torch.Size([1, 1, 128, 128])
tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
      

4it [00:14,  3.72s/it]



Test Set: Average DICE Coefficient: 0.5175



In [None]:
torch.save(model.state_dict(), 'unet-final-model')

In [None]:
model.load_state_dict(torch.load('unet-final-model'))
test(model, test_loader, criterion, save_output=True)