In [None]:
# This code crops the input band image into 256 X 256 image patches within the bounding box
# Image patches that are not overlapping with marsh dataframe will be removed from the patches.


In [3]:
import os
import numpy as np
import pandas as pd
import rasterio
import sys
import geopandas as gpd
from pathlib import Path
from tqdm import tqdm
from rasterio.mask import mask, raster_geometry_mask
from shapely.geometry import box
from rasterio.enums import Resampling
from itertools import product
from rasterio import windows
import rioxarray
import shapely
import shutil

path_cur = os.path.abspath('.')
sys.path.append(path_cur)

from os.path import dirname as up

In [5]:
# Setting working directory and input data directory

base_path = Path(os.path.join(up(path_cur), 'data', 'sentinel_data_processing'))
(base_path / 'image_patches_128').mkdir(exist_ok=True, parents=True)
(base_path / 'temp_patches').mkdir(exist_ok=True, parents=True)

IMG_DATA_DIR = os.path.join(up(path_cur), 'data', 'sentinel_shoreline_2017')
Label_DATA_DIR = os.path.join(up(path_cur), 'data', 'processing_data', 'marsh_all_500.geojson')

In [6]:
label_df = gpd.read_file(Label_DATA_DIR)

In [7]:
label_df.columns

Index(['MarshType', 'CmtyType', 'Reedgrass', 'PrcntPhrag', 'FieldDate',
       'CommunType', 'Dominant_P', 'COUNTY', 'PubYear', 'Acres', 'FieldCheck',
       'MarshNo', 'FIPS', 'FIPSCode', 'FIPSMRSHNO', 'Shape_Leng', 'PrevPubYr',
       'Shape_Le_1', 'Shape_Area', 'Comments', 'RefImagery', 'area',
       'unique_id', 'year', 'geometry'],
      dtype='object')

In [8]:
all_bands = [b for b in os.listdir(IMG_DATA_DIR) if b.endswith('tif')]

In [9]:
def cropping_bands(ref_img_path, ups_img, outfile):
    
    """
    ref_img_path: input 10m resolution band
    ups_img_path: input low resolution band (rasterio.open() output)
    outfile: output low resolution band with geom alinged with ref_img
    """

    ref_img = rasterio.open(ref_img_path)
    # get the geometry of the reference high resolution band
    geom = box(*ref_img.bounds)
    
#     ups_img = rasterio.open(ups_img_path)
    cropped, crop_transf = mask(ups_img, [geom], crop=True, filled=False, all_touched=False)
    
    c, h, w = cropped.shape
    
    meta = ref_img.meta
    meta['width'], meta['height'] = w, h
    meta['transform'] = crop_transf
    meta["count"] = c

    with rasterio.open(outfile, 'w', **meta) as dst:
        dst.write(cropped)


def upsample(img_lres_path, img_hres_path, img_size, outf, method=Resampling.bilinear):
    
    """
    img_lres_path: low resolution cropped band path
    img_hres_path: high resolution cropped band path
    img_size: the size to resample
    outf: output resampled Bands
    """
    
    dataset = rasterio.open(img_lres_path)

    # resample data to target shape
    data = dataset.read(
        out_shape=(
            dataset.count,
            int(img_size),
            int(img_size)
        ),
        resampling=method
    )

    dataset_hres = rasterio.open(img_hres_path)
    
    meta = dataset_hres.meta
    meta['count'] = dataset.count
    
    with rasterio.open(outf, 'w', **meta) as dst:
        dst.write(data)


def get_tile_geom(tile_tif, crs=None):
    
    rds = rioxarray.open_rasterio(tile_tif)
    
    if crs is not None:

        assert isinstance(crs, str)
        
        rds_proj = rds.rio.reproject(crs)
        minx, miny, maxx, maxy = rds_proj.rio.bounds()
        geometry = shapely.geometry.box(minx, miny, maxx, maxy, ccw=True)
    
    else:
        
        minx, miny, maxx, maxy = rds.rio.bounds()
        geometry = shapely.geometry.box(minx, miny, maxx, maxy, ccw=True)
    
    return geometry

def get_tiles(ds, width=256, height=256):
    nols, nrows = ds.meta['width'], ds.meta['height']
    offsets = product(range(0, nols, width), range(0, nrows, height))
    big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
    for col_off, row_off in offsets:
        window =windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
        transform = windows.transform(window, ds.transform)
        yield window, transform
        
