# ABOUT

__Author__: Pat McCornack

__Date__: 5/28/2024

__Purpose__: Update fuelscape data to reflect most recent disturbances on the landscape. This script will both train models to predict FVT/FVC/FVH/F40 and apply those models to generated predicted rasters of each target. Alternatively, the user can specify existing models to be used to generate the predicted rasters. Inputs are 1. Paths to sample point layers to train the models, 2. paths to LANDFIRE input raster datasets. 

__Instructions__:
If generating new models and applying them to predict rasters then update the file paths then hit 'Run All'. This will train new models and apply them to generate predicted rasters. Predicted FVT/FVC/FVH rasters are used as inputs to predict F40. 

If using existing models, then update the file paths then see step 3. below. This will still use the predicted FVT/FVC/FVH rasters as inputs to predict F40. 

1. Update paths if necessary: 
- Update the sample points file paths
- Update the model file paths
- Update the raster_fpaths_dict with most recent rasters

__Main__
2. Train models: If creating new models, then run cell under _Train Models_. By default it will create a model for each FVT/FVC/FVH/F40. Modify 'targets' if only creating a subset of these models. 

3. Predict rasters: If using existing models then specify in 'model_fpaths_dict' under _Run Model Predictions_ in the form 'target' : 'path' (e.g. 'LF22_FVT' : 'path_to_model'). Otherwise simply run the cell.  



------

In [2]:
import os
import glob

import datetime as dt
from joblib import dump, load, Parallel, delayed

import numpy as np
import pandas as pd
import geopandas as gpd

import rasterio
from rasterio.merge import merge
from rasterio.windows import Window

from sklearn.ensemble import HistGradientBoostingClassifier


## Define Paths
Define set of filepaths to conveniently switch between working off local files or the PNNL drive. Set active_data_dir to either local_data_dir or pnnl_data_dir depending on which you're working off of. 

In [50]:
local_root_dir = r"C:\Users\mcco573\OneDrive - PNNL\Documents\_Projects\BPA Wildfire\fuelscape_modeling"
pnnl_root_dir = r"\\pnl\projects\BPAWildfire\data\Landfire\fuels_modeling\fuelscape_modeling"

## ! Specify these ##

# Define which data directory to work off of
active_root_dir = local_root_dir

# Define paths to sample points
sample_points_dir = os.path.join(active_root_dir, r"sample_points")
sample_points_fnames = {
    "LF22_F40" : "LF22_F40_sample_points_2024-05-29_100k.shp",
    "LF22_FVT" : "LF22_FVT_Disturbed_sample_points_2024-05-29_200k.shp",
    "LF22_FVC" : "LF22_FVT_Disturbed_sample_points_2024-05-29_200k.shp",
    "LF22_FVH" : "LF22_FVT_Disturbed_sample_points_2024-05-29_200k.shp"
}

# Define paths to models
models_dir = os.path.join(active_root_dir, "models")
models_fpath_dict = {
    'LF22_F40' : os.path.join(models_dir, "LF22_F40_model_2024-05-29_09-28-44"),
    'LF22_FVT' : os.path.join(models_dir, "LF22_FVT_HGBC_model_2024-05-29_09-37-32"),
    'LF22_FVC' : os.path.join(models_dir, "LF22_FVC_HGBC_model_2024-05-29_09-31-57"),
    'LF22_FVH' : os.path.join(models_dir, "LF22_FVH_HGBC_model_2024-05-29_09-35-38"),
}


## ! These likely don't need to be updated ##

# Directory paths
paths_dict = {
    "out_base_dir" : os.path.join(active_root_dir, r"model_outputs\tabular"),  # Where to save result outputs 
    "ref_data_dir" : os.path.join(active_root_dir, r"..\LF_raster_data\_tables"),  # Location of LF csvs (e.g. LF22_FVT_230.csv)
}

