In training RFs to expand our validation dataset, we've come across two situations that are not captured by the approach thus far - 
1. The spectral characteristics of water within the validation chip is not representative of water surfaces present in the broader Planet image
2. A validation chip contains no water classification, while the broader Planet image contains some water. Since the RF trained on this data is not given any OSW examples, it will classify the entire Planet image as "not-water"

To mitigate this, we will train a random forest on validation data from multiple sites

In [3]:
import rasterio

import pandas as pd
from pathlib import Path
import numpy as np
from collections import defaultdict
import random

# ML imports
from sklearn.ensemble import RandomForestClassifier
from skimage.segmentation import felzenszwalb
from tools import get_superpixel_stds_as_features, get_superpixel_means_as_features, get_array_from_features, reproject_arr_to_match_profile
from sklearn.model_selection import train_test_split
import joblib
from joblib import dump

# local imports
from rf_funcs import calc_ndwi, calc_ndvi, return_grn_indices, return_img_bands, return_reflectance_coeffs

from tqdm import tqdm

# for repeatability
np.random.seed(42)

In [7]:
RETRAIN_MODEL=True
SUBSET_TRAINING=False # Set to True to use only a limited number of chips from each strata
SUBSET_NUMBER=4 # Number of chips to use from each strata

In [8]:
# Read the validation database
data_path = Path('../data/')
val_chips_db = data_path / 'validation_table.csv'
val_df = pd.read_csv(val_chips_db)

site_names = list(val_df['site_name'])
planet_ids = list(val_df['planet_id'])

# Extract planet IDs and associated strata
site_names_stratified = defaultdict(list)
for sn, planet_id in zip(site_names, planet_ids):
    site_names_stratified[sn[:2]].append(planet_id)

print(site_names_stratified.keys())

dict_keys(['4_', '1_', '3_', '2_'])


In [9]:
# We can either use 4 chips from each strata, or ALL chips from each strata
training_sites = []
for key in site_names_stratified.keys():
    if SUBSET_TRAINING:
        training_sites.extend(np.random.choice(site_names_stratified[key], 4))
    else:
        training_sites.extend(site_names_stratified[key])

print("Training sites: ", training_sites)

Training sites:  ['20210917_152712_47_2457', '20210914_103644_25_2413', '20211016_135440_48_2459', '20211002_155415_1009', '20210927_105543_66_2424', '20211028_012421_73_245f', '20210915_011051_87_240a', '20210915_173832_80_2307', '20211021_133031_75_245a', '20210925_072712_16_2254', '20211010_135831_84_227e', '20211021_182217_09_2456', '20211030_142613_41_227b', '20211012_083209_57_245c', '20210930_021156_60_2434', '20210929_073913_09_2453']


We have the name of the planet ids. For each, do the following - 
1. Read the cropped planet image and the corresponding validation labels
2. Generate superpixels and calculate mean and std dev.
3. Append to list
4. Train and save model
5. Apply model to broader Planet images

