In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
import os
from torchvision.datasets import VOCSegmentation
from torch.utils.data import Dataset , DataLoader
import copy
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./original_testlogs')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from typing import Any, Tuple
import PIL
PIL.Image.ANTIALIAS = PIL.Image.LANCZOS
from PIL import Image
num_classes = 21
batch_size = 4
class VOCSegDataset(VOCSegmentation):
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        image = Image.open(self.images[index]).convert('RGB')
        label = Image.open(self.targets[index])
        image = self.transform(image)
        label = self.target_transform(label)
        label = (label*255)

        return image ,label.long()

In [3]:

image_transforms = transforms.Compose([
                                    transforms.Resize((256,256)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5) ),
                                     ])
target_transform = transforms.Compose([transforms.Resize((256,256)),
                                     transforms.ToTensor(),
                                     ])
test_dataset = VOCSegDataset('./data',
                             year='2012',download=False ,image_set='trainval', transform=image_transforms, target_transform=target_transform)
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False,num_workers=0)



In [4]:
from UNet import UNet
device = torch.device("cuda")
model = UNet(3,21)
model = (torch.load('./UNet_with_dice.pth'))
model = model.to(device)
print(model)

import torch.nn as nn
#import torchmetrics

celoss = nn.CrossEntropyLoss(ignore_index=255) # pascalVOC has 255 indexed pixel.

UNet(
  (inc): double_conv(
    (conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): contraction_path(
    (contract): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0

In [5]:
def make_sementic_image(img): # img is 256 256 tensor
#     [128, 128, 128] - 회색, 배경
# [0, 0, 255] - 파란색, 비행기
# [0, 255, 0] - 초록색, 자전거
# [255, 0, 0] - 빨간색, 새
# [255, 255, 0] - 노란색, 보트
# [0, 255, 255] - 청록색, 병
# [255, 0, 255] - 자주색, 버스
# [192, 192, 192] - 밝은 회색, 자동차
# [128, 128, 128] - 회색, 고양이
# [128, 0, 0] - 진한 빨간색, 의자
# [128, 128, 0] - 카키색, 소
# [0, 128, 0] - 진한 초록색, 식탁
# [128, 0, 128] - 보라색, 개
# [0, 128, 128] - 청록색, 말
# [0, 0, 128] - 진한 파란색, 오토바이
# [139, 69, 19] - 갈색, 사람
# [255, 165, 0] - 주황색, 식물
# [255, 192, 203] - 연분홍색, 양
# [255, 255, 255] - 흰색, 소파
# [255, 105, 180] - 핫핑크, 기차
# [240, 230, 140] - 밝은 베이지색, 텔레비전
    colors = [
    [128, 128, 128],    
    [0, 0, 255],   
    [0, 255, 0],   
    [255, 0, 0],   
    [255, 255, 0], 
    [0, 255, 255], 
    [255, 0, 255], 
    [192, 192, 192], 
    [128, 128, 128], 
    [128, 0, 0],   
    [128, 128, 0], 
    [0, 128, 0],   
    [128, 0, 128], 
    [0, 128, 128], 
    [0, 0, 128],   
    [139, 69, 19], 
    [255, 165, 0],
    [255, 192, 203],
    [255, 255, 255], 
    [255, 105, 180], 
    [240, 230, 140]  
    ]   
    #print(img.shape)
    rt = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            rt[i, j] = colors[img[i, j]]
    return torch.from_numpy(rt).permute(2,0,1)

In [6]:
total_test_loss = 0.0
total_dice_score = 0.0
total_dice_nobg= 0.0
total_ce_loss = 0.0
total_miou = 0.0
with torch.no_grad():
        for i ,data in enumerate(test_loader):
                inputs,labels = data
                inputs = inputs.to(device)
                labels=labels.squeeze(1) # change to 16 21 256 256
                labels = (labels).to(device)
                outputs = model(inputs)
                outputs_softmax = outputs.softmax(dim=1)
                celossval=celoss(outputs_softmax,labels)
                labels_no_255 = torch.where(labels >= 255, torch.zeros_like(labels), labels)
 
                loss = celossval
                total_ce_loss += celossval.item()
                total_test_loss+= loss.item()
                
                out_with_color = make_sementic_image(torch.argmax(outputs_softmax,dim=1)[0])
                writer.add_image('original_image',(inputs[0]*0.5+0.5),i)
                writer.add_image('model_out image',out_with_color,i)
                
print('test_loss : ',total_test_loss/len(test_dataset))
print('CE loss : ',total_ce_loss/len(test_dataset))

