# Train a ML model for trait retrieval 
Based on simulated spectra from RTMs

In [49]:
import sys
import os
from pathlib import Path
import pickle
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
import numpy as np

base_dir = Path(os.path.dirname(os.path.realpath("__file__"))).parent

## 1. Get data
Load LUT tables in a dataframe and split into train and test (70/30 split)

In [22]:
safe_dir = base_dir.joinpath(Path('results/lut_based_inversion/eschikon'))

In [41]:
df = pd.DataFrame()

for scene_dir in safe_dir.glob('*.SAFE'):
  # Load LUT into dataframe
  with open(scene_dir.joinpath('all_phases_lai-cab-ccc-car_lut.pkl'), 'rb') as f:
    df = pd.concat([df, pickle.load(f)])

In [42]:
df.drop(columns=['n', 'cab', 'car', 'ant', 'cbrown', 'cw', 'cm', 'lidfa', 'hspot',
       'rsoil', 'psoil', 'lidfb', 'tts', 'tto', 'psi', 'typelidf', 'ccc'], inplace=True)

In [43]:
# Splitting the DataFrame into features (X) and target variable (y)
X = df.drop('lai', axis=1)
y = df['lai']

# Splitting the data into 70% training and 30% testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

## 2. Set up model: Random Forest

In [58]:
random_forest_regressor = RandomForestRegressor(n_estimators=100, random_state=42)
random_forest_regressor.fit(X_train, y_train)

y_pred = random_forest_regressor.predict(X_test)

# Evaluate the performance of the model (for example, using Mean Squared Error)
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
print(f"Mean Squared Error: {mse:.2f}")
print(f"Root Mean Squared Error: {rmse:.2f}")


Mean Squared Error: 0.35
Root Mean Squared Error: 0.59


In [59]:
# Save model

model_filename = 'random_forest_model.pkl'
with open(model_filename, 'wb') as model_file:
    pickle.dump(random_forest_regressor, model_file)

"""
with open(model_filename, 'rb') as model_file:
    loaded_rf_model = pickle.load(model_file)
"""

"\nwith open(model_filename, 'rb') as model_file:\n    loaded_rf_model = pickle.load(model_file)\n"

## 3. Inference: get LAI of (new) images

## 4. Improve model
- Standardise input data
- K-fold cross validation
- Hyperparameter tuning

analyse how many min traning data