In [28]:
if RETRAIN_MODEL:
    X, class_features = None, None
    for idx, site in enumerate(training_sites):
        print(f"Currently processing site # {idx}")

        current_img_path = data_path / site
        cropped_img_path = data_path / 'planet_images_cropped' / site
        
        xml_file = list(current_img_path.glob('*.xml'))[0]
        chip = list(cropped_img_path.glob(f'cropped_{site}*.tif'))[0]
        classification = list(cropped_img_path.glob(f'classification_*.tif'))[0] 

        band_idxs = return_grn_indices(xml_file)
        coeffs = return_reflectance_coeffs(xml_file, band_idxs)
        chip_img = return_img_bands(chip, band_idxs, denoising_weight=None)

        with rasterio.open(chip) as src_ds:
            ref_profile = src_ds.profile

        green = chip_img[0]*coeffs[band_idxs[0]]
        red = chip_img[1]*coeffs[band_idxs[1]]
        nir = chip_img[2]*coeffs[band_idxs[2]]

        with rasterio.open(classification) as src_ds:
            cl = src_ds.read(1)
            cl_profile = src_ds.profile

        # some classification extents are not the same as the corresponding planet chip extent
        # if they are not the same, reproject the validation data to match the profile of the planet data
        if ((ref_profile['transform'] != cl_profile['transform']) | 
            (ref_profile['width'] != cl_profile['width']) | 
            (ref_profile['height'] != cl_profile['height'])):

            cl, _ = reproject_arr_to_match_profile(cl, cl_profile, ref_profile)
            cl = np.squeeze(cl)

        ndwi = calc_ndwi(green, nir)
        ndvi = calc_ndvi(red, nir)

        # segment image using green, nir, and NDWI channels
        img_stack = np.stack([green, nir, ndwi], axis=-1)
        segments = felzenszwalb(img_stack, sigma=0, min_size=10)

        # create training data that includes other channels as well
        img_stack = np.stack([red, nir, green, ndwi, ndvi], axis=-1)     
        std_features = get_superpixel_stds_as_features(segments, img_stack)
        mean_features = get_superpixel_means_as_features(segments, img_stack)

        if X is None:
            X = np.concatenate([mean_features, std_features], axis = 1)
        else:
            X_temp = np.concatenate([mean_features, std_features], axis = 1)
            X = np.concatenate([X, X_temp], axis=0)

        # We have superpixels, we now need to map each of the segments to the associated label
        # A 0 value indicates no label for the segment
        
        class_features_temp = np.zeros((mean_features.shape[0], 1))
        for class_id in [0, 1]:
            # Get all superpixel labels with particular id
            superpixel_labels_for_class = np.unique(segments[class_id == cl])
            # Label those superpixels with approrpriate class
            class_features_temp[superpixel_labels_for_class] = class_id

        if class_features is None:
            class_features = class_features_temp
        else:
            class_features = np.concatenate([class_features, class_features_temp], axis=0)


    print("Beginning model training")
    # Define an RF to be trained. setting n_jobs = -1 uses all available processors
    rf = RandomForestClassifier(n_estimators=300, class_weight='balanced', oob_score=True, random_state=0, n_jobs=-1)

    # train model on all of the available data
    rf.fit(X, class_features.ravel())

    rf_model_folder = data_path / 'trained_model' / 'rf_model'
    rf_model_folder.mkdir(exist_ok=True, parents=True)
    model_path = rf_model_folder/"rf_model_alldata.joblib"

    # save for later use
    dump(rf, model_path)

Currently processing site # 0


  means = sums / counts


Currently processing site # 1


  means = sums / counts


Currently processing site # 2


  means = sums / counts


Currently processing site # 3


  means = sums / counts


Currently processing site # 4


  means = sums / counts


Currently processing site # 5


  means = sums / counts


Currently processing site # 6


  means = sums / counts


Currently processing site # 7


  means = sums / counts


Currently processing site # 8


  means = sums / counts


Currently processing site # 9


  means = sums / counts


Currently processing site # 10


  means = sums / counts


Currently processing site # 11


  means = sums / counts


Currently processing site # 12


  means = sums / counts


Currently processing site # 13


  means = sums / counts


Currently processing site # 14


  means = sums / counts


Currently processing site # 15


  means = sums / counts


Currently processing site # 16


  means = sums / counts


Currently processing site # 17


  means = sums / counts


Currently processing site # 18


  means = sums / counts


Currently processing site # 19


  means = sums / counts


Currently processing site # 20


  means = sums / counts


Currently processing site # 21


  means = sums / counts


Currently processing site # 22


  means = sums / counts


Currently processing site # 23


  means = sums / counts


Currently processing site # 24


  means = sums / counts


Currently processing site # 25


  means = sums / counts


Currently processing site # 26


  means = sums / counts


Currently processing site # 27


  means = sums / counts


Currently processing site # 28


  means = sums / counts


Currently processing site # 29


  means = sums / counts


Currently processing site # 30


  means = sums / counts


Currently processing site # 31


  means = sums / counts


Currently processing site # 32


  means = sums / counts


Currently processing site # 33


  means = sums / counts


Currently processing site # 34


  means = sums / counts


Currently processing site # 35


  means = sums / counts


Currently processing site # 36


  means = sums / counts


Currently processing site # 37


  means = sums / counts


Currently processing site # 38


  means = sums / counts


Currently processing site # 39


  means = sums / counts


Currently processing site # 40


  means = sums / counts


Currently processing site # 41


  means = sums / counts


Currently processing site # 42


  means = sums / counts


Currently processing site # 43


  means = sums / counts


Currently processing site # 44


  means = sums / counts


Currently processing site # 45


  means = sums / counts


Currently processing site # 46


  means = sums / counts


Currently processing site # 47


  means = sums / counts


Currently processing site # 48


  means = sums / counts


Currently processing site # 49


  means = sums / counts


