In [24]:
import torch
from torch import nn

from torchvision import transforms
from tests import _PATH_DATA
import torchvision
from torchvision import transforms
from os import path, listdir
import numpy as np
import pytest

In [5]:
def split_data(dir):

    idx_dict = {
        'train': [
            34, 101, 114, 82, 123, 57, 22, 15, 137, 83, 99, 72, 47,
            36, 96, 46, 120, 60, 19, 79, 58, 134, 39, 102, 126, 94,
            7, 106, 2, 40, 70, 52, 104, 12, 119, 76, 108, 90, 147,
            143, 43, 140, 142, 88, 93, 4, 51, 16, 121, 74, 64, 77,
            98, 107, 56, 13, 92, 3, 141, 136, 146, 78, 91, 35, 124,
            63, 130, 84, 17, 80, 25, 118, 6, 113, 117, 67, 100, 54,
            103, 95, 37, 23, 32, 30, 42, 144, 75, 38, 50, 31, 66,
            131, 68, 97, 85, 44, 69, 33, 5, 138, 49, 14, 128, 24,
            11, 89, 135, 10, 29, 116, 65, 18, 125, 20, 26, 111, 73,
            48, 59, 139],
        'validation': [86, 21, 55, 61, 45, 81, 105, 149, 27, 132, 28, 129, 1, 53, 133],
        'test': [115, 109, 87, 112, 8, 9, 122, 41, 148, 110, 145, 71, 150, 127, 0, 62]
    }

    image_dir = dir+'/images'
    mask_dir = dir+'/masks'

    image_names = sorted([image for image in listdir(image_dir) if image.endswith('.png')])
    
    train_names = [image_names[idx] for idx in idx_dict['train']]
    validation_names = [image_names[idx] for idx in idx_dict['validation']]
    test_names = [image_names[idx] for idx in idx_dict['test']]

    name_dict = {'train': train_names, 'validation': validation_names, 'test': test_names}

    return name_dict

In [6]:
raw_dir = '../data/raw'
name_dict = split_data(raw_dir)

In [18]:
## TEST DATA
assert len(name_dict['train']) == 120, "Train data did not have correct number of images"
assert len(name_dict['validation']) == 15, "Train data did not have correct number of images"
assert len(name_dict['test']) == 16, "Test data did not have correct number of images"
    
#for name in name_dict['train']:
#    image = torchvision.io.read_image(raw_dir+'/images/'+name)
#    assert image.shape[]

first_image = torchvision.io.read_image(raw_dir+'/images/'+name_dict['train'][0])
assert first_image.shape[0] == 3, "Fist image was not 3-dimensional"

image_names = sorted([image for image in listdir(raw_dir+'/images') if image.endswith('.png')])
mask_names  = sorted([image for image in listdir(raw_dir+'/masks') if image.endswith('.png')])
assert image_names == mask_names, "Image names did not match mask names"

In [44]:
def get_slice_idxs(size, down_size):
    height, width = size[:2]
    down_height, down_width = down_size[:2]

    n_images_height = height // down_height + 1
    n_images_width = width // down_width + 1

    offsets_height = np.concatenate(([0], np.diff(np.linspace(0,  n_images_height * down_height - height, n_images_height, dtype=int))))
    offsets_height = np.cumsum(offsets_height).reshape(-1, 1)
    offsets_width = np.concatenate(([0], np.diff(np.linspace(0,  n_images_width * down_width - width, n_images_width, dtype=int))))
    offsets_width = np.cumsum(offsets_width).reshape(-1, 1)

    idxs_height = np.arange(0, n_images_height)
    idxs_height = np.concatenate((idxs_height, idxs_height+1))*down_height
    idxs_height = idxs_height.reshape(-1, 2, order='F') - offsets_height

    idxs_width = np.arange(0, n_images_width)
    idxs_width = np.concatenate((idxs_width, idxs_width+1))*down_width
    idxs_width = idxs_width.reshape(-1, 2, order='F') - offsets_width

    return idxs_height, idxs_width

def slice_image(image, size, idxs_height, idxs_width):
    images = np.empty((len(idxs_height), len(idxs_width), *size))
    for i, (sy, ey) in enumerate(idxs_height):
        for j, (sx, ex) in enumerate(idxs_width):
            images[i, j] = image[sy:ey, sx:ex]
    return images

def unslice_images(images, size, idxs_height, idxs_width, combine_func=lambda x: np.mean(x, axis=0)):
    image = np.full((size), np.nan)
    for i, (sy, ey) in enumerate(idxs_height):
        for j, (sx, ex) in enumerate(idxs_width):
            slice = image[sy:ey, sx:ex]
            slice[np.isnan(slice)] = images[i, j, np.isnan(slice)]
            slice[~np.isnan(slice)] = combine_func((slice[~np.isnan(slice)], images[i, j, ~np.isnan(slice)]))        
    return image

# unslice_images(slices, image.shape, idxs_height, idxs_width)

In [50]:
## TEST PREPROCESS
image = torch.moveaxis(torchvision.io.read_image(raw_dir+'/images/'+name_dict['train'][0]), 0, -1)

slice_size = (512, 512)

idxs_height, idxs_width = get_slice_idxs(image.shape, slice_size)
image_slices = slice_image(image, slice_size + (3,), idxs_height, idxs_width)

for i in range(image_slices.shape[0]):
    for j in range(image_slices.shape[1]):
        assert image_slices[i,j].shape == slice_size + (3,), "Image slice did not have correct dimensions"

unsliced_image = unslice_images(image_slices, image.shape, idxs_height, idxs_width)
assert unsliced_image.shape == image.shape, "Unsliced image dimensions did not match original image"