# Define the source raster file paths
data_dir =  os.path.join(active_root_dir, r'..\LF_raster_data\bpa_service_territory')
ref_data_dir = os.path.join(data_dir, r"..\_tables")
raster_fpaths_dict = {
    "LF20_F40" : os.path.join(data_dir, "LC22_F40_220_bpa.tif"),
    "LF20_FVT" : os.path.join(data_dir, "LC22_FVT_220_bpa.tif"),
    "LF22_FVT" : os.path.join(data_dir, "LC22_FVT_230_bpa.tif"),
    "LF20_FVC" : os.path.join(data_dir, "LC22_FVC_220_bpa.tif"),
    "LF22_FVC" : os.path.join(data_dir, "LC22_FVC_230_bpa.tif"),
    "LF20_FVH" : os.path.join(data_dir, "LC22_FVH_220_bpa.tif"),
    "LF22_FVH" : os.path.join(data_dir, "LC22_FVH_230_bpa.tif"),
    "LF22_FDST" : os.path.join(data_dir, "LC22_FDst_230_bpa.tif"),
    "BPS" : os.path.join(data_dir, "LC20_BPS_220_bpa.tif"),
    "ZONE" : os.path.join(data_dir, "us_lf_zones_bpa.tif"),
    "ASPECT" : os.path.join(data_dir, "LC20_Asp_220_bpa.tif"),
    "SLOPE" : os.path.join(data_dir, "LC20_SlpD_220_bpa.tif"),
    "ELEVATION" : os.path.join(data_dir, "LC20_Elev_220_bpa.tif"),
    "BPS_FRG_NE" : os.path.join(data_dir, "BPS_FRG_NEW.tif")
}

out_chunk_dir = os.path.join(active_root_dir, r"outputs\geospatial")
out_raster_dir = os.path.join(data_dir, r'_predicted_rasters')
out_fname_dict = {
    'LF22_F40' : 'Pred_LF22_F40',
    'LF22_FVT' : 'Pred_LF22_FVT',
    'LF22_FVC' : 'Pred_LF22_FVC',
    'LF22_FVH' : 'Pred_LF22_FVH'
}

# __Functions__
----


## __Helper Functions__

### Create Output Directory
Names the output directory using the datetime that the script was run. 
Returns the name of the directory. The returned directory is used to output the trained model and/or results. 

In [51]:
def make_dir(base_dir, new_dir_name='model_results'):
        """
        Returns path to a directory created at the specified base_dir location. 

        The name of the created directory can optionally be specified using the dir_name argument. 
        """

        datetime = dt.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        output_dir = os.path.join(base_dir, new_dir_name + "_" + datetime)

        os.makedirs(output_dir)
        return output_dir

### Create Data Dictionaries to Append Features
Some features are separate attributes of the LANDFIRE dataset (e.g. BPS Fire Regime) and others are useful for results analysis (e.g. FDst attributes). These can be mapped to points using LANDFIRE CSVs. The below creates dictionaries to perform that mapping. 

This function is called by join_features.

In [52]:
def read_ref_data(ref_data_dir=paths_dict["ref_data_dir"]):
    """
    Returns a dictionary of dictionaries of mappings between LANDFIRE raster values and other attributes associated with those values. 
    """
    data_dir = ref_data_dir
    BPS_fname = "LF20_BPS_220.csv"
   
    # Create empty dictionary
    LF_ref_dicts = {}

    # Get BPS reference dictionary
    BPS_df = pd.read_csv(os.path.join(data_dir, BPS_fname))
    LF_ref_dicts["BPS_NAME"] = dict(BPS_df[['VALUE', 'BPS_NAME']].values)
    LF_ref_dicts["BPS_FRG_NE"] = dict(BPS_df[['VALUE', 'FRG_NEW']].values)

    return LF_ref_dicts
                         
        

### Append Features using Data Dictionaries
Append in selected features using the LANDFIRE data dictionaries.

__Note:__ Items in feature_list must be in the source_layers dictionary. 

In [53]:
def join_features(sample_points, feature_list = ['BPS_NAME']):
    """
    Returns the sample_points layer with the features in feature_list appended. 

    Items in feature_list must be in the source_layers dictionary.  
    """
    
    LF_ref_dicts = read_ref_data()
    
    source_layers = {
        'BPS_NAME' : 'BPS',
        'BPS_FRG_NE' : 'BPS',
    }

    # Iterate through feature_list and append features to sample_points
    for feature in feature_list:
        sample_points[feature] = sample_points[source_layers[feature]].map(LF_ref_dicts[feature]).copy()

    return sample_points

### Low Count Filter
Some classes have exceedingly low representation in the dataset. In order to run train_test_split, the target class must have a count greater than 1. The following will filter out classes with counts of 1. Given the size of the dataset excluding these classes will not detrimentally impact model performance.

In [54]:
def low_count_filter(sample_points, target):  
    """
    Returns modified sample_points dataframe. Classes in the target feature that have counts of 1 are removed. 
    """  
    # Group by target
    obs_counts = sample_points.groupby(target).count()

    # Identify classes with low observation counts
    low_count_classes = obs_counts[obs_counts.iloc[:,0] < 5].index.tolist()

    # Remove those classes from sample_points
    sample_points = sample_points.loc[~sample_points[target].isin(low_count_classes)]

    return sample_points

