In [None]:
import os
import rasterio
import geopandas as gpd
import rasterio.features
from rasterio.plot import reshape_as_image
import pickle

In [None]:
# load the model
with open(os.path.join('data', 'trained_model.pkl'), "rb") as f:
	rf= pickle.load(f)

In [None]:
stacked_tif_dr = os.path.join('data', 'stacked_bands.tif') # where stacked image is saved
src = rasterio.open(stacked_tif_dr)
meta = src.meta
img = src.read()
print(img.shape) # (bands, rows, cols)
reshaped_img = reshape_as_image(img)
print(reshaped_img.shape) # (rows, cols, bands)
# Reshape to 2D array
class_input = reshaped_img.reshape(-1, reshaped_img.shape[-1])
print(class_input.shape) # (rows*cols, bands)

In [None]:
# Sentinel-2 stack classification
class_RF_S2 = rf.predict(class_input)
# Reshape our classification map back into a 2d array so we can visualize it
class_RF_S2 = class_RF_S2.reshape(reshaped_img[:, :, 0].shape)

In [None]:
# mask non crop/agricultral areas
agri_area = gpd.read_file('area/ag_only.geojson') #read shapefile
agri_mask = rasterio.features.geometry_mask(
  agri_area.geometry, out_shape=class_RF_S2.shape, transform=meta['transform'], invert=False
)
class_RF_S2[agri_mask] = 255

In [None]:
# export classified image
meta.update(count=1, dtype= 'uint8', nodata=255)
# Output file path for the stacked GeoTIFF
output_path = os.path.join('results', 'classified.tif')
# Write the stacked bands to the output GeoTIFF
with rasterio.open(output_path, 'w', **meta) as dest:
	dest.write((class_RF_S2), 1)