Using the CDL tif to reproject each HLS scene to CDL projection.

In [None]:
# Importing libraries
import xarray
import rioxarray
import pandas as pd
import numpy as np
import pyproj
import multiprocessing as mp
from rasterio.enums import Resampling
import json
from pathlib import Path
import os

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [None]:
# Path to the cropped cdl tif from cdl_generate.ipynb
cdl_file = "/data/requirements/cdl_raw/2022_30m_cdls.tif"

In [None]:
# Read in the tracking dataframe reating from the hls_v2_pipeline
track_df = pd.read_csv("/home/data/fmask/track_df.csv")

In [None]:
# Necessary functions to match each shape from the aoi to the cdl tif
def point_transform(coor, src_crs, target_crs=5070):
    proj = pyproj.Transformer.from_crs(src_crs, target_crs, always_xy=True)
    projected_coor = proj.transform(coor[0], coor[1])
    return [projected_coor[0], projected_coor[1]]

def find_nearest(array, value):
    idx = (np.abs(array - value)).argmin()
    return array[idx]

In [None]:
def reproject_hls(tile_path,
                  cdl_ds,
                  target_crs ="EPSG:5070", 
                  remove_original = True, 
                  resampling_method = Resampling.bilinear):
    
    """
    This function receives the path to a specific HLS tile and reproject it to the targeting crs_ds.
    The option of removing the raw HLS tile is provided
    
    Assumptions:
    - tile_path is a full path that end with .tif
    - cdl_ds is a rioxarray dataset that is opened with `cache=False` setting.
    
    
    Inputs:
    - tile_path: The full path to a specific HLS tile
    - target_crs: The crs that you wish to reproject the tile to, default is EPSG 4326
    - remove_original: The option to remove raw HLS tile after reprojecting, default is True
    - resampling_method: The method that rioxarray use to reproject, default is bilinear
    """
    
    xds = rioxarray.open_rasterio(tile_path)
    half_scene_len = np.abs(np.round((xds.x.max().data - xds.x.min().data) / 2))
    coor_min = point_transform([xds.x.min().data - half_scene_len, xds.y.min().data - half_scene_len], xds.rio.crs)
    coor_max = point_transform([xds.x.max().data + half_scene_len, xds.y.max().data + half_scene_len], xds.rio.crs)
    
    x0 = find_nearest(cdl_ds.x.data, coor_min[0])
    y0 = find_nearest(cdl_ds.y.data, coor_min[1])
    x1 = find_nearest(cdl_ds.x.data, coor_max[0])
    y1 = find_nearest(cdl_ds.y.data, coor_max[1])
    
    cdl_for_reprojection = cdl_ds.rio.slice_xy(x0, y0, x1, y1)
    
    xds_new = xds.rio.reproject_match(cdl_for_reprojection, resampling = resampling_method)

    if remove_original:
        if Path(tile_path).is_file():
            os.remove(tile_path)
        xds_new.rio.to_raster(raster_path = tile_path.replace(".tif", ".reproject.tif"))
    else:
        xds_new.rio.to_raster(raster_path = tile_path.replace(".tif", ".reproject.tif"))
    # cdl_for_reprojection.rio.to_raster(raster_path = f"/{tile_path.split('/')[1]}/{tile_path.split('/')[2]}/{tile_path.split('/')[3]}/{tile_path.split('/')[3]}_cdl.tif")

In [None]:
# Quality control to ensure there are three scenes for each tile.

# The number of tiles is to check if each tile name contains certain number of tiles. 
# This should either match the timestep or the number of could masks you want for each tile
num_tiles = 3

failed_tiles = []
for tile in list(track_df.tile.unique()):
    if len(track_df[track_df.tile == tile]) != num_tiles:
        failed_tiles.append(tile)
if len(failed_tiles) == 0:
    print("All tiles passed the quality test!")
else:
    print(f"Tile {failed_tiles} does not pass the quality test.")    

In [None]:
# For the raw image
track_df["cdl_file"] = cdl_file
track_df.loc[:, "bands"] = '["B02","B03","B04","B8A","B11","B12","Fmask"]'

In [None]:
# For the cloud masks
track_df["cdl_file"] = cdl_file
track_df.loc[:, "bands"] = '["Fmask"]'

In [None]:
track_df.head()

In [None]:
# Function that passed to multi-process the projecting task
def hls_process(kwargs):

    remove_original = True
    
    save_path = kwargs["save_path"]
    filename= kwargs["filename"]
    bands = json.loads(kwargs["bands"])
    cdl_file = kwargs["cdl_file"]
    
    cdl_ds = rioxarray.open_rasterio(cdl_file, cache=False)

    for band in bands:
        tile_path = f"{save_path}{filename}.{band}.tif"
        if Path(tile_path).is_file():
            if band == "Fmask":
                reproject_hls(tile_path, cdl_ds, remove_original, resampling_method = Resampling.nearest)
            else :
                reproject_hls(tile_path, cdl_ds, remove_original)

In [None]:
# Set the cpu_count to whatever you wish. Leave 4 cores to idel is recommended since it takes lots of memory. 
with mp.Pool(processes=mp.cpu_count() - 4) as pool:
    pool.map(hls_process, track_df.to_dict('records'))