## __Train Model Functions__

### Prepare the Sample Points

In [55]:
def model_data_prep(sample_point_fpath, target):
    """
    Returns dataframe of processed sample points. Sample points are read in from shapefile. 
    """
    # Read in gdf
    sample_points = gpd.read_file(sample_point_fpath)

    # Filter out near points
    sample_points = sample_points.loc[sample_points['NEAR_FID'] == -1]

    # Drop unneeeded columns if present
    sample_points = sample_points.drop(['Classified', 'GrndTruth', 'NEAR_FID', 'NEAR_DIST'], axis=1,
                                       errors='ignore')
    
    # Remove observations with -9999/-1111/-32768 in any field 
    sample_points = sample_points.loc[~sample_points.isin([-1111, -9999, -32768]).any(axis=1)]

    # Join in additional features
    sample_points = join_features(sample_points)

    # Drop non-disturbed values if not predicting F40
    if target != 'LF22_F40':     
        # Filter out points that weren't disturbed
        sample_points = sample_points.loc[sample_points['LF22_FDST'] != 0]

    # Filter classes based on target
    if target == 'LF22_F40':
        F40_NB = [91, 93]
        sample_points = sample_points.loc[~sample_points['LF22_F40'].isin(F40_NB)]

    elif target == 'LF22_FVT':
        developed_fvt = list(range(20,33)) + list(range(2901,2906))
        ag_fvt = [80, 81, 82] + list(range(2960, 2971))
        fvt_filter = developed_fvt + ag_fvt
        sample_points = sample_points.loc[~sample_points['LF22_FVT'].isin(fvt_filter)]

    elif target == 'LF22_FVH' or target == 'LF22_FVC':
        filter = list(range(20, 70)) + list(range(80, 86))
        sample_points = sample_points.loc[~sample_points[target].isin(filter)]

    # Filter out classes in the target feature with very low counts
    sample_points = low_count_filter(sample_points, target)

    return sample_points


### Instantiate Histogram Based Gradient Boosting Classifier
Scikit-learn implementation of Histogram Based Gradient Boosting Classifier

In [56]:
def histGradientBoostingClassifier(categorical_feature_list, class_weight='balanced', seed=1234):
    """
    Returns specified histogram-based gradient boosting classifier. 
    """

    hgb_classifier = HistGradientBoostingClassifier(
        categorical_features=categorical_feature_list,  # Natively handle categorical variables
        class_weight=class_weight,
        random_state=seed,
        learning_rate=0.01,
        max_iter=100
    )

    return hgb_classifier

### Train Model
Trains and returns HGBC model. Subsets the training data to predictors based on the target model and predicts the specified target feature.

__Note__: Aspect, Elevation, and Slope are the only continuous LANDFIRE datasets, therefore any feature not in that list is assumed to be categorical.

In [57]:
def fit_model(train_data, predictors, target):
    """
    Returns trained model of specified type given a target. 
    """

    class_weight = "balanced"

    # Get list of predictors for run
    cat_features = [x for x in predictors if x not in ['ASPECT', 'ELEVATION', 'SLOPE']]  # Categorical features for HGBC models  
    
    # Separate training data predictors/response
    y_train = train_data[target].copy()
    X_train = train_data[predictors].copy()

    # Instantiate the model
    model = histGradientBoostingClassifier(categorical_feature_list=cat_features, class_weight=class_weight)
      
    # Fit the model with the training data
    print(f'Fitting {target} Model...')
    model.fit(X_train, y_train)

    # Return the fit model
    return model         
    

### Train model wrapper function 

