In [22]:
import random
import os
from os import makedirs, symlink, readlink
from os.path import join, basename, dirname, exists, islink, split
from glob import glob
from shutil import copy

from skimage.io import imread
from tqdm import tqdm

In [23]:
# Constants
DATASET_DIR = "/home/gdf/data/LEVIR-CD/"

OLD_TRAIN_DIR = join(DATASET_DIR, '256x256_2', 'train')
OLD_VAL_DIR = join(DATASET_DIR, '256x256_2', 'val2')
OLD_TEST_DIR = join(DATASET_DIR, '256x256_2', 'test')

NEW_TRAIN_DIR = join(DATASET_DIR, '256x256_3', 'train')
NEW_VAL_DIR = join(DATASET_DIR, '256x256_3', 'val')
NEW_TEST_DIR = join(DATASET_DIR, '256x256_3', 'test')

RATIO = 0.9
COPY_MODE = 'LINK'  # LINK, COPY
SEED = 114514

In [24]:
# Utility functions
def mklink(src, dst):
    dir = dirname(dst)
    if not exists(dir):
        makedirs(dir)
    if islink(src):
        src = readlink(src)
    symlink(src, dst)


def mkcopy(src, dst):
    dir = dirname(dst)
    if not exists(dir):
        makedirs(dir)
    copy(src, dst)
    

In [25]:
# Merge old train and val subsets and re-split them
def _copy_file(src, dst):
    if COPY_MODE == 'COPY':
        mkcopy(src, dst)
    elif COPY_MODE == 'LINK':
        mklink(src, dst)
    else:
        raise ValueError

def _split_parts(path):
    parts = []
    head = path
    while True:
        head, tail = split(head)
        if tail != '':
            parts.append(tail)
        elif head != '':
            parts.append(head)
            break
    parts.reverse()
    return parts

def _random_split(label_paths, train_dir, val_dir):
    random.shuffle(label_paths)
    end = int(len(label_paths)*RATIO)
    
    for i, p in enumerate(label_paths):
        parts = _split_parts(p)
        src_dir = join(*parts[:-3])
        tag, name = parts[-2:]
        dst_dir = train_dir if i < end else val_dir
        _copy_file(join(src_dir, 'A', tag, name), join(dst_dir, 'A', tag, name))
        _copy_file(join(src_dir, 'B', tag, name), join(dst_dir, 'B', tag, name))
        _copy_file(join(src_dir, 'label', tag, name), join(dst_dir, 'label', tag, name))

random.seed(SEED)

subdirs = glob(join(OLD_TRAIN_DIR, 'label', '*/'))+glob(join(OLD_VAL_DIR, 'label', '*/'))
subdirs.sort()
for subdir in subdirs:
    label_paths = sorted(glob(join(subdir, '*.png')))
    pos_paths, neg_paths = [], []
    for p in label_paths:
        if imread(p).sum() > 0:
            pos_paths.append(p)
        else:
            neg_paths.append(p)
    
    if len(pos_paths) > 0:
        _random_split(pos_paths, NEW_TRAIN_DIR, NEW_VAL_DIR)

    if len(neg_paths) > 0:
        _random_split(neg_paths, NEW_TRAIN_DIR, NEW_VAL_DIR)


In [27]:
# Link test subset
mklink(OLD_TEST_DIR, NEW_TEST_DIR)

In [29]:
# Count
def _show_info(paths):
    print(f"{len(paths)}\t{_count_ratio(paths)}")

def _count_ratio(paths):
    pos_cnt = 0
    cnt = 0
    for p in paths:
        im = imread(p)
        pos_cnt += (im>0).sum()
        assert len(im.shape) == 2
        cnt += im.shape[0]*im.shape[1]
    return pos_cnt / cnt

_show_info(glob(join(NEW_TRAIN_DIR, 'label', '**', '*.png')))
_show_info(glob(join(NEW_VAL_DIR, 'label', '**', '*.png')))
_show_info(glob(join(NEW_TEST_DIR, 'label', '**', '*.png')))

22016	0.04571745076844858
2925	0.044757674779647434
2048	0.05094262957572937