# create a multiband images

def img_to_array(image):
    
    img = rasterio.open(image)
    img_array = img.read()
    
    meta = img.meta

    return img_array, meta
    

def stack_bands(bd_list, bdir):
    
    b1 = rasterio.open(os.path.join(bdir, bd_list[0]))
    meta = b1.meta
    b1_array = b1.read(1)
    
    array_list = list()
    array_list.append(b1_array)
    
    for i in range(1, len(bd_list)):
        b_array = rasterio.open(os.path.join(bdir, bd_list[i])).read(1)
        array_list.append(b_array)
    
    return array_list, meta

In [10]:
VRGB = ['merge_B02_2017.tif', 'merge_B03_2017.tif', 'merge_B04_2017.tif', 'merge_B08_2017.tif']
out_img = os.path.join(IMG_DATA_DIR, 'merge_VRGB_2017.tif')

composite_VRGB= False

while composite_VRGB:
    
    composite_VRGB = False

    arrays, meta = stack_bands(VRGB, IMG_DATA_DIR)
    meta.update({"count": 4})

    with rasterio.open(out_img, 'w', **meta) as dest:
        for band_nr, src in enumerate(arrays, start=1):
            dest.write(src, band_nr)


In [11]:
image_clipping = False

N = 128

while image_clipping:
    
    image_clipping = False

    output_filename = 'VRGB_2017_tile_{}-{}.tif'
            
    with rasterio.open(out_img) as inds:

        meta = inds.meta.copy()

        for window, transform in get_tiles(inds, N, N):

            meta['transform'] = transform
            meta['width'], meta['height'] = window.width, window.height

            outpath = os.path.join(base_path, 'temp_patches', output_filename.format(int(window.col_off), int(window.row_off)))
                    
            with rasterio.open(outpath, 'w', **meta) as outds:
                outds.write(inds.read(window=window))


            patch_geom = get_tile_geom(outpath)
            patch_gdf = label_df[label_df.within(patch_geom)]

            if not patch_gdf.empty:

                # move all subtiles that are inter-sect with the CUSP data to a separate folder, the imageries in this folder will be used
                # to create training/validation data

                patch_path = os.path.join(base_path, 'image_patches_128', output_filename.format(int(window.col_off), int(window.row_off)))

                shutil.copyfile(outpath, patch_path)


In [96]:

# image_clipping_b3 = False
# N = 128
# # bands10m = ['B02', 'B03', 'B04', 'B08'] # RGB,NIR

# while image_clipping_b3:
    
#     image_clipping_b3 = False

#     for band in tqdm(all_bands):

#         if os.path.basename(band).split('_')[1] == 'B03':

#             output_filename = band.split('.')[0] + '_tile_{}-{}.tif'
#             band_path = os.path.join(IMG_DATA_DIR, band)
            
#             with rasterio.open(band_path) as inds:

#                 meta = inds.meta.copy()

#                 for window, transform in get_tiles(inds, N, N):

#                     meta['transform'] = transform
#                     meta['width'], meta['height'] = window.width, window.height

#                     outpath = os.path.join(base_path, 'temp_patches', output_filename.format(int(window.col_off), int(window.row_off)))
                    
#                     with rasterio.open(outpath, 'w', **meta) as outds:
#                         outds.write(inds.read(window=window))


#                     patch_geom = get_tile_geom(outpath)
#                     patch_gdf = label_df[label_df.within(patch_geom)]

#                     if not patch_gdf.empty:

#                         # move all subtiles that are inter-sect with the CUSP data to a separate folder, the imageries in this folder will be used
#                         # to create training/validation data

#                         patch_path = os.path.join(base_path, 'image_patches_128', output_filename.format(int(window.col_off), int(window.row_off)))

#                         shutil.copyfile(outpath, patch_path)


In [97]:

# image_clipping = False
# bands10m = ['B02', 'B04', 'B08']

# allpatches = [i for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if 'B03' in i and i.endswith('tif')]


# while image_clipping:
    
#     image_clipping = False

#     for band in tqdm(all_bands):

#         if os.path.basename(band).split('_')[1] in bands10m:
            
#             print('Working on band {}'.format(band))
            
#             band_name = os.path.basename(band).split('_')[1]
#             band_path = os.path.join(IMG_DATA_DIR, band)
            
#             with rasterio.open(band_path) as inds:
                
