In [None]:
#import libraries
import rasterio as rio
from rasterio.windows import Window
from matplotlib import pyplot as plt
import numpy as np
from itertools import product
import pandas as pd
from tqdm import tqdm
from pathlib import Path
import os
import json
from osgeo import gdal
import joblib
from rasterio.vrt import WarpedVRT

In [None]:
#CHANGE following path: 
src_path = Path("/path/S2A_T06VWP_HS.VRT") #this is path for Simulated image (VRT file generated using Simulation.ipynb)
dem_path = Path("/path/DEMlayer.tif") #Tiff file generated using DEM_preprocessing.ipynb

#set model path
model_path = Path("/..location../model/RandomForest.joblib") #available in model.zip inside Data (for boreal Alaska only)
meta_path = Path("/..location../model/Meta_Info.json") #available in model.zip inside Data (for boreal Alaska only)

#set output path and parameters
dst_dir = Path("/path/prediction") #set output path
tiles_dir = dst_dir / "Tiles"
tiles_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
dst_tag = "PredictedLabel"
ndvi_ll = 0.3 #NDVI threshold to mask pixels below this value
ndvi_ul = 1.0 #max value of NDVI
nir_bid = 96 #NIR band to calculate NDVI
red_bid = 56 #RED band to calculate NDVI
tile_height = 512 #set Tile size as required
tile_width = 512 #set Tile size as required
resampling_alg = 1

In [None]:
#Load model
clf = joblib.load(model_path)

In [None]:
#Processing 
color_dict = dict()
with open(meta_path, 'r') as fp:
  meta_map = json.load(fp)
  for k, v in meta_map.items():
    r = int(np.round((255 * v['RED']), decimals=0))
    g = int(np.round((255 * v['GREEN']), decimals=0))
    b = int(np.round((255 * v['BLUE']), decimals=0))
    a = int(np.round((255 * v['ALPHA']), decimals=0))
    color_dict[int(k)] = (r, g, b, a)

with rio.open(src_path, 'r') as src, rio.open(dem_path, 'r') as dem:
  assert src.crs == dem.crs, "CRS Mismatch!"
  with WarpedVRT(
    dem, 
    height=src.height, 
    width=src.width,
    transform=src.transform,
    resampling=resampling_alg,
  ) as vrt_dem:
    meta = src.meta.copy()
    img_h = src.height
    img_w = src.width
    big_win = Window(row_off=0, col_off=0, height=img_h, width=img_w)
    r_offsets = list(range(0, src.height, tile_height))
    c_offsets = list(range(0, src.width, tile_width))
    r_indexes = list(range(len(r_offsets)))
    c_indexes = list(range(len(c_offsets)))
    offsets = list(product(r_offsets, c_offsets))
    indexes = list(product(r_indexes, c_indexes))
    pointers = list(zip(indexes, offsets))
    meta['dtype'] = np.uint8
    meta['nodata'] = 0
    meta['driver'] = 'GTiff'
    color_dict[meta["nodata"]] = (0, 0, 0, 0)
    tile_list = list()
    
    for (i, j), (r_off, c_off) in tqdm(pointers):
      win = Window(
        row_off=r_off, col_off=c_off, height=tile_height, width=tile_width
      ).intersection(big_win)
      img_arr = src.read(window=win, boundless=False, masked=True)
      dem_arr = vrt_dem.read(
        window=win,
        boundless=False,
        masked=True
      )

      nir = img_arr[nir_bid]
      nir_mask = nir == src.nodata
      nir = nir.astype(np.float32)
      nir[nir_mask] = np.nan
      red = img_arr[red_bid]
      red_mask = red == src.nodata
      red = red.astype(np.float32)
      red[red_mask] = np.nan
      denominator = nir + red
      denominator[denominator==0] = np.nan
      nominator = nir - red
      ndvi = nominator / denominator
      ndvi_mask = np.logical_or(
        (ndvi < ndvi_ll),
        (ndvi > ndvi_ul)
      )
      dmask = np.any(dem_arr.mask, axis=0)
      imask = np.any(img_arr.mask, axis=0)
      mask = np.logical_or(dmask, imask)
      mask = np.logical_or(mask, ndvi_mask)
      dem_arr = dem_arr.astype(img_arr.dtype)
      dem_arr.fill_value = img_arr.fill_value
      dat_arr = np.concatenate((img_arr, dem_arr), axis=0)
      dat_arr.mask = np.tile(mask, (dat_arr.shape[0], 1, 1))
      valid = np.logical_not(mask)
      if np.any(valid):
          data = np.moveaxis(dat_arr[:, valid], 0, -1)
          pred = clf.predict(data)
          pred_img = np.full_like(
            valid, fill_value=meta['nodata'], dtype=meta['dtype']
          )
          pred_img[valid] = pred.astype(meta['dtype'])
          pred_img = np.stack((pred_img,), axis=0)
          meta['count'], meta['height'], meta['width'] = pred_img.shape
          meta['transform'] = src.window_transform(win)
          dst_path = tiles_dir / "{}_{}_{}.{}".format(
            dst_tag, i, j, 'tiff'
          )
          with rio.open(dst_path, 'w', **meta) as dst:
            dst.write(pred_img)
            dst.write_colormap(
              1, color_dict
            )
          tile_list.append(str(dst_path.relative_to(dst_dir)))
  os.chdir(dst_dir)
  vrt_path = dst_dir / '{}.{}'.format(dst_tag, "VRT")
  vrt_options = gdal.BuildVRTOptions(resampleAlg='near', addAlpha=False)
  ds = gdal.BuildVRT(
    str(vrt_path.relative_to(dst_dir)), tile_list, options=vrt_options
  )
  ds.FlushCache()