In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt
from imageio import imread, imwrite

import os, copy, time

from horsetools import list_files

plt.ion()

## Load data

In [None]:
def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

    
class SegmentationDataset(Dataset):
    IMG_EXTS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif')
    
    def _get_files(self, folder):
        if os.path.isdir(folder):
            return sorted(list_files(folder, valid_exts=self.IMG_EXTS))
        else:
            raise(RuntimeError('No folder named "{}" found.'.format(folder)))
    
    def __init__(self, root, labels, image_transforms=None, mask_transforms=None):
        self.imgs = self._get_files(os.path.join(root, 'images'))
        self.masks = self._get_files(os.path.join(root, 'masks'))
        self.labels = labels
        
        self.image_transforms = image_transforms
        self.mask_transforms = mask_transforms
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, index):
        img = default_loader(self.imgs[index])
        mask = default_loader(self.masks[index])
        if self.image_transforms is not None:
            img = self.image_transforms(img)
        if self.mask_transforms is not None:
            mask = self.mask_transforms(mask)
            
        return img, mask

In [None]:
# find means and stds of dataset
imgs_list = list_files('stroma_epithelia/images')
np_means = np.zeros((len(imgs_list), 3))
np_stds = np.zeros_like(np_means)
for i, img_name in enumerate(imgs_list):
    img = imread(img_name)
    np_means[i] = np.mean(img, axis=(0, 1))
    np_stds[i] = np.std(img, axis=(0, 1))
    
channel_means = np.mean(np_means, axis=0)
channel_stds = np.std(np_stds, axis=0)

In [None]:
print('Means: {}'.format(channel_means))
print('Stds: {}'.format(channel_stds))

In [None]:
data_root = 'stroma_epithelia'
crop_size = 224
labels = (0, 1, 2)
dataset_phases = ['train']

In [None]:
from PIL import Image

class LabelToOnehot(object):
    def __init__ (self, labels):
        self.labels = labels
        
    def __call__(self, img):
        img = np.array(img)
        
        if len(img.shape) > 2:
            img = img[:, :, 0]
            
        onehot = np.zeros((img.shape[0], img.shape[1], len(self.labels)))
        for i, l in enumerate(self.labels):
            onehot[:, :, i] = img == l
            
        return onehot
    
    def __repr__(self):
        return self.__class__.__name__ + '()'

data_transforms = {
    'train': {
        'imgs': transforms.Compose([
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Normalize(channel_means, channel_stds)
        ]),
        'masks': transforms.Compose([
            transforms.CenterCrop(crop_size),
            LabelToOnehot(labels),
            transforms.ToTensor()
        ])
    }
}

image_datasets = {x: SegmentationDataset(data_root, labels, 
                                         image_transforms=data_transforms[x]['imgs'],
                                         mask_transforms=data_transforms[x]['masks'])
                  for x in dataset_phases}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
               for x in dataset_phases}
dataset_sizes = {x: len(image_datasets[x]) for x in dataset_phases}

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# test onehot 
mask_name = 'stroma_epithelia/masks/0.1_110_2_1.png'
mask = imread(mask_name)
LO = LabelToOnehot(labels)
mask_oh = LO(mask)

plt.subplot(121)
plt.imshow(mask)
plt.subplot(122)
plt.imshow(mask_oh)

## Visualize images

In [None]:
def imshow(inp, means=None, stds=None, title=None):
    # convert tensor back to image range [0, 1]
    inp = inp.numpy().transpose((1, 2, 0))
    
    if means is not None and stds is not None:
        inp = np.array(stds) * inp + np.array(means)
        inp = np.clip(inp, 0, 1)
    
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
    
inputs, masks = next(iter(dataloaders['train']))
print('Image shape: {}'.format(inputs.shape))
print('Masks shape: {}'.format(masks.shape))
imshow(torchvision.utils.make_grid(inputs), means=channel_means, stds=channel_stds)
imshow(torchvision.utils.make_grid(masks))