#                 for patch in allpatches:
#                     hres_path = os.path.join(base_path, 'image_patches_128', patch)
#                     crop_path = os.path.join(base_path, 'image_patches_128', patch.replace('B03', band_name))                    
#                     cropping_bands(hres_path, inds, crop_path)


In [17]:
# Image resampling

coarse_clipping_resample = False
pixel_size = N
bands10m = ['B02', 'B03', 'B04', 'B08']

all_hres = [f for f in os.listdir(os.path.join(base_path, 'image_patches_128')) if f.endswith('tif')]

# check_dim_img = list()

# for hres in all_hres:
#     hres_path = os.path.join(base_path, 'image_patches_128', hres)
#     src = rasterio.open(hres_path).read()
#     if src.shape[1] != pixel_size or src.shape[2] != pixel_size:
#         print(hres)
#         check_dim_img.append(hres)


while coarse_clipping_resample:
    
    coarse_clipping_resample = False
    
    for band in tqdm(all_bands):
        
        band_name = band.split('_')[1]
        
        if band_name not in bands10m:
            
            print('Working on band {}'.format(band))
            
            raw_img = rasterio.open(os.path.join(IMG_DATA_DIR, band))
            
            for hres in all_hres:
                
                hres_path = os.path.join(os.path.join(base_path, 'image_patches_128', hres))
                crop_path = os.path.join(os.path.join(base_path, 'temp_coarse', hres.replace('VRGB', band_name)))
                resample_path = os.path.join(os.path.join(base_path, 'image_patches_128', hres.replace('VRGB', band_name)))

                cropping_bands(hres_path, raw_img, crop_path)
                upsample(crop_path, hres_path, pixel_size, resample_path)

In [22]:
all_img_id = [i for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if i.endswith('tif')]

for f in all_img_id:    
    fpath = os.path.join(base_path, 'image_patches_128', f)
    src = rasterio.open(fpath).read()
    if src.shape[1] != 128 or src.shape[2] != 128:
        print(f)

In [23]:

bands4_list = ['VRGB', 'B01', 'B05', 'B06', 'B07', 'B09', 'B10', 'B11', 'B12', 'B8A']            
img_dict = dict()

all_img_id = [i[:-4].split('_')[-1] for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if 'VRGB' in i]

for img_id in all_img_id:
    
    if img_id not in img_dict:
        img_dict[img_id] = list()
    
    bands = list()
    for band in bands4_list:
        band_name = '{}_2017_tile_{}.tif'.format(band, img_id)
        if os.path.isfile(os.path.join(base_path, 'image_patches_128', band_name)):
            bands.append(band_name)
        else:
            raise
    
    img_dict[img_id] = bands


In [26]:
def stack_bands(bd_list, bdir):
    
    b1 = rasterio.open(os.path.join(bdir, bd_list[0]))
    meta = b1.meta
    b1_array_1 = b1.read(1)
    b1_array_2 = b1.read(2)
    b1_array_3 = b1.read(3)
    b1_array_4 = b1.read(4)
    
    array_list = list()
    array_list.append(b1_array_1)
    array_list.append(b1_array_2)
    array_list.append(b1_array_3)
    array_list.append(b1_array_4)
    
    for i in range(1, len(bd_list)):
        b_array = rasterio.open(os.path.join(bdir, bd_list[i])).read(1)
        array_list.append(b_array)
    
    return array_list, meta

In [28]:
bdir = os.path.join(base_path, 'image_patches_128')

composite_bands = False

while composite_bands:
    
    composite_bands = False

    for img_id in img_dict.keys():

        multi_name = '{}.tif'.format(img_id)
        out_img = os.path.join(base_path, 'image_multiband_128', multi_name)

        band_list = img_dict[img_id]    
        arrays, meta = stack_bands(band_list, bdir)

        meta.update({"count": 13})

        with rasterio.open(out_img, 'w', **meta) as dest:
            for band_nr, src in enumerate(arrays, start=1):
                dest.write(src, band_nr)


In [29]:
allfile_names = [i for i in os.listdir(os.path.join(base_path, 'image_multiband_128')) if i.endswith('tif')]

In [30]:
df = pd.DataFrame(allfile_names, columns =['patch_name'])

In [31]:
df['label'] = df.apply(lambda x: os.path.join(base_path, 'tmi_labels', 'tile_{}'.format(x['patch_name'])), axis=1)

In [32]:
df.to_csv(os.path.join(base_path, 'labels_tmi.csv'), index=False)