In [None]:
# This code crops the input band image into 256 X 256 image patches within the bounding box
# Image patches that are not overlapping with marsh dataframe will be removed from the patches.


In [None]:
import os
import numpy as np
import pandas as pd
import rasterio
import sys
import geopandas as gpd
from pathlib import Path
from tqdm import tqdm
from rasterio.mask import mask, raster_geometry_mask
from shapely.geometry import box
from rasterio.enums import Resampling
from itertools import product
from rasterio import windows
import rioxarray
import shapely
import shutil
from rasterio import features

path_cur = os.path.abspath('.')
sys.path.append(path_cur)

from os.path import dirname as up

In [None]:
# Setting working directory and input data directory

base_path = Path(os.path.join(up(up(up(path_cur))), 'VIMS', 'NAIP', 'VA_NAIP_2018_8977', '2018_VA_wgs84'))
# (base_path / 'image_patches_256').mkdir(exist_ok=True, parents=True)
# (base_path / 'temp_patches').mkdir(exist_ok=True, parents=True)

Label_DATA_DIR = os.path.join(up(path_cur), 'data', 'processing_data', 'marsh_all_500.geojson')
overlap_index_tile = os.path.join(up(up(up(path_cur))), 'VIMS/NAIP/VA_NAIP_2018_8977/tileindex_VA_NAIP_2018/2018_VA_tiles.shp')

In [None]:
tile_gdf = gpd.read_file(overlap_index_tile)
overlap_tiles = tile_gdf['location'].to_list()

In [None]:
label_df = gpd.read_file(Label_DATA_DIR)

In [None]:
label_df.columns

In [None]:
def cropping_bands(ref_img_path, ups_img, outfile):
    
    """
    ref_img_path: input 10m resolution band
    ups_img_path: input low resolution band (rasterio.open() output)
    outfile: output low resolution band with geom alinged with ref_img
    """

    ref_img = rasterio.open(ref_img_path)
    # get the geometry of the reference high resolution band
    geom = box(*ref_img.bounds)
    
#     ups_img = rasterio.open(ups_img_path)
    cropped, crop_transf = mask(ups_img, [geom], crop=True, filled=False, all_touched=False)
    
    c, h, w = cropped.shape
    
    meta = ref_img.meta
    meta['width'], meta['height'] = w, h
    meta['transform'] = crop_transf

    with rasterio.open(outfile, 'w', **meta) as dst:
        dst.write(cropped)


def upsample(img_lres_path, img_hres_path, img_size, outf, method=Resampling.bilinear):
    
    """
    img_lres_path: low resolution cropped band path
    img_hres_path: high resolution cropped band path
    img_size: the size to resample
    outf: output resampled Bands
    """
    
    dataset = rasterio.open(img_lres_path)

    # resample data to target shape
    data = dataset.read(
        out_shape=(
            dataset.count,
            int(img_size),
            int(img_size)
        ),
        resampling=method
    )

    dataset_hres = rasterio.open(img_hres_path)
    
    meta = dataset_hres.meta
    
    with rasterio.open(outf, 'w', **meta) as dst:
        dst.write(data)


def get_tile_geom(tile_tif, crs=None):
    
    rds = rioxarray.open_rasterio(tile_tif)
    
    if crs is not None:

        assert isinstance(crs, str)
        
        rds_proj = rds.rio.reproject(crs)
        minx, miny, maxx, maxy = rds_proj.rio.bounds()
        geometry = shapely.geometry.box(minx, miny, maxx, maxy, ccw=True)
    
    else:
        
        minx, miny, maxx, maxy = rds.rio.bounds()
        geometry = shapely.geometry.box(minx, miny, maxx, maxy, ccw=True)
    
    return geometry

def get_tiles(ds, width=256, height=256):
    nols, nrows = ds.meta['width'], ds.meta['height']
    offsets = product(range(0, nols, width), range(0, nrows, height))
    big_window = windows.Window(col_off=0, row_off=0, width=nols, height=nrows)
    for col_off, row_off in offsets:
        window =windows.Window(col_off=col_off, row_off=row_off, width=width, height=height).intersection(big_window)
        transform = windows.transform(window, ds.transform)
        yield window, transform

