# Create land cover classes from 4-band orthomosaic

In [None]:
import os
import rioxarray as rxr
import xarray as xr
import rasterio as rio
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import geopandas as gpd

In [None]:
data_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS'
ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_4band_orthomosaic.tif')
roads_vector_fn = os.path.join(data_dir, 'roads', 'MCS_roads_polygon.shp')
masks_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_masks.tif')

# Load orthomosaic 
ortho_rxr = rxr.open_rasterio(ortho_fn)
# rearrange ortho to have separate bands
band_names = ['blue', 'green', 'red', 'NIR']
ortho = xr.Dataset(coords=dict(y=ortho_rxr.y.data, x=ortho_rxr.x.data))
for i, band_name in enumerate(band_names):
    ortho[band_name] = (('y', 'x'), ortho_rxr.data[i,:])
ortho = xr.where(ortho==0, np.nan, ortho / 1e4)
ortho = ortho.rio.write_crs(ortho_rxr.rio.crs)

# Load roads vector
roads_vector = gpd.read_file(roads_vector_fn)

In [None]:
# Classify trees and other vegetation using NDVI
ndvi = (ortho.NIR - ortho.green) / (ortho.NIR + ortho.green)
ndvi.data[ortho.green==0] = np.nan
ndvi_threshold = 0.1
trees_mask = (ndvi >= ndvi_threshold).astype(int)

# Convert roads to rasterized mask
roads_mask = ortho.blue.rio.clip(roads_vector.geometry.values, roads_vector.crs, drop=False)
roads_mask.data[ortho.green==0] = np.nan
roads_mask = xr.where(np.isnan(roads_mask), 0, 1)

# Classify snow using the NDSI
ndsi = (ortho.red - ortho.NIR) / (ortho.red + ortho.NIR)
ndsi.data[ortho.green==0] = np.nan
ndsi_threshold = 0.1
snow_mask = ((ndsi >= ndsi_threshold) & (trees_mask==0) & (roads_mask==0)).astype(int)

# Create stable surfaces mask (unclassified + roads)
ss_mask = ((snow_mask==0) & (trees_mask==0)).astype(int)


# Combine into one dataset
# masks = xr.Dataset(data_vars=dict(trees_mask=(('y', 'x'), trees_mask.data, 
#                                               {'Description': 'Constructed by thresholding the NDVI of the orthomosaic image',
#                                                'NDVI threshold': f'{ndvi_threshold}'}),
#                                   roads_mask=(('y', 'x'), roads_mask.data, 
#                                               {'Description': 'Constructed from the Source, buffered, rasterized, and interpolated to the orthomosaic image grid.',
#                                                'Source': 'U.S. Geological Survey National Transportation Dataset for Idaho (published 20240215) Shapefile: https://www.sciencebase.gov/catalog/item/5a5f36bfe4b06e28e9bfc1be'}),
#                                   snow_mask=(('y', 'x'), snow_mask.data, 
#                                              {'Description': 'Constructed by thresholding the NDSI of the orthomosaic image',
#                                               'NDSI bands': 'red, NIR',
#                                               'NDSI threshold': f'{ndsi_threshold}'}),
#                                   stable_surfaces_mask=(('y', 'x'), ss_mask.data, {'Description': 'Stable surfaces include all road-covered, snow-free, and tree-free surfaces according to the trees_mask, snow_mask, and roads_mask data variables.'})),
#                    coords=ortho.coords)

# Plot
plt.rcParams.update({'font.size': 12, 'font.sans-serif': 'Arial'})
fig, ax = plt.subplots(2, 1, figsize=(8,16))
ax[0].imshow(np.dstack([ortho.red, ortho.green, ortho.blue]) * 0.5, 
             extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                     np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
ax[0].set_title('RGB orthoimage')
xmin, xmax = ax[0].get_xlim()
ymin, ymax = ax[0].get_ylim()
# Iterate over masks
colors = [(77/255, 175/255, 74/255, 1), # trees
          (55/255, 126/255, 184/255, 1), # snow
          (166/255, 86/255, 40/255, 1)] # roads
for color, mask, mask_name in zip(colors, 
                                  [trees_mask, snow_mask, roads_mask], 
                                  ['trees mask', 'snow mask', 'roads mask']):
    cmap = matplotlib.colors.ListedColormap([(1,1,1,0), color])
    ax[1].imshow(masks[band].data, cmap=cmap, clim=(0,1),
                 extent=(np.min(ortho.x.data)/1e3, np.max(ortho.x.data)/1e3,
                     np.min(ortho.y.data)/1e3, np.max(ortho.y.data)/1e3))
    # plot dummy point for legend
    ax[1].plot(0, 0, 's', color=color, markersize=5, label=band)
ax[1].set_title('Land cover masks')
# reset axes limits
ax[1].set_xlim(xmin, xmax)
ax[1].set_ylim(ymin, ymax)
ax[1].legend(loc='lower right', markerscale=2)
fig.tight_layout()
plt.show()

# Save to file
masks = xr.where(np.isnan(ortho.green), -9999, masks)
masks = masks.astype(np.int16)
masks = masks.assign_attrs({'_FillValue': -9999})
masks = masks.rio.write_crs(ortho_rxr.rio.crs)
masks.rio.to_raster(masks_fn)
print('Masks saved to file:', masks_fn)
fig_fn = masks_fn.replace('.tif', '.png')
fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_fn)

In [None]:
trees_mask