Currently processing site # 50


  means = sums / counts


Currently processing site # 51


  means = sums / counts


Beginning model training


In [18]:
# X_train, X_test, y_train, y_test = train_test_split(X, class_features, test_size=0.15, random_state=0)

# print("Beginning model training")
# # Define an RF to be trained. setting n_jobs = -1 uses all available processors
# rf = RandomForestClassifier(n_estimators=500, class_weight='balanced', oob_score=True, random_state=0, n_jobs=-1)

# # train model on all of the available data
# rf.fit(X_train, y_train.ravel())

# print("Model test score: ", rf.score(X_test, y_test))

# # rf_model_folder = data_path / 'trained_model' / 'rf_model'
# # rf_model_folder.mkdir(exist_ok=True, parents=True)
# # model_path = rf_model_folder/"rf_model.joblib"

# # # save for later use
# # dump(rf, model_path)

In [19]:
# print("Beginning model training")
# # Define an RF to be trained. setting n_jobs = -1 uses all available processors
# rf = RandomForestClassifier(n_estimators=300, class_weight='balanced', oob_score=True, random_state=0, n_jobs=-1)

# # train model on all of the available data
# rf.fit(X, class_features.ravel())

# rf_model_folder = data_path / 'trained_model' / 'rf_model'
# rf_model_folder.mkdir(exist_ok=True, parents=True)
# model_path = rf_model_folder/"rf_model.joblib"

# # save for later use
# dump(rf, model_path)

In [6]:
if not RETRAIN_MODEL:
    rf_model_folder = data_path / 'trained_model' / 'rf_model'
    model_path = rf_model_folder/"rf_model.joblib"
    rf = joblib.load(model_path)

Let's make inferences on the broader planet images

In [20]:
def generate_inference_helper(rf, img:str|Path, xml_file:str|Path):
    band_idxs = return_grn_indices(xml_file)
    coeffs = return_reflectance_coeffs(xml_file, band_idxs)
    
    full_img = return_img_bands(img, band_idxs, denoising_weight=None)

    green = full_img[0]*coeffs[band_idxs[0]]
    red = full_img[1]*coeffs[band_idxs[1]]
    nir = full_img[2]*coeffs[band_idxs[2]]

    ndwi = calc_ndwi(green, nir)
    ndvi = calc_ndvi(red, nir)

    img_stack = np.stack([green, nir, ndwi], axis=-1)
    segments = felzenszwalb(img_stack, sigma=0, min_size=10)

    # for inference we include other channels as well
    img_stack = np.stack([red, nir, green, ndwi, ndvi], axis=-1)
    std_features = get_superpixel_stds_as_features(segments, img_stack)
    mean_features = get_superpixel_means_as_features(segments, img_stack)

    X = np.concatenate([mean_features, std_features], axis = 1)
    y = rf.predict(X)

    return get_array_from_features(segments, np.expand_dims(y, axis=1))

def generate_inference(planet_id):
    """ 
    This function takes in a planet_id and generates inferences for the overlapping planet image
    """
    data_path = Path('../data')
    
    current_img_path = data_path / planet_id
    cropped_img_path = data_path / 'planet_images_cropped' / planet_id
    xml_file = list(current_img_path.glob('*.xml'))[0]
    classification = list(cropped_img_path.glob(f'classification_*.tif'))[0]

    img = list(current_img_path.glob(f'{planet_id}*.tif'))[0]

    inference = generate_inference_helper(rf, img, xml_file)

    # use planet image to mask out regions of no data in the model inference
    with rasterio.open(img) as src_ds:
        nodata_mask = np.where(src_ds.read(1) == src_ds.profile['nodata'], 1, 0)
        inference[nodata_mask==1] = 255
        profile_copy = src_ds.profile
        profile_copy.update({'count':1, 'dtype':np.uint8, 'nodata':255})

        # write out model inference
        with rasterio.open(f"{classification.parent}/new_full_img_rf_classification_{classification.name}", 'w', **profile_copy) as dst_ds:
            dst_ds.write(inference.astype(np.uint8).reshape(1, *inference.shape))

    print(f"Completed inference for planet id {planet_id}")

In [29]:
_ = list(map(generate_inference, tqdm(planet_ids)))

  means = sums / counts
  2%|▏         | 1/52 [16:01<13:37:37, 961.91s/it]

Completed inference for planet id 20210903_150800_60_2458


  means = sums / counts
  4%|▍         | 2/52 [25:09<9:58:29, 718.18s/it] 