In [59]:
def train_model_wrapper(target, sample_points_fname):
    """
    Trains and saves model based on specified target and returns path to saved model.
    """
    # Specify predictors based on target. 
    predictors_dict = {
        'LF22_FVT' : ['LF20_FVT', 'LF22_FDST', 'ZONE', 'ASPECT', 'SLOPE', 'ELEVATION', 'BPS_NAME', 'LF20_FVC', 'LF20_FVH'],
        'LF22_FVC' : ['LF22_FVT', 'LF20_FVC', 'LF22_FDST', 'ZONE', 'BPS_NAME'],
        'LF22_FVH' : ['LF22_FVT', 'LF20_FVH', 'LF22_FDST', 'ZONE', 'BPS_NAME'],
        'LF22_F40' : ['LF22_FVT', 'LF22_FVH', 'LF22_FVC', 'LF22_FDST', 'ZONE', 'BPS_FRG_NE']
    }

    # Define path to sample points
    sample_points_fpath =  os.path.join(paths_dict['sample_points_dir'], sample_points_fname)

    # Read in and prepare data
    sample_points = model_data_prep(sample_points_fpath, target)

    # Train model
    model = fit_model(train_data=sample_points, predictors=predictors_dict[target], target=target)

    # Save model 
    datetime = dt.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    model_fname = f"{target}_HGBC_model_{datetime}"
    model_fpath = os.path.join(paths_dict['model_dir'], model_fname)
    dump(model, model_fpath)
    print(f"{target} model written to: {model_fpath}")
    return model_fpath

## __Raster Predict Functions__

### Preprocess window dataframe
Prepares the data to be run through the model. Separates out null and non-burnable values. Returns a dictionary with: 
1. A clean dataframe to be run through the model. 
2. A dataframe of the dropped observations to be rejoined to model predictions. This allows for the data to be reshaped to a 2D numpy array and written as a raster. 


In [None]:
def predict_data_prep(df, target):
    """
    Prepares data to run model on.
    Input: 
    - A dataframe containing created using the flattened numpy arrays returned from reading in the raster chunks. 
    Returns: 
    - Clean dataframe without NULL data or agricultural/developed/nonburnable classes.
    - Dataframe with the dropped observations - to be reappended after prediction
    """

    # Based on the target, rename the dropped LF column to Pred column
    column_dict = {
        "LF22_F40" : {'LF20_F40' : 'F40_Pred'},
        "LF22_FVC" : {"LF20_FVC" : "FVC_Pred"},
        "LF22_FVH" : {"LF20_FVH" : "FVH_Pred"},
        "LF22_FVT" : {"LF20_FVT" : "FVT_Pred"}
    }

    pred_dict = {
        "LF22_F40" : "F40_Pred",
        "LF22_FVC" : "FVC_Pred",
        "LF22_FVH" : "FVH_Pred",
        "LF22_FVT" : "FVT_Pred"
    }

    # Remove -9999/-1111/-32768 (Null values)
    null = df[(df.isin([-1111, -9999, -32768])).any(axis=1)] 
    df = df.drop(null.index, axis=0) 

    # If not predicting F40 - join in BPS_NAME as predictor and drop undisturbed points
    # Drop non-disturbed values if not predicting F40
    if target != 'LF22_F40':
        df = join_features(df, feature_list = ['BPS_NAME'])
        non_disturbed = df.loc[df['LF22_FDST'] == 0]
        df = df.drop(non_disturbed.index, axis=0)

    # Filter classes based on the target
    if target == 'LF22_F40':
        F40_NB = [91, 93]  # Nonburnable F40 Classes
        filtered = df.loc[df['LF20_F40'].isin(F40_NB)]  # Drop NB classes
        df = df.drop(filtered.index, axis=0)
        
    elif target == 'LF22_FVC':
        fvc_filter = list(range(20, 70)) + list(range(80, 86)) # Agricultural/Developed classes
        filtered = df.loc[df['LF20_FVC'].isin(fvc_filter)]
        df = df.drop(filtered.index, axis=0)
    
    elif target == 'LF22_FVH':
        fvh_filter = list(range(20, 70)) + list(range(80, 86)) # Agricultural/Developed classes
        filtered = df.loc[df['LF20_FVH'].isin(fvh_filter)]
        df = df.drop(filtered.index, axis=0)

    elif target == 'LF22_FVT':
        developed_fvt = list(range(20,33)) + list(range(2901,2906)) # Agricultural/Developed classes
        ag_fvt = [80, 81, 82] + list(range(2960, 2971))
        fvt_filter = developed_fvt + ag_fvt
        filtered = df.loc[df['LF20_FVT'].isin(fvt_filter)]
        df = df.drop(filtered.index, axis=0)

    # Join the filtered values together so they can be readded later
    if target == 'LF22_F40':
        dropped = pd.concat([null, filtered], axis=0)
        dropped = dropped.rename(columns=column_dict[target])
    else:
        dropped = pd.concat([null, non_disturbed, filtered], axis=0)
        dropped = dropped.rename(columns=column_dict[target])

    
    return df, dropped[pred_dict[target]]


