In [None]:
import numpy as np
import tensorflow as tf

from random import randint
import math
import itertools

import higra as hg
import multiprocessing as mp
import numpy as np
import skimage as sk

import datasets.I3 as D1
import datasets.LW4 as D2
import patch

%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
PATCH_SIZE = 256
LAMBDAS = []
BATCH_SIZE = 512*4

In [None]:
def patches_batch_tos_area_p(batch_image, i, j, l):
    tree, altitudes = hg.component_tree_tree_of_shapes_image2d(batch_image[i, :, :, 0])
    area = hg.attribute_area(tree)
    batch_image[i, :, :, j+1] = hg.reconstruct_leaf_data(tree, altitudes, area < l)
    return batch_image[i, :, :, j+1]


def gen_patches_batch_augmented_tos_area_label_random_one_hot_p(patch_size, image, label, batch_size=32, lambdas=[]):
    n_label = label[0].shape[-1]
    batch_image = np.zeros((batch_size, patch_size, patch_size, len(lambdas)+1))
    batch_label = np.zeros((batch_size, patch_size, patch_size, n_label))
    pool = mp.Pool(min(mp.cpu_count(), 16))
    while True:
        for i in range(batch_size):
            x = randint(0, image.shape[2] - patch_size - 1)
            y = randint(0, image.shape[1] - patch_size - 1)
            z = randint(0, image.shape[0] - 1)
            
            batch_image[i, :, :, 0] = image[z, y:y + patch_size, x:x + patch_size]
            batch_label[i, :, :, :] = label[z, y:y + patch_size, x:x + patch_size]
            
            # Augmentations
            # random 90 degree rotation
            # random flip
            rot = randint(0, 3)
            batch_image[i, :, :] = np.rot90(batch_image[i, :, :], rot)
            batch_label[i, :, :] = np.rot90(batch_label[i, :, :], rot)
            
            if randint(0, 1) == 1:
                batch_image[i, :, :] = np.fliplr(batch_image[i, :, :])
                batch_label[i, :, :] = np.fliplr(batch_label[i, :, :])
                
            if randint(0, 1) == 1:
                batch_image[i, :, :] = np.flipud(batch_image[i, :, :])
                batch_label[i, :, :] = np.flipud(batch_label[i, :, :])
        
        batch_image_p = (pool.starmap(patches_batch_tos_area_p, [(batch_image, ij//len(lambdas), ij%len(lambdas), lambdas[ij%len(lambdas)]) for ij in range(batch_size * len(lambdas))]))
        
        for i in range(batch_size):
            for j in range(len(lambdas)):
                x = j + i*len(lambdas)
                batch_image[i, :, :, j+1] = batch_image_p[x]
        
        del batch_image_p
        
        yield batch_image, batch_label
    pool.close()

In [None]:
train_image = D1.train_i3_image_normalized_f32
train_labels = np.sum([D1.train_i3_label_1, D1.train_i3_label_2*2, D1.train_i3_label_3*3], axis=0).astype(np.uint8)
train_background = np.where(train_labels > 0, 0, 1)
train_labels_indexes = [D1.train_i3_label_1_indexes, D1.train_i3_label_2_indexes, D1.train_i3_label_3_indexes]
train_labels_one_hot = np.stack([train_background, D1.train_i3_label_1, D1.train_i3_label_2, D1.train_i3_label_3], axis=-1)

In [None]:
train1 = patch.gen_patches_batch_augmented_tos_area_label_indexes_one_hot_p(PATCH_SIZE, train_image, train_labels_one_hot, train_labels_indexes, batch_size=BATCH_SIZE, lambdas=LAMBDAS)
train2 = gen_patches_batch_augmented_tos_area_label_random_one_hot_p(PATCH_SIZE, train_image, train_labels_one_hot, batch_size=BATCH_SIZE, lambdas=LAMBDAS)

In [None]:
X, Y = next(train1)
print((np.sum(Y[:, :, :, 0])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 1])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 2])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 3])/(256*256))*100/BATCH_SIZE)

* 77.57105976343155
* 14.720947295427322
* 1.1825524270534515
* 6.525498628616333

In [None]:
X, Y = next(train2)
print((np.sum(Y[:, :, :, 0])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 1])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 2])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 3])/(256*256))*100/BATCH_SIZE)

* 92.0101523399353
* 5.138735473155975
* 0.323670357465744
* 2.5274470448493958

In [None]:
train_image = D2.train_lw4_image_normalized_f32
train_labels = np.sum([D2.train_lw4_label_1, D2.train_lw4_label_2*2, D2.train_lw4_label_3*3], axis=0).astype(np.uint8)
train_background = np.where(train_labels > 0, 0, 1)
train_labels_indexes = [D2.train_lw4_label_1_indexes, D2.train_lw4_label_2_indexes, D2.train_lw4_label_3_indexes]
train_labels_one_hot = np.stack([train_background, D2.train_lw4_label_1, D2.train_lw4_label_2, D2.train_lw4_label_3], axis=-1)

In [None]:
train1 = patch.gen_patches_batch_augmented_tos_area_label_indexes_one_hot_p(PATCH_SIZE, train_image, train_labels_one_hot, train_labels_indexes, batch_size=BATCH_SIZE, lambdas=LAMBDAS)
train2 = gen_patches_batch_augmented_tos_area_label_random_one_hot_p(PATCH_SIZE, train_image, train_labels_one_hot, batch_size=BATCH_SIZE, lambdas=LAMBDAS)

In [None]:
X, Y = next(train1)
print((np.sum(Y[:, :, :, 0])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 1])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 2])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 3])/(256*256))*100/BATCH_SIZE)

* 84.41567122936249
* 9.082133322954178
* 2.6188924908638
* 3.8833029568195343

In [None]:
X, Y = next(train2)
print((np.sum(Y[:, :, :, 0])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 1])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 2])/(256*256))*100/BATCH_SIZE)
print((np.sum(Y[:, :, :, 3])/(256*256))*100/BATCH_SIZE)

* 90.60298949480057
* 5.268880724906921
* 1.3633936643600464
* 2.7647361159324646