Completed inference for planet id 20210903_152641_60_105c


  means = sums / counts
  6%|▌         | 3/52 [49:19<14:19:32, 1052.50s/it]

Completed inference for planet id 20210904_093422_44_1065


  means = sums / counts
  8%|▊         | 4/52 [1:12:47<15:54:17, 1192.87s/it]

Completed inference for planet id 20210906_101112_28_225a


  means = sums / counts
 10%|▉         | 5/52 [1:33:25<15:47:02, 1208.99s/it]

Completed inference for planet id 20210909_000649_94_222b


  means = sums / counts
 12%|█▏        | 6/52 [2:00:04<17:08:32, 1341.57s/it]

Completed inference for planet id 20210911_001230_44_2262


  means = sums / counts
 13%|█▎        | 7/52 [2:06:12<12:47:33, 1023.40s/it]

Completed inference for planet id 20210911_005129_82_106a


  means = sums / counts
 15%|█▌        | 8/52 [2:18:09<11:18:47, 925.63s/it] 

Completed inference for planet id 20210912_034049_22_2421


  means = sums / counts
 17%|█▋        | 9/52 [2:45:12<13:39:36, 1143.65s/it]

Completed inference for planet id 20210912_094213_84_240f


  means = sums / counts
 19%|█▉        | 10/52 [3:20:33<16:51:46, 1445.39s/it]

Completed inference for planet id 20210914_094548_30_2406


  means = sums / counts
 21%|██        | 11/52 [3:35:13<14:29:30, 1272.46s/it]

Completed inference for planet id 20210914_103644_25_2413


  means = sums / counts
 23%|██▎       | 12/52 [4:05:54<16:03:41, 1445.54s/it]

Completed inference for planet id 20210915_011051_87_240a


  means = sums / counts
 25%|██▌       | 13/52 [5:00:38<21:41:35, 2002.45s/it]

Completed inference for planet id 20210915_172340_12_245f


  means = sums / counts
 27%|██▋       | 14/52 [5:27:49<19:57:05, 1890.14s/it]

Completed inference for planet id 20210915_173832_80_2307


  means = sums / counts
 29%|██▉       | 15/52 [5:38:45<15:36:16, 1518.27s/it]

Completed inference for planet id 20210916_010848_94_2407


  means = sums / counts
 31%|███       | 16/52 [5:50:15<12:41:18, 1268.86s/it]

Completed inference for planet id 20210917_140704_93_2262


  means = sums / counts
 33%|███▎      | 17/52 [6:51:19<19:20:12, 1988.93s/it]

Completed inference for planet id 20210917_152712_47_2457


  means = sums / counts
 35%|███▍      | 18/52 [7:17:44<17:38:26, 1867.84s/it]

Completed inference for planet id 20210922_171337_39_2420


  means = sums / counts
 37%|███▋      | 19/52 [7:48:09<17:00:08, 1854.80s/it]

Completed inference for planet id 20210924_000522_94_2421


  means = sums / counts
 38%|███▊      | 20/52 [8:18:34<16:24:23, 1845.74s/it]

Completed inference for planet id 20210924_082025_48_2424


  means = sums / counts
 40%|████      | 21/52 [8:30:01<12:54:01, 1498.12s/it]

Completed inference for planet id 20210924_133812_95_2420


  means = sums / counts
 42%|████▏     | 22/52 [8:46:23<11:11:34, 1343.16s/it]

Completed inference for planet id 20210925_072712_16_2254


  means = sums / counts
 44%|████▍     | 23/52 [9:09:28<10:55:15, 1355.71s/it]

Completed inference for planet id 20210926_020646_15_2231


  means = sums / counts
 46%|████▌     | 24/52 [9:40:43<11:45:22, 1511.53s/it]

Completed inference for planet id 20210927_105543_66_2424


  means = sums / counts
 48%|████▊     | 25/52 [10:11:43<12:07:12, 1616.02s/it]

Completed inference for planet id 20210928_141837_16_2407


  means = sums / counts
 50%|█████     | 26/52 [10:25:48<10:00:05, 1384.81s/it]

Completed inference for planet id 20210928_211311_91_2457


  means = sums / counts
 52%|█████▏    | 27/52 [10:58:04<10:45:56, 1550.28s/it]

Completed inference for planet id 20210929_073913_09_2453


  means = sums / counts
 54%|█████▍    | 28/52 [11:36:42<11:52:12, 1780.50s/it]

