In [7]:
import os
import imageio
import argparse
import numpy as np
import pandas as pd
import geopandas as gp
from sklearn.model_selection import StratifiedShuffleSplit

In [170]:
DATA_PATH = '../data'
IMAGES_FOLDER = 'images'
MASKS_FOLDER = 'masks'
INSTANCES_FOLDER = 'instance_masks'
IMAGE_TYPE = 'tiff'
MASK_TYPE = 'png'
INSTANCE_TYPE = 'geojson'
CHANNELS = ['rgb', 'ndvi', 'ndvi_color', 'b2']

In [185]:
def get_data_info(data_path=DATA_PATH):  
    dataset = get_folders(data_path)[0]
    _, _, insatnces_path = get_data_pathes(data_path)
    instances = get_folders(insatnces_path)
    
    cols = ['date', 'name', 'ix', 'iy']
    data_info = pd.DataFrame(columns=cols)
    for instance in instances:
        name_parts = split_fullname(instance)
        data_info = data_info.append(
            pd.DataFrame({
                'date': name_parts[0],
                'name': name_parts[1],
                'ix': name_parts[3],
                'iy': name_parts[4]
            }, index=[0]),
            sort=True, ignore_index=True
        )
        
    return data_info


def get_data_pathes(data_path, 
                    images_folder=IMAGES_FOLDER,
                    masks_folder=MASKS_FOLDER,
                    instances_folder=INSTANCES_FOLDER):
    
    dataset = get_folders(data_path)[0]
    
    images_path = os.path.join(data_path, dataset, images_folder)
    masks_path = os.path.join(data_path, dataset, masks_folder)
    insatnces_path = os.path.join(data_path, dataset, instances_folder)
    
    return images_path, masks_path, insatnces_path
    
    
def get_folders(path):
    return list(os.walk(path))[0][1]


def split_fullname(fullname):
    return fullname.split('_')


def get_fullname(*name_parts):
    return '_'.join(tuple(map(str, name_parts)))


def get_filepath(*path_parts, file_type):
    return '{}.{}'.format(join_pathes(*path_parts), file_type)


def join_pathes(*pathes):
    return os.path.join(*pathes)


def stratify(data_info, test_size=0.2, random_state=42,
             channel='rgb', instance_type=INSTANCE_TYPE):
    
    X, _ = get_data(data_info)
    areas = []
    for _, row in data_info.iterrows():
        instance_name = get_fullname(row['date'], row['name'], channel, row['ix'], row['iy'])
        instance_path = get_filepath(
            data_path,
            get_fullname(row['date'], row['name'], channel),
            instances_folder,
            instance_name,
            instance_name,
            file_type=instance_type
        )
        areas.append(get_area(instance_path))
                     
    labels = get_labels(np.array(areas))

    sss = StratifiedShuffleSplit(
        n_splits=1,
        test_size=test_size,
        random_state=random_state
    )

    return sss.split(X, labels)


def get_data(data_info, channel='rgb',
             data_path=DATA_PATH,
             image_folder=IMAGES_FOLDER,
             mask_folder=MASKS_FOLDER,
             image_type=IMAGE_TYPE,
             mask_type=MASK_TYPE):
    
    x = []
    y = []
    for _, row in data_info.iterrows():
        dataset = get_fullname(row['date'], row['name'], channel)
        filename = get_fullname(row['date'], row['name'], channel, row['ix'], row['iy'])
        
        image_path = get_filepath(
            data_path,
            dataset,
            image_folder,
            filename,
            file_type=image_type
        )
        mask_path = get_filepath(
            data_path,
            dataset,
            mask_folder,
            filename,
            file_type=mask_type
        )
        
        x.append(read_tensor(image_path))
        y.append(read_tensor(mask_path))
        
    x = np.array(x)
    y = np.array(y)
    y = y.reshape([*y.shape, 1])

    return x, y


def read_tensor(filepath):
    return imageio.imread(filepath)


def get_area(instance_path):
    return (gp.read_file(instance_path)['geometry'].area / 100).median()


def get_labels(distr):
    res = np.full(distr.shape, 3)
    res[distr < np.quantile(distr, 0.75)] = 2
    res[distr < np.quantile(distr, 0.5)] = 1
    res[distr < np.quantile(distr, 0.25)] = 0
    return res

In [186]:
data_info = get_data_info()
data_info.head()

Unnamed: 0,date,ix,iy,name
0,20160103,20,20,66979721-be1b-4451-84e0-4a573236defd
1,20160103,26,13,66979721-be1b-4451-84e0-4a573236defd
2,20160103,12,22,66979721-be1b-4451-84e0-4a573236defd
3,20160103,16,5,66979721-be1b-4451-84e0-4a573236defd
4,20160103,30,16,66979721-be1b-4451-84e0-4a573236defd


In [187]:
x, y = get_data(data_info)
x.shape

(245, 224, 224, 3)

In [188]:
stratify(data_info)

[571.0459217901257, 30.108072562493362, 291.68015376469407, 271.8175202065232, 928.4628718328879, 269.41835292902397, 352.13214357329264, 174.76759558231566, 255.00304770915005, 55.63056148236994, 329.898966943304, 272.6688394137549, 182.99908878319653, 496.29280122490263, 252.75213571793046, 172.1105268442197, 162.76928611841475, 143.46291070775004, 255.72082779391965, 684.3760283516647, 113.50763241952889, 190.3767675822798, 28.012210000521517, 196.4297116522409, 30.358243547493046, 17.782047963230347, 141.1894143663663, 163.84738889172093, 117.78763803238165, 149.56919527189538, 1004.8441099843471, 771.5376922713334, 99.25886307180834, 283.5785990136587, 275.778776508931, 124.61176166492427, 148.0672807642901, 85.40664020178008, 265.3454790239226, 344.64852395519233, 96.27140971959923, 592.6559290670873, 374.76237064664895, 234.5854510522106, 110.39663753673294, 426.60717256091596, 357.7917263088333, 75.44156536364537, 244.97269154777413, 1681.503008996537, 1183.8701721768257, 397.8

<generator object BaseShuffleSplit.split at 0x7f645c9d1f68>