In [12]:
# Random Forest (per-pixel) classification from a multi-band VRT and training polygons
# deps: rasterio, geopandas, scikit-learn, numpy

import os
import numpy as np
import rasterio
from rasterio.features import rasterize
import geopandas as gpd
from sklearn.ensemble import RandomForestClassifier

wks = r'Q:\dss_workarea\mlabiadh\workspace\20250813_luscidEye_imagery_assessement\OBIA_test\classification'
vrt_path = os.path.join(wks, "stack_noNDVI.vrt")
train_shp = os.path.join(wks, "training_areas.shp")

n_trees = 500  # Number of trees in the Random Forest
out_tif  = os.path.join(wks, f"class_RF_{n_trees}trees.tif")


with rasterio.open(vrt_path) as src:
    print(f"Reading input imagery")
    img = src.read()  # (bands, rows, cols)
    profile = src.profile
    nodata = src.nodata
    rows, cols = src.height, src.width

    print ("Reading and rasterizing training polygons")
    # Read & align training polygons
    gdf = gpd.read_file(train_shp)
    if gdf.crs != src.crs:
        gdf = gdf.to_crs(src.crs)

    # Rasterize class labels (assumes integer column 'class_id')
    shapes = [(geom, int(cid)) for geom, cid in zip(gdf.geometry, gdf["class_id"])]
    y = rasterize(
        shapes,
        out_shape=(rows, cols),
        transform=src.transform,
        fill=0,
        all_touched=True,
        dtype="int32",
    ).reshape(-1)


    # Build feature matrix (pixels × bands)
    X = np.moveaxis(img, 0, -1).reshape(-1, img.shape[0])

    # Valid training mask: labeled pixels (>0) and valid imagery
    if nodata is not None:
        valid_pix = (y > 0) & (~np.any(np.isclose(X, nodata), axis=1))
    else:
        valid_pix = (y > 0) & (np.all(np.isfinite(X), axis=1))

    X_train = X[valid_pix]
    y_train = y[valid_pix]

    print ("Running Random Forest classification")
    # Train RF and predict over all valid pixels
    clf = RandomForestClassifier(n_estimators=n_trees, n_jobs=-1, random_state=42)
    clf.fit(X_train, y_train)

    if nodata is not None:
        valid_all = ~np.any(np.isclose(X, nodata), axis=1)
    else:
        valid_all = np.all(np.isfinite(X), axis=1)

    y_pred = np.zeros(y.shape, dtype=np.int16)
    y_pred[valid_all] = clf.predict(X[valid_all])

    print ("Writing classified raster")
    # Write classified raster (LZW-compressed)
    gtiff_profile = {
        "driver": "GTiff",
        "height": rows,
        "width": cols,
        "count": 1,
        "dtype": "int16",
        "crs": src.crs,
        "transform": src.transform,
        "compress": "LZW",
        "predictor": 2,        # better LZW compression for integer data
        "tiled": True,
        "blockxsize": 256,
        "blockysize": 256,
        "nodata": 0,
        "BIGTIFF": "IF_SAFER", # avoids 4GB limit if needed
    }

    with rasterio.open(out_tif, "w", **gtiff_profile) as dst:
        dst.write(y_pred.reshape(rows, cols), 1)


Reading input imagery
Reading and rasterizing training polygons
Running Random Forest classification
Writing classified raster


In [11]:
vrt_path = os.path.join(wks, "stack_noNDVI.vrt")

with rasterio.open(vrt_path) as src:
    print(f"Driver: {src.driver}  |  Size: {src.width} x {src.height}  |  Bands: {src.count}")
    for i in range(1, src.count + 1):
        dt = src.dtypes[i-1] if src.dtypes else "unknown"
        nd = src.nodatavals[i-1] if src.nodatavals else None
        desc = src.descriptions[i-1] if src.descriptions and len(src.descriptions) >= i else ""
        print(f"Band {i:02d}: dtype={dt}, nodata={nd}, desc='{desc}'")

Driver: VRT  |  Size: 3704 x 3052  |  Bands: 28
Band 01: dtype=float32, nodata=-3.4028234663852886e+38, desc='None'
Band 02: dtype=float32, nodata=-3.4028234663852886e+38, desc='None'
Band 03: dtype=float32, nodata=-3.4028234663852886e+38, desc='None'
Band 04: dtype=float32, nodata=-3.4028234663852886e+38, desc='None'
Band 05: dtype=float32, nodata=nan, desc='None'
Band 06: dtype=float32, nodata=nan, desc='None'
Band 07: dtype=float32, nodata=nan, desc='None'
Band 08: dtype=float32, nodata=nan, desc='None'
Band 09: dtype=float32, nodata=nan, desc='None'
Band 10: dtype=float32, nodata=nan, desc='None'
Band 11: dtype=float32, nodata=nan, desc='None'
Band 12: dtype=float32, nodata=nan, desc='None'
Band 13: dtype=float32, nodata=nan, desc='None'
Band 14: dtype=float32, nodata=nan, desc='None'
Band 15: dtype=float32, nodata=nan, desc='None'
Band 16: dtype=float32, nodata=nan, desc='None'
Band 17: dtype=float32, nodata=nan, desc='None'
Band 18: dtype=float32, nodata=nan, desc='None'
Band 19: