In [1]:
import os
from glob import glob
import numpy as np
import cv2
from osgeo import gdal

In [2]:
IMG_ROOT = '/tank2/home/public/iceplant/rgb/Images_clip'
MASK_ROOT = '/tank2/home/public/iceplant/rgb/Label_clip'
OUTPUT_ROOT = '/tank2/home/public/iceplant/rgb'

In [4]:
PSIZE = np.array([128, 128])
STRIDE = np.array([128, 128])

# make folders
os.makedirs(os.path.join(OUTPUT_ROOT, 'train', 'images'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, 'train', 'masks'), exist_ok=True)

os.makedirs(os.path.join(OUTPUT_ROOT, 'val', 'images'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_ROOT, 'val', 'masks'), exist_ok=True)

# get all image and mask paths
files = sorted([os.path.basename(i) for i in glob(os.path.join(IMG_ROOT, '*.tif'))])

count = 0

for fname in files:
    img_path = os.path.join(IMG_ROOT, fname)
    mask_path = os.path.join(MASK_ROOT, fname)

    # read image and mask
    img_ds = gdal.Open(img_path)
    mask_ds = gdal.Open(mask_path)

    img = img_ds.ReadAsArray().transpose(1, 2, 0)
    mask = mask_ds.ReadAsArray()

    # split by PSIZE and STRIDE
    # save patches to disk
    for i in range(0, img.shape[0], STRIDE[0]):
        for j in range(0, img.shape[1], STRIDE[1]):
            img_patch = img[i:i+PSIZE[0], j:j+PSIZE[1]]
            mask_patch = mask[i:i+PSIZE[0], j:j+PSIZE[1]]

            if img_patch.shape[0] != PSIZE[0] or img_patch.shape[1] != PSIZE[1]:
                continue

            if 15 in mask_patch:
                continue

            mask_patch = mask_patch - 1

            if fname == 'clip1.tif':
                split = 'val'
            else:
                split = 'train'

            cv2.imwrite(os.path.join(OUTPUT_ROOT, split, 'images', f'{fname}_{i}_{j}.png'), img_patch)
            cv2.imwrite(os.path.join(OUTPUT_ROOT, split, 'masks', f'{fname}_{i}_{j}.png'), mask_patch)
            count += 1

print("Total patches:", count)


Total patches: 8542


In [4]:
# # random split all patches to train and val by 8:2
# # move files to train and val folders
# files = sorted([os.path.basename(i) for i in glob(os.path.join(OUTPUT_ROOT, 'images', '*.png'))])
# np.random.shuffle(files)

# train_files = files[:int(len(files) * 0.8)]
# val_files = files[int(len(files) * 0.8):]

# # make folders
# os.makedirs(os.path.join(OUTPUT_ROOT, 'train', 'images'), exist_ok=True)
# os.makedirs(os.path.join(OUTPUT_ROOT, 'train', 'masks'), exist_ok=True)

# os.makedirs(os.path.join(OUTPUT_ROOT, 'val', 'images'), exist_ok=True)
# os.makedirs(os.path.join(OUTPUT_ROOT, 'val', 'masks'), exist_ok=True)

# for fname in train_files:
#     os.rename(os.path.join(OUTPUT_ROOT, 'images', fname), os.path.join(OUTPUT_ROOT, 'train', 'images', fname))
#     os.rename(os.path.join(OUTPUT_ROOT, 'masks', fname), os.path.join(OUTPUT_ROOT, 'train', 'masks', fname))

# for fname in val_files:
#     os.rename(os.path.join(OUTPUT_ROOT, 'images', fname), os.path.join(OUTPUT_ROOT, 'val', 'images', fname))
#     os.rename(os.path.join(OUTPUT_ROOT, 'masks', fname), os.path.join(OUTPUT_ROOT, 'val', 'masks', fname))

# # remove empty folders
# os.rmdir(os.path.join(OUTPUT_ROOT, 'images'))
# os.rmdir(os.path.join(OUTPUT_ROOT, 'masks'))
