In [10]:
import os
import gc
import time
import pandas as pd
import numpy as np

import rasterio
import rioxarray as rioxr
import geopandas as gpd

import dask_gateway
import dask.array as da

from joblib import load

# custom modules
import data_sampling_workflow.sample_rasters as sr
import refactoring_modules as rm

In [2]:
# **************************************************************
# ********* SPECIFY ITEMID AND LIDAR YEAR TO MATCH HERE ********

itemid = 'ca_m_3412039_nw_10_060_20200522'
filename = 'LS_preds_' + itemid + '.tif'

# **************************************************************
# **************************************************************

In [8]:
# initialize DASK cluster
cluster = dask_gateway.GatewayCluster()
client = cluster.get_client()
cluster.scale(30)
client

0,1
Connection method: Cluster object,Cluster type: dask_gateway.GatewayCluster
Dashboard: https://pccompute.westeurope.cloudapp.azure.com/compute/services/dask-gateway/clusters/prod.33bf89c5d0fb4feb9f83d2802f3d5c9c/status,


In [3]:
# ---------------------------------------
# open NAIP scene
raster = rm.rioxr_from_itemid(itemid)

# ***********************************************************************
# ****************** ADD SPECTRAL AND DATE FEATURES *********************

# find vegetation pixels to go into model
# keep ndices of water and low-ndvi pixels
# adds columns of ndvi and ndwi features for each pixel
t0 = time.time()
is_veg, water_index, not_veg_index = rm.add_spectral_features(df = rm.raster_as_df(raster.to_numpy(), 
                                                                                   ['r','g','b','nir']), #names of bands
                                                           ndwi_thresh = 0.3, 
                                                           ndvi_thresh = 0.05) 
# drop ndwi and add dates
is_veg.drop('ndwi', axis=1, inplace=True)
is_veg = rm.add_date_features(is_veg, raster.datetime)
print('time taken to assemble pixels into dataframe with features: ', time.time() - t0,' s')


# *******************************************************************
# ****************** ADD CANOPY HEIGHT FEATURES *********************

# Create auxiliary canopy height files to sample from
t0 = time.time()

lidar_rast_reader = rasterio.open(sr.path_to_lidar(raster.datetime.year))   # open canopy height raster
rast_name = 'SB_canopy_height_'+str(raster.datetime.year) # give a name to canopy height raster

# save aux rasters in temp folder
sr.min_raster(rast_reader = lidar_rast_reader, rast_name = rast_name, n=3)  
sr.max_raster(rast_reader = lidar_rast_reader, rast_name = rast_name, n=3)
sr.avg_raster(rast_reader = lidar_rast_reader, rast_name = rast_name, n=3)

print('time to make auxiliary rasters: ', (time.time()-t0), 'seconds')

# ---------------------------------------
# Resample canopy height layers to match NAIP scene resolution and extent
t0 = time.time()

lidar_fps = [sr.path_to_lidar(raster.datetime.year)]  # file paths to aux canopy height rasters
for tag in ['_avgs', '_maxs', '_mins']:
    lidar_fps.append(os.path.join(os.getcwd(),'temp',rast_name+tag+'.tif'))

lidar_values = []    # resampled canopy height layers as vectors
for fp in lidar_fps:
    match = sr.open_and_match(fp, raster)
    match_vector = match.to_numpy().reshape(match.shape[0]*match.shape[1])
    lidar_values.append(match_vector)

df_lidar = pd.DataFrame(dict(zip(['lidar','avg_lidar', 'max_lidar', 'min_lidar'], lidar_values)))
df_lidar = df_lidar.assign(min_max_diff =  df_lidar['max_lidar'] - df_lidar['min_lidar'])
print('time to resample and reshape rasters: ', (time.time()-t0), 'seconds')

for i in range(1,4):   # delete aux canopy height rasters
    os.remove(lidar_fps[i])

# ---------------------------------------
# add LIDAR features to vegetation dataframe
feature_order = ['r', 'g', 'b', 'nir', 'ndvi', 
                  'year', 'month', 'day_in_year',
                  'lidar', 'max_lidar', 'min_lidar', 'min_max_diff', 'avg_lidar']

scene_features = pd.concat([is_veg, df_lidar.iloc[is_veg.index]], axis=1)
scene_features = scene_features[feature_order]

# ---------------------------------------
#free memory
del is_veg, df_lidar, match_vector, lidar_values
gc.collect()


# *******************************************************************
# ****************** PREDICT USING DASK *****************************

#open pre-trained random forest classifier
rfc = load('spectral_rfc.joblib') 
print('loaded model')

# ---------------------------------------
# convert into dask.array and predict using model
da_pixels = da.from_array(np.array(scene_features), chunks=728802)
scene_preds = rfc.predict(da_pixels)
t0 = time.time()
preds = scene_preds.compute()
print('time taken to predict: ', time.time() - t0,' s')

# ---------------------------------------
# recover pixel indices for iceplant classifications
preds_df = pd.DataFrame(preds, 
                     columns=['is_iceplant'], 
                     index = scene_features.index)
is_iceplant_index = preds_df[preds_df.is_iceplant == 1].index.to_numpy()
non_iceplant_index = preds_df[preds_df.is_iceplant == 0].index.to_numpy()


# *******************************************************************
# ****************** RECONSTRUCT INTO IMAGE *************************

# indices of different categories
indices = [non_iceplant_index,
           is_iceplant_index, 
           not_veg_index,
           water_index]
values = [0,    # values assigned to pixels from each index
          1,
          2,
          3]
t0 = time.time()
reconstruct = rm.indices_to_image(raster.shape[1], raster.shape[2], indices, values, back_value=100)
print('reconstructed image\n   time taken to reconstruct: ', time.time() - t0,' s')


# *******************************************************************
# ****************** SAVE PREDICTIONS AS RASTER *********************

with rasterio.open(
    os.path.join(os.getcwd(),'temp',filename),  # file path
    'w',           # w = write
    driver = 'GTiff', # format
    height = reconstruct.shape[0], 
    width = reconstruct.shape[1],
    count = 1,  # number of raster bands in the dataset
    dtype = rasterio.uint8,
    crs = raster.rio.crs,
    transform = raster.rio.transform(),
) as dst:
    dst.write(reconstruct.astype(rasterio.uint8), 1)