Completed inference for planet id 20210930_021156_60_2434


  means = sums / counts
 56%|█████▌    | 29/52 [11:48:28<9:18:56, 1458.13s/it] 

Completed inference for planet id 20210930_070548_00_2442


  means = sums / counts
 58%|█████▊    | 30/52 [11:58:40<7:21:34, 1204.29s/it]

Completed inference for planet id 20211002_155415_1009


  means = sums / counts
 60%|█████▉    | 31/52 [12:09:50<6:05:26, 1044.12s/it]

Completed inference for planet id 20211003_161639_91_241d


  means = sums / counts
 62%|██████▏   | 32/52 [12:42:22<7:18:45, 1316.28s/it]

Completed inference for planet id 20211004_132710_80_240c


  means = sums / counts
 63%|██████▎   | 33/52 [13:32:43<9:38:49, 1827.87s/it]

Completed inference for planet id 20211004_135131_16_2276


  means = sums / counts
 65%|██████▌   | 34/52 [14:19:46<10:37:52, 2126.27s/it]

Completed inference for planet id 20211005_030345_52_241f


  means = sums / counts
 67%|██████▋   | 35/52 [14:42:46<8:59:03, 1902.54s/it] 

Completed inference for planet id 20211010_135831_84_227e


  means = sums / counts
 69%|██████▉   | 36/52 [14:57:23<7:05:16, 1594.79s/it]

Completed inference for planet id 20211011_065101_82_2274


  means = sums / counts
 71%|███████   | 37/52 [15:20:15<6:21:59, 1527.96s/it]

Completed inference for planet id 20211011_155455_52_2262


  means = sums / counts
 73%|███████▎  | 38/52 [15:47:28<6:03:51, 1559.36s/it]

Completed inference for planet id 20211012_002401_90_2459


  means = sums / counts
 75%|███████▌  | 39/52 [16:15:21<5:45:15, 1593.49s/it]

Completed inference for planet id 20211012_083209_57_245c


  means = sums / counts
 77%|███████▋  | 40/52 [16:44:29<5:27:57, 1639.76s/it]

Completed inference for planet id 20211016_135440_48_2459


  means = sums / counts
 79%|███████▉  | 41/52 [17:21:26<5:32:24, 1813.16s/it]

Completed inference for planet id 20211016_174304_93_245f


  means = sums / counts
 81%|████████  | 42/52 [17:35:21<4:13:17, 1519.71s/it]

Completed inference for planet id 20211019_180019_07_2453


  means = sums / counts
 83%|████████▎ | 43/52 [18:09:18<4:11:13, 1674.88s/it]

Completed inference for planet id 20211021_133031_75_245a


  means = sums / counts
 85%|████████▍ | 44/52 [18:24:29<3:12:44, 1445.55s/it]

Completed inference for planet id 20211021_182217_09_2456


  means = sums / counts
 87%|████████▋ | 45/52 [18:39:24<2:29:23, 1280.46s/it]

Completed inference for planet id 20211022_175213_55_2440


  means = sums / counts
 88%|████████▊ | 46/52 [18:58:18<2:03:38, 1236.40s/it]

Completed inference for planet id 20211023_115642_67_105d


  means = sums / counts
 90%|█████████ | 47/52 [19:37:37<2:11:06, 1573.30s/it]

Completed inference for planet id 20211028_012421_73_245f


  means = sums / counts
 92%|█████████▏| 48/52 [19:51:27<1:30:00, 1350.22s/it]

Completed inference for planet id 20211028_045455_02_2459


  means = sums / counts
 94%|█████████▍| 49/52 [20:04:53<59:21, 1187.07s/it]  

Completed inference for planet id 20211028_063039_20_2429


  means = sums / counts
 96%|█████████▌| 50/52 [20:42:50<50:27, 1513.98s/it]

Completed inference for planet id 20211028_134803_20_227a


  means = sums / counts
 98%|█████████▊| 51/52 [21:27:57<31:11, 1871.88s/it]

Completed inference for planet id 20211028_144231_39_227b


  means = sums / counts
100%|██████████| 52/52 [22:14:55<00:00, 1540.29s/it]

Completed inference for planet id 20211030_142613_41_227b





In [30]:
rf.feature_importances_

array([0.03855211, 0.3848263 , 0.03254691, 0.21088557, 0.12055762,
       0.01563617, 0.02419592, 0.01149294, 0.08922672, 0.07207974])