### Window Read Function
This function is used to read in chunks of the rull raster to be processed. The raster data is stored as blocks with height=1 and width=raster width, so we read in chunks composed of these blocks (e.g. 1000 blocks at a time). This corresponds to the row_slice argument. 

In [None]:
def windowed_read(ras, row_slice):
  """
  Reads in a subset (window) of the data to be processed.
  Inputs:
  - ras: Raster object read in using rasterio
  - row_slice: Used to define the window height - in the form (row_start, row_end). 

  Returns:
  - data: The data in the window as a 2D numpy array
  - win: The Window object used to define the subset of the data. 
  - win_transform: The affine transform associated with the window. Used to update the metadata of the output of that chunk. 
  """
  
  with rasterio.open(ras) as src:
    col_slice = (0, src.width)  # Define row slice based on block size
    win = Window.from_slices(row_slice, col_slice) 
    data = src.read(window=win)
    win_transform = src.window_transform(win)
    
  return data, win, win_transform

### Apply Trained Model 
The model was not trained using null (-9999, -1111, -32768), non-disturbed, or agricultural/developed classes.

In [None]:
def predict_raster(model, df, target):
    """
    Predicts target class given a trained model and data to predict on.
    Returns:
    - A dataframe of predicted target values joined to the previously dropped values. 
    """
    # Specify output variable name as function of target
    var_name_dict = {
        "LF22_F40" : "F40_Pred",
        "LF22_FVC" : "FVC_Pred",
        "LF22_FVH" : "FVH_Pred",
        "LF22_FVT" : "FVT_Pred",
    }

    # Prep the data - get the a clean dataframe and dropped observations
    clean_df, dropped = predict_data_prep(df, target)

    # Get list of predictors for run
    predictors = list(model.feature_names_in_)
    print(clean_df)
    
    X = clean_df[predictors].copy()

    # Run model to predict
    # If clean_df is empty, then all values were NULL and are in dropped
    if clean_df.shape[0] > 0:
        y_pred = model.predict(X)
    else:
        return dropped
    
    # Join the dropped observations back in 
    # This allows the result dataframe to be reshaped back to a raster
    df = pd.DataFrame({var_name_dict[target] : y_pred},
                       index=X.index)
    df = pd.concat([dropped, df])
    df.sort_index(inplace=True)

    print(df.head())

    # Return predictions
    return df


### Apply Model to Window

In [None]:
def window_predict(row, target, model_fpath, win_height, raster_dict, ras_shape, out_dir, out_meta):

    model = load(model_fpath)

    row_start = row
    row_end = row + win_height  # This is also the row_offset of the window

    # make sure slice doesn't exceed row/col dims
    if row_end > ras_shape[0]:
        row_end = ras_shape[0]

    # Define the window to be processed
    row_slice = (row_start, row_end)

    data_dict = {}

    # For the current window, load data from each rasters
    for var, fpath in raster_dict.items():
        data_chunk, data_win, data_transform = windowed_read(fpath, row_slice)
        data_dict[var] = data_chunk.ravel()

    # Create a dataframe from the dictionary of datachunks
    df = pd.DataFrame(data_dict)

    # Look at the window currently processing
    clear_output()

    datetime = dt.datetime.now().strftime('%H-%M-%S-%f')
    out_file = f"data_chunk_{datetime}.tif"
    #print(f"Row Slice: {row_slice}")  
    #print(f"Writing {out_file}.")
    #print(f"Processing window {i} of {floor(ras.shape[0] / win_height)}")

    # Run model to predict F40 Classes for window 
    out_arr = predict_raster(model, df, target)

    # Reshape to 2D
    out_arr_np = out_arr.to_numpy()
    out_arr_2D = out_arr_np.reshape(data_chunk.shape)
    out_arr_2D = out_arr_2D[0]

    # update output metadata for chunk
    out_meta.update({
        'height': out_arr_2D.shape[0],
        'width': out_arr_2D.shape[1],
        'transform' : data_transform
    })

    # Write chunk out
    datetime = dt.datetime.now().strftime('%H-%M-%S-%f')
    out_file = f"data_chunk_{datetime}.tif"
    with rasterio.open(os.path.join(out_dir, out_file), 'w+', **out_meta) as out:
        out.write(out_arr_2D, indexes=1)

### Run the model to predict raster
Applies the model to chunks of the raster then mosaics those chunks, saves the resultant predicted raster, and returns the path to the saved raster. 

