In [1]:
import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline 

In [None]:
# Import Mask RCNN
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.config import Config
from mrcnn.model import log

In [2]:
# Root directory of the project
ROOT_DIR = os.path.abspath("../data/raw/stephtest/")

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "wv2/models")

# # Local path to trained weights file
# COCO_MODEL_PATH = os.path.join(MODEL_DIR, "mask_rcnn_coco.h5")
# # Download COCO trained weights from Releases if needed
# if not os.path.exists(COCO_MODEL_PATH):
#     utils.download_trained_weights(COCO_MODEL_PATH)
EIGHTCHANNEL_DIR = os.path.join(ROOT_DIR, 'wv2/eightchannels')
TRAIN_DIR = os.path.join(ROOT_DIR, 'wv2/images')
VALIDATION_DIR = os.path.join(ROOT_DIR, 'wv2/masks')
TEST_DIR = os.path.join(ROOT_DIR, 'wv2/test')
try:
    os.mkdir(EIGHTCHANNEL_DIR)
    os.mkdir(MODEL_DIR)
    os.mkdir(TRAIN_DIR)
    os.mkdir(VALIDATION_DIR)
    os.mkdir(TEST_DIR)
except:
    FileExistsError

In [5]:
def load_merge_wv2(image_id):
    """Load the specified wv2 os/gs image pairs and return a [H,W,8] 
    Numpy array. Channels are ordered [B, G, R, NIR, B, G, R, NIR], OS 
    first.
    """
    # Load image
    os_path = IMAGERY_DIR+'/'+image_id+'_MS_OS.tif'
    gs_path = IMAGERY_DIR+'/'+image_id+'_MS_GS.tif'
    os_image = skimage.io.imread(os_path)
    gs_image = skimage.io.imread(gs_path)
    # If has more than 4 bands, select correct bands 
    # will need to provide image config in future
    # to programmaticaly use correct band mappings
    if os_image.shape[-1] != 4:
        os_image = np.dstack((os_image[:,:,1:3],os_image[:,:,4],os_image[:,:,6]))
    if gs_image.shape[-1] != 4:
        gs_image = np.dstack((gs_image[:,:,1:3],gs_image[:,:,4],gs_image[:,:,6]))
    stacked_image = np.dstack((os_image, gs_image))
    stacked_image_path = EIGHTCHANNEL_DIR +'/'+ image_id + '_OSGS_ms.tif'
    return (stacked_image_path, stacked_image)

IMAGERY_DIR = os.path.join(ROOT_DIR, 'projectedtiffs/')
GROUNDTRUTH_DIR = os.path.join(ROOT_DIR, 'rasterized_wv2_labels')

# all files, including ones we don't care about
file_ids_all = next(os.walk(IMAGERY_DIR))[2]
# all multispectral on and off season tifs
image_ids_all = [image_id for image_id in file_ids_all if 'MS' in image_id]
#check for duplicates
print(len(image_ids_all) != len(set(image_ids_all)))

image_ids_gs = [image_id for image_id in image_ids_all if 'GS' in image_id]
image_ids_os = [image_id for image_id in image_ids_all if 'OS' in image_id]

#check for equality
print(len(image_ids_os) == len(image_ids_gs))

image_ids_short = [image_id[0:2] for image_id in image_ids_gs]
image_ids_short

stacked_dict = {}

for imid in image_ids_short:
    
    path, arr = load_merge_wv2(imid)
    stacked_dict.update({path:arr})
    
# trying to save 8 channel numpy array with GS, OS info. this is what matterport expects
# BUT in mold_inputs() in https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/model.py
# indicates that the input png/array must only have three channels...
# We could change mold_inputs to not have this requirement and also change
#         input_image = KL.Input(
#            shape=[None, None, 3], name="input_image")
# on line 1841 of mrcnn/model.py but not sure if this is all the changes
# that would be required
# this issue indicates fix is simpler: https://github.com/matterport/Mask_RCNN/issues/314

for key, val in stacked_dict.items():
    skimage.io.imsave(key,val,plugin='tifffile')

False
True


In [6]:
import random
import shutil
random.seed(42)

def train_test_split(imagerydir, traindir, testdir, kprop):
    """Splits tifs into train and test dir."""
    
    image_list = next(os.walk(imagerydir))[2]
    k = round(kprop*len(image_list))
    test_list = random.sample(image_list,k)
    for test in test_list:
        shutil.copyfile(os.path.join(imagerydir,test),os.path.join(testdir,test))
    train_list = list(set(next(os.walk(imagerydir))[2]) - set(test_list))
    for train in train_list:
        shutil.copyfile(os.path.join(imagerydir,train),os.path.join(traindir,train))
    print(len(train_list))
    print(len(test_list))
    
train_test_split(EIGHTCHANNEL_DIR,TRAIN_DIR, TEST_DIR, .1)

groundtruth_list = next(os.walk(GROUNDTRUTH_DIR))[2]
for file in groundtruth_list:
    shutil.copyfile(os.path.join(GROUNDTRUTH_DIR,file),os.path.join(VALIDATION_DIR,file))

7
1
