In [14]:
import os
import numpy as np
import pandas as pd
import rasterio
import sys
import geopandas as gpd
from pathlib import Path
from tqdm import tqdm
from rasterio.mask import mask, raster_geometry_mask
from shapely.geometry import box
from rasterio.enums import Resampling
from itertools import product
from rasterio import windows
import rioxarray
import shapely
import shutil
from rasterio import features
from sklearn.model_selection import train_test_split

path_cur = os.path.abspath('.')
sys.path.append(path_cur)

from os.path import dirname as up

In [15]:
# Setting working directory and input data directory

base_path = Path(os.path.join(up(path_cur), 'data', 'sentinel_data_processing'))

Marsh_DIR = os.path.join(up(path_cur), 'data', 'processing_data', 'marsh_all_500.geojson')
ncld_path = os.path.join(up(path_cur), 'data', 'NLCD_shoreline', 'clipped_nlcd.tif')

In [16]:
allfiles = [i for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if 'VRGB' in i and i.endswith('tif')]

In [20]:
def cropping_bands(ref_img_path, ups_img, outfile):
    
    """
    ref_img_path: input 10m resolution band
    ups_img_path: input low resolution band (rasterio.open() output)
    outfile: output low resolution band with geom alinged with ref_img
    """

    ref_img = rasterio.open(ref_img_path)
    # get the geometry of the reference high resolution band
    geom = box(*ref_img.bounds)
    
#     ups_img = rasterio.open(ups_img_path)
    cropped, crop_transf = mask(ups_img, [geom], crop=True, filled=False, all_touched=False)
    
    c, h, w = cropped.shape
    
    meta = ref_img.meta
    meta['width'], meta['height'] = w, h
    meta['transform'] = crop_transf
    meta['count'] = c

    with rasterio.open(outfile, 'w', **meta) as dst:
        dst.write(cropped)


def upsample(img_lres_path, img_hres_path, img_size, outf, method=Resampling.bilinear):
    
    """
    img_lres_path: low resolution cropped band path
    img_hres_path: high resolution cropped band path
    img_size: the size to resample
    outf: output resampled Bands
    """
    
    dataset = rasterio.open(img_lres_path)

    # resample data to target shape
    data = dataset.read(
        out_shape=(
            dataset.count,
            int(img_size),
            int(img_size)
        ),
        resampling=method
    )

    dataset_hres = rasterio.open(img_hres_path)
    
    meta = dataset_hres.meta
    meta['count'] = dataset.count
    
    with rasterio.open(outf, 'w', **meta) as dst:
        dst.write(data)


def replace_ncld(arr):
    
    # NCLD classification: https://www.mrlc.gov/data/legends/national-land-cover-database-class-legend-and-description
    arr[arr == 11] = 1 # 11->1 water
    arr = np.where((20 < arr)&(39 > arr), 2, arr) # 2X and 3X ->2 developed and barren
    arr = np.where((40 < arr)&(49 > arr), 4, arr) # 4X->4 forest
    arr = np.where((50 < arr)&(59 > arr), 5, arr) # 5X->5 shrub
    arr = np.where((70 < arr)&(79 > arr), 6, arr) # 7X->6 Herbaceous
    arr = np.where((80 < arr)&(89 > arr), 7, arr) # 8X->7 planeted
    arr[arr == 90] = 8 # 90->8 woody wetland
    arr[arr == 95] = 9 # 95->9 emergent wetland
    
    return arr
    

In [21]:
ncld_resample_path = os.path.join(up(path_cur), 'data', 'NLCD_shoreline', 'nlcd_reclass.tif')
ncld_reclass = False

while ncld_reclass:
    
    ncld_reclass = False

    ncld = rasterio.open(ncld_path)
    ncld_array = ncld.read()
    reclass_array = replace_ncld(ncld_array)    
    meta = ncld.meta

    with rasterio.open(ncld_resample_path, 'w', **meta) as dst:
        dst.write(reclass_array)

In [23]:

coarse_clipping_resample = False
upsample_size = 128

while coarse_clipping_resample:
    
    coarse_clipping_resample = False
            
    ncld_img = rasterio.open(ncld_resample_path)

    for hres in allfiles:
        
        print('Working on patch {}'.format(hres))
        
        label_name = '_'.join(os.path.splitext(hres)[0].split('_')[-2:]) + '.tif'
        
        hres_path = os.path.join(os.path.join(base_path, 'image_patches_128', hres))
        crop_path = os.path.join(os.path.join(base_path, 'temp_label', label_name))
        resample_path = os.path.join(os.path.join(base_path, 'ncld_labels', label_name))
                
        cropping_bands(hres_path, ncld_img, crop_path)
        upsample(crop_path, hres_path, upsample_size, resample_path, method=Resampling.nearest)

In [7]:
# tmi label
Marsh_df = gpd.read_file(Marsh_DIR)
Marsh_df['marsh_presence'] = 1

In [24]:
tmi_label = False

while tmi_label:
    
    tmi_label = False

    for hres in tqdm(allfiles):
        
        print('Working on patch {}'.format(hres))
        
        label_name = '_'.join(os.path.splitext(hres)[0].split('_')[-2:]) + '.tif'
        
        if not label_name in os.listdir(os.path.join(base_path, 'tmi_labels')):

            hres_path = os.path.join(os.path.join(base_path, 'image_patches_128', hres))
            rst_path = os.path.join(os.path.join(base_path, 'tmi_labels', label_name))

            rst = rasterio.open(hres_path)
            meta = rst.meta.copy()
            meta.update(compress='lzw')
            meta['count'] = 1

            with rasterio.open(rst_path, 'w+', **meta) as out:
                out_arr = out.read(1)

                # this is where we create a generator of geom, value pairs to use in rasterizing
                shapes = ((geom,value) for geom, value in zip(Marsh_df.geometry, Marsh_df.marsh_presence))

                burned = features.rasterize(shapes=shapes, fill=0, out=out_arr, transform=out.transform)
                out.write_band(1, burned)

In [28]:
allfiles = [i for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if 'B03' in i and i.endswith('tif')]

In [29]:
len(allfiles)

2837

In [None]:
allfile_names = [os.path.splitext(i)[0].split('_')[-1] for i in os.listdir(os.path.join(base_path, 'image_patches_128')) if 'B03' in i and i.endswith('tif')]

In [None]:
allfile_names[0]

In [None]:
len(allfile_names)

In [None]:
df = pd.DataFrame(allfile_names, columns =['patch_name'])


In [None]:
train, test = train_test_split(df, test_size=0.4, random_state=32)
val, test = train_test_split(test, test_size=0.5, random_state=32)