In [1]:
import os
import imageio
import numpy as np
import pandas as pd
import geopandas as gp
from sklearn.model_selection import StratifiedShuffleSplit

In [71]:
DATA = '../data'
IMAGES_PATH = 'images'
MASKS_PATH = 'masks'
INSTANCES_PATH = 'instance_masks'
WIDHT , HEIGHT = 224, 224
CHANNELS = 3

In [176]:
def get_area(instance_path):
    return (gp.read_file(instance_path)['geometry'].area / 100).median()

    
def get_labels(distr):
    res = np.full(distr.shape, 3)
    res[areas < np.quantile(distr, 0.75)] = 2
    res[areas < np.quantile(distr, 0.5)] = 1
    res[areas < np.quantile(distr, 0.25)] = 0
    return res

In [179]:
def stratify(
    data_path, width=224, height=224, channels=3, images_path_name='images',
    masks_path_name='masks', instances_path_name='instance_masks'
    ):
     
    X = np.empty([0, width, height, channels])
    y = np.empty([0, width, height])
    areas = np.empty([0])
    datasets = list(os.walk(data_path))[0][1]
    for dataset in datasets:
        images_path = os.path.join(DATA, dataset, images_path_name)
        masks_path = os.path.join(DATA, dataset, masks_path_name)
        instances_path = os.path.join(DATA, dataset, instances_path_name)
        instances = list(os.walk(instances_path))[0][1]

        X = np.concatenate((
            X, [imageio.imread(os.path.join(images_path, i + '.jpeg')) for i in instances]))
        y = np.concatenate((
            y, [imageio.imread(os.path.join(masks_path, i + '.png')) for i in instances]))
        y = y.reshape([*y.shape, 1])
        areas = np.concatenate((
            areas, 
            [get_area(os.path.join(instances_path, i, i + '.geojson')) for i in instances]))
        labels = get_labels(areas)

    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    for train_ix, test_ix in sss.split(X, labels):
        X_train, X_test = X[train_ix], X[test_ix]
        y_train, y_test = y[train_ix], y[test_ix]

    return X_train, y_train, X_test, y_test

In [181]:
a, b, c, d = stratify(DATA)
print(a.shape, b.shape, c.shape, d.shape, sep='\n')

(196, 224, 224, 3)
(196, 224, 224, 1)
(49, 224, 224, 3)
(49, 224, 224, 1)