In [None]:
def predict_wrapper(target, raster_fpaths_dict, model_fpath):
  raster_fpaths = raster_fpaths_dict

  # Arbitrarily grab metadata from raster to use for updating output metdata
  with rasterio.open(raster_fpaths['LF20_F40']) as src:
    out_meta = src.meta.copy()

  # Open a raster to access its attributes
  ras = rasterio.open(raster_fpaths['LF20_F40'])

  # Create directory using current datetime to output data chunks to
  datetime = dt.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')  # Used to name output file
  out_dir = make_dir(base_dir=out_chunk_dir, file_name=out_fname_dict[target])

  # Define window height and iteration tracker
  win_height = 1000  # Number of rows to process at once

  Parallel(n_jobs=24)(delayed(window_predict)(row, target, model_fpath, win_height, raster_fpaths, ras.shape, out_dir, out_meta) for row in range(0, ras.shape[0], win_height))


  ## Mosaic the data chunks
  # Get the file paths of the generated data chunks
  raster_fpaths = glob.glob(out_dir + "/*.tif")

  # Get the rasterio dataset objects corresponding to each path
  src_files_to_mosaic = []
  for fpath in raster_fpaths:
      src = rasterio.open(fpath)
      src_files_to_mosaic.append(src)

  # Merge the data chunks into a single raster
  mosaic, out_trans = merge(src_files_to_mosaic)

  # Get the metadata for writing
  out_meta = src.meta.copy()
  out_meta.update({
      "driver" : "GTiff",
      "height" : mosaic.shape[1],
      "width" : mosaic.shape[2],
      "transform" : out_trans
  })

  # Write out the mosaic raster
  fname = f"{out_fname_dict[target] + "_" + datetime}.tif"
  with rasterio.open(os.path.join(out_raster_dir, fname), "w", **out_meta) as dest:
      dest.write(mosaic)

  print(f"Raster written to {os.path.join(out_raster_dir, fname)}")
  return os.path.join(out_raster_dir, fname)


# __Main__

----

## __Train Models__
Trains the models in the targets list, saves them, and stores the paths in model_fpaths_dict.

In [60]:
# Specify the models to train
targets = ["LF22_FVT", "LF22_FVC", "LF22_FVH", "LF22_F40"]
model_fpaths_dict = {}

# Train the models
for target in targets:
    model_fpaths_dict[target] = train_model_wrapper(target, sample_points_fnames[target])

print(model_fpaths_dict)

Fitting LF22_FVC Model...
LF22_FVC model written to: C:\Users\mcco573\OneDrive - PNNL\Documents\_Projects\BPA Wildfire\fuelscape_modeling\models\LF22_FVC_HGBC_model_2024-05-29_12-06-42
Fitting LF22_FVH Model...
LF22_FVH model written to: C:\Users\mcco573\OneDrive - PNNL\Documents\_Projects\BPA Wildfire\fuelscape_modeling\models\LF22_FVH_HGBC_model_2024-05-29_12-07-58
Fitting LF22_F40 Model...
LF22_F40 model written to: C:\Users\mcco573\OneDrive - PNNL\Documents\_Projects\BPA Wildfire\fuelscape_modeling\models\LF22_F40_HGBC_model_2024-05-29_12-08-48


## __Run Model Predictions__
Generate predictions for FVT/FVC/FVH/F40. Predicted FVT is used as an input to predict FVC/FVH. Predicted FVT/FVC/FVH are used as inputs to predict F40.

In [None]:
## ! If using pre-trained models, specify them here (key=target, value=model path):
#model_fpaths_dict = {}

## Run models
# Predict FVT
pred_fvt_path = predict_wrapper('LF22_FVT',
                                raster_fpaths_dict=raster_fpaths_dict,
                                model_fpath=model_fpaths_dict['LF22_FVT'])

# Save path to predicted FVT to dict
raster_fpaths_dict['LF22_FVT'] = pred_fvt_path

# Predict FVC/FVH
targets = ["LF22_FVH", "LF22_FVC"]
pred_fpaths = {}

for target in targets:
    print(f"Processing {target}...")
    pred_fpaths[target] = predict_wrapper(target=target,
                                          raster_fpaths_dict=raster_fpaths_dict,
                                          model_fpath=model_fpaths_dict[target])
    
# Update raster_fpaths using the predicted FVC/FVH
for raster, fpath in pred_fpaths.items():
    raster_fpaths_dict[raster] = fpath

# Predict F40
print(raster_fpaths_dict)
print("Processing F40...")
predict_wrapper(target="LF22_F40",
                raster_fpaths_dict=raster_fpaths_dict,
                model_fpath=model_fpaths_dict[target])