In [None]:
import os
import time
import pandas as pd

from sklearn.ensemble import RandomForestClassifier
# wrapper so that model predicts using dask
from dask_ml.wrappers import ParallelPostFit  

from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt

from joblib import dump

import model_prep_and_evals as mpe 

In [None]:
# **************************************************************
# whole_set = True => merge train+test sets and train model with combined dataset
# whole_set = False => train model only with train set

whole_set = False

#root = os.path.join(os.getcwd(), 'processing_results', 'glcm_model_2020')
model_name = 'model_feb14'
root = os.path.join(os.getcwd(),'models',model_name)
train_name = 'model_feb14_train.csv'
test_name = 'model_feb14_test.csv'

# root = os.getcwd()
# train_name = 'glcm_spectral_window_model3070FP_train_2020.csv'
# test_name = 'glcm_spectral_window_model3070FP_test_2020.csv'

# first_feature = 'r'
# last_feature = 'nir_corrE'
label_name = 'iceplant'

# drops this feature
filter_year = False
year = 2020

save_model = True
model_name = 'model_feb14_rfc'

calculate_feature_importance = True

In [None]:
cols = ['r', 
       # 'r_max7', 'r_min7', 
       # 'r_avg7', 'r_entr7', 
        'r_max11', 'r_min11', 
        'r_avg11', 
        'r_entr11', 
        'g',
       # 'g_max7', 'g_min7', 
       # 'g_avg7', 'g_entr7', 
        'g_max11', 'g_min11', 
        'g_avg11', 
        'g_entr11',        
        'b',
       # 'b_max7', 'b_min7', 
       # 'b_avg7', 'b_entr7', 
        'b_max11', 'b_min11', 
        'b_avg11', 
        'b_entr11',        
        'nir',
       # 'nir_max7', 'nir_min7', 
       # 'nir_avg7','nir_entr7', 
        'nir_max11', 'nir_min11', 
        'nir_avg11', 
        'nir_entr11',         
        'ndvi',
       # 'ndvi_max7', 'ndvi_min7', 
       # 'ndvi_avg7', 'ndvi_entr7',
        'ndvi_max11', 'ndvi_min11', 
        'ndvi_avg11', 
        'ndvi_entr11',
        'month', 
        'day_in_year']


In [None]:
# ------------------------------
# IMPORT TRAIN DATA
X_train = pd.read_csv(os.path.join(root, train_name))#.loc[:, first_feature:last_feature]
y_train = pd.read_csv(os.path.join(root, train_name)).loc[:,label_name] 

# ------------------------------
# IMPORT TEST DATA
X_test = pd.read_csv(os.path.join(root, test_name))#.loc[:, first_feature:last_feature]
y_test = pd.read_csv(os.path.join(root, test_name)).loc[:,label_name] 



X_test = X_test[cols]
X_train = X_train[cols] 

In [None]:
# ------------------------------
if filter_year:
    
    X_train = X_train.loc[X_train.year == year]
    X_train = X_train.drop(['year'], axis =1)
    y_train = y_train.iloc[X_train.index]

    
    X_test = X_test.loc[X_test.year == year]
    X_test = X_test.drop(['year'], axis =1)
    y_test = y_test.iloc[X_test.index]
    
# ------------------------------
if whole_set == True:
    X_train = pd.concat([X_train, X_test], axis = 0)
    y_train = pd.concat([y_train, y_test], axis = 0)


X_test.columns == X_train.columns

In [None]:
X_train.columns

In [None]:
mpe.test_train_proportions(y_train, y_test)

In [None]:
# ------------------------------
X_train = X_train.to_numpy()
y_train = y_train.to_numpy()

In [None]:
t0 = time.time()
rfc = ParallelPostFit(RandomForestClassifier(n_estimators = 100, 
                                             random_state = 42))
rfc.fit(X_train, y_train)
print(time.time() - t0)

if save_model:
    dump(rfc, model_name +'.joblib')

In [None]:
if whole_set == False:
    preds = rfc.predict(X_test.to_numpy())
    mpe.print_accuracy_info(y_test.to_numpy(), preds)

In [None]:
result = permutation_importance(
    rfc, 
    X_test.to_numpy(),
    y_test.to_numpy(),
    n_repeats=10, 
    random_state=42, 
    n_jobs=2
)

forest_importances = pd.Series(result.importances_mean, index=X_test.columns)

fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=result.importances_std, ax=ax)
ax.set_title("Feature importances using permutation on full model")
ax.set_ylabel("Mean accuracy decrease")
fig.tight_layout()
plt.show()