In [1]:
import torch
import cv2
from PIL import Image
import numpy as np
import glob
import math
import os
import sys
from tqdm import tqdm
import torch
import torchvision.transforms as transforms
import torch.utils.data as torch_data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from models import deeplab

In [2]:

class GTiffDataset(torch_data.Dataset):
    def __init__(self,
                 root_dir, split, tile_size = 256, stride = 256, debug = False, transform=None):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.tile_size = tile_size
        self.stride = stride
        self.root_dir = root_dir
        self.transform = transform
        self.debug = debug
        self.images, self.masks = self.read_dir()

    def get_tiles(self, image, mask):
        i_tiles, m_tiles = [], []
        width = image.shape[2] - image.shape[2]%self.tile_size
        height = image.shape[1] - image.shape[1]%self.tile_size

        for i in range(0, height, self.stride):
            if i+self.tile_size > height:
                break
            for j in range(0, width, self.stride):
                img_tile = image[
                    :,
                    i:i+self.tile_size,
                    j:j+self.tile_size
                ]

                mask_tile = mask[
                    i:i+self.tile_size,
                    j:j+self.tile_size
                ]

                i_tiles.append(img_tile)
                m_tiles.append(mask_tile)

                if self.debug:
                    # Debugging the tiles
                    img_tile = np.moveaxis(img_tile, 0, -1)
                    cv2.imwrite("debug/" + str(i) + "_" + str(j) + "_img.png", img_tile)
                    cv2.imwrite("debug/" + str(i) + "_" + str(j) + "_mask.png", mask_tile)

        return i_tiles, m_tiles

    def read_dir(self):
        tiles = [[], []]
        images = sorted(glob.glob(self.root_dir + '/' + '*_3031.tif'))
        masks = sorted(glob.glob(self.root_dir + '/' + '*_3031_mask.tif'))
        for idx, [img, msk] in enumerate(zip(images, masks)):
            print('Reading item # {} - {}/{}'.format(img, idx+1, len(images)))
            image = Image.open(img)
            mask = Image.open(msk)
            image = np.asarray(image.transpose(Image.FLIP_TOP_BOTTOM))
            image = np.moveaxis(image, 2, 0)
            mask = np.asarray(mask)
            _, mask = cv2.threshold(mask, 127, 1, cv2.THRESH_BINARY)
            print(image.shape)
            i_tiles, m_tiles = self.get_tiles(image, mask)
            for im, ma in zip(i_tiles, m_tiles):
                tiles[0].append(im)
                tiles[1].append(ma)
            print(len(tiles[0]))
            del image
            del mask
        print()
        return tiles

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = [self.images[idx], self.masks[idx]]
        return sample


In [3]:
def test(model, device, dataloader):
    model.eval()
    val_loss = 0.0
    tbar = tqdm(dataloader)
    num_samples = len(dataloader)
    with torch.no_grad():
        for i, sample in enumerate(tbar):
            image, target = sample[0].float(), sample[1].float()
            image, target = image.to(device), target.float().to(device)

            output = model(image)
            tbar.set_description('Val loss: %.3f' % (train_loss / (i + 1)))
    return output


In [6]:
device = torch.device("cpu" if torch.cuda.is_available() else "cpu")
gtiffdataset = GTiffDataset('../data/pre-processed/transantractics/WV02', split='test', stride=256, debug=False)
test_dataloader = torch_data.DataLoader(gtiffdataset, num_workers=0, batch_size=1)    

Reading item # ../data/pre-processed/transantractics/WV02\10100100038A4200_3031.tif - 1/2
(3, 8068, 8670)
1023
Reading item # ../data/pre-processed/transantractics/WV02\1010010007918200_3031.tif - 2/2
(3, 8164, 7708)
1953



In [10]:
if __name__=="__main__":
    model = deeplab.DeepLab(output_stride=16)
    model.load_state_dict(torch.load("../models/deeplabv3-resnet-bn2d-1.pth")["model"])
    model.to(device)

RuntimeError: cuda runtime error (30) : unknown error at ..\aten\src\THC\THCGeneral.cpp:87