In [1]:
import os
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms

import PIL
from PIL import Image

In [2]:
model = torch.jit.load("TorchScript/ternausnet.pt")

In [3]:
import itertools
import re

class CustomDataset(Dataset):
    def __init__(self, directory: str, transform=None) -> None:
        self.patients_dir = list(map( lambda x: os.path.join(directory, x), sorted(os.listdir(directory))))
        self.files = list(map(lambda x: self.get_files(x), self.patients_dir))
        self.files = list(itertools.chain(*self.files))
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx) -> Tensor:
        img = Image.open(self.files[idx])
        if self.transform:
            img = self.transform(img)
        
        return img

    def get_files(self, directory: str) -> list:
        # regex_pattern = "[^\s]+(Depth|RGB)(_T[0-9]+)(\.(png))$"
        regex_pattern = "[^\s]+(RGB_T[0-9]+)(\.(png))$"

        result = []
        files = sorted(os.listdir(directory))
        for filename in files:
            x = re.search(regex_pattern, filename)
            if x:
                result.append(os.path.join(directory, filename))

        return result


In [4]:
trans = transforms.Compose([
            transforms.Resize([512,512], interpolation=PIL.Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
        ])

dataset = CustomDataset("../DataTest/Images", trans)

In [9]:
from torchvision.utils import save_image
loader = DataLoader(dataset, 5)

results = []

for idx, img in enumerate(loader):
    sigmoid = torch.nn.Sigmoid()
    img = img.cuda()
    with torch.no_grad():
        output = model(img)

    output = sigmoid(output)
    output[output < 0.75] = 0
    output[output > 0] = 1
    
    save_image(img, "img{}.png".format(idx))
    save_image(output, "test{}.png".format(idx))