In [None]:

N = 256
num = 0

cropping = False

temp_patch_dir = os.path.join(up(path_cur), 'data', 'NAIP_data_processing', 'temp_patch_2')
patch_dir = os.path.join(up(path_cur), 'data', 'NAIP_data_processing', 'image_patches_2')

while cropping:
    
    cropping = False
    
    for tile in tqdm(overlap_tiles):
        
        tile = tile.split('.')[0] + '_wgs84.tif'

        wgs84_tiff = Path(base_path) / tile

        print("------------------------------------------------------")

        print("Processing tile {}: ".format(wgs84_tiff))

        num += 1
        print("Image {}".format(str(num)))

        # select vectors that are within the tile
        tif_geom = get_tile_geom(wgs84_tiff)
        sub_gdf = label_df[label_df.within(tif_geom)]

        # Check the vectors that overlayed with the selected tile
        if not sub_gdf.empty:

            tile_name = tile.split('.')[0]

            Path(os.path.join(temp_patch_dir, tile_name)).mkdir(exist_ok=True, parents=True)

            output_filename = tile.split('.')[0] + '_tile_{}-{}.tif'

            with rasterio.open(wgs84_tiff) as inds:

                meta = inds.meta.copy()

                for window, transform in get_tiles(inds, N, N):

                    meta['transform'] = transform
                    meta['width'], meta['height'] = window.width, window.height

                    outpath = os.path.join(temp_patch_dir, tile_name, output_filename.format(int(window.col_off), int(window.row_off)))

                    with rasterio.open(outpath, 'w', **meta) as outds:
                        outds.write(inds.read(window=window))


                    patch_geom = get_tile_geom(outpath)
                    patch_gdf = label_df[label_df.within(patch_geom)]

                    if not patch_gdf.empty:

                        # move all subtiles that are inter-sect with the CUSP data to a separate folder, the imageries in this folder will be used
                        # to create training/validation data

                        patch_path = os.path.join(patch_dir, output_filename.format(int(window.col_off), int(window.row_off)))

                        shutil.copyfile(outpath, patch_path)


In [None]:
# p = Path(temp_patch_dir).glob('**/*')
# files = [x for x in p if x.is_file()]

In [None]:
# tmi label
Marsh_df = gpd.read_file(Label_DATA_DIR)
Marsh_df['marsh_presence'] = 1

In [None]:
allfiles = [i for i in os.listdir(os.path.join(up(path_cur), 'data', 'NAIP_data_processing', 'image_patch')) if i.endswith('tif')]
naip_base_path = os.path.join(up(path_cur), 'data', 'NAIP_data_processing')

In [None]:

tmi_label = False

while tmi_label:
    
    tmi_label = False

    for hres in tqdm(allfiles):
        
        print('Working on patch {}'.format(hres))
        
        if not hres in os.listdir(os.path.join(naip_base_path, 'tmi_labels')):

            hres_path = os.path.join(os.path.join(naip_base_path, 'image_patch', hres))
            rst_path = os.path.join(os.path.join(naip_base_path, 'tmi_labels', hres))

            rst = rasterio.open(hres_path)
            meta = rst.meta.copy()
            meta.update(compress='lzw')
            meta['count'] = 1

            with rasterio.open(rst_path, 'w+', **meta) as out:
                out_arr = out.read(1)

                # this is where we create a generator of geom, value pairs to use in rasterizing
                shapes = ((geom,value) for geom, value in zip(Marsh_df.geometry, Marsh_df.marsh_presence))

                burned = features.rasterize(shapes=shapes, fill=0, out=out_arr, transform=out.transform)
                out.write_band(1, burned)

In [None]:
df = pd.DataFrame(allfiles, columns =['patch_name'])
df['label'] = df.apply(lambda x: os.path.join(naip_base_path, 'tmi_labels', x['patch_name']), axis=1)
df.to_csv(os.path.join(naip_base_path, 'labels_tmi.csv'), index=False)