### Load model

In [44]:
MODEL_PATH = '../Models/Pres_hybrid_CF.h5'
# DATA_PATH =  '../Datasets/Dataset_train_CF.csv'
DATA_PATH = '../Datasets/dataset_reduced.csv'

SEED = 32

In [45]:
from tensorflow.keras.models import load_model

model = load_model(MODEL_PATH)

### Read data and predict values

In [46]:
import pandas as pd

data = pd.read_csv(DATA_PATH, index_col=0)

# X = data.drop('P_res', axis=1)
X = data

y_pred_nn = model.predict(X)



### Train decision tree

In [47]:
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.metrics import make_scorer, mean_squared_error


# X_train, X_test, y_train, y_test = train_test_split(X, y_pred_nn, test_size=0.2, random_state=SEED)
X_train = X
y_train = y_pred_nn

tree = DecisionTreeRegressor(random_state=SEED)

tree_params = {
    "max_depth": [4, 16, 32, 64],
    "criterion": ['squared_error'],
    "min_samples_split": [0.01],
    "min_samples_leaf": [0.01],
}

scoring = make_scorer(mean_squared_error, greater_is_better=False)
grid = GridSearchCV(tree, tree_params, scoring=scoring, n_jobs=-1, refit=True, cv=5, verbose=1)

grid.fit(X_train, y_train)

model_tree = grid.best_estimator_


Fitting 5 folds for each of 4 candidates, totalling 20 fits


### Get results

In [48]:
from sklearn.metrics import mean_squared_error
from sklearn.tree import export_text
import numpy as np

# y_pred_tree = model_tree.predict(X_test)
y_pred_tree = model_tree.predict(X_train)

# mse_tree = mean_squared_error(y_test, y_pred_tree)
mse_tree = mean_squared_error(y_pred_nn, y_pred_tree)

# Calcular RMSE
rmse_tree = np.sqrt(mse_tree)
print(f"RMSE tree: {rmse_tree}")

tree_rules = export_text(model_tree, feature_names=list(X.columns))
print("Tree rules:\n", tree_rules)


RMSE tree: 0.5529425735050145
Tree rules:
 |--- hub_temperature <= -0.19
|   |--- V <= 0.20
|   |   |--- Out_temperature <= -0.83
|   |   |   |--- V <= -0.46
|   |   |   |   |--- V <= -0.87
|   |   |   |   |   |--- V <= -1.07
|   |   |   |   |   |   |--- value: [0.06]
|   |   |   |   |   |--- V >  -1.07
|   |   |   |   |   |   |--- value: [0.10]
|   |   |   |   |--- V >  -0.87
|   |   |   |   |   |--- Out_temperature <= -1.44
|   |   |   |   |   |   |--- value: [0.23]
|   |   |   |   |   |--- Out_temperature >  -1.44
|   |   |   |   |   |   |--- Out_temperature <= -1.15
|   |   |   |   |   |   |   |--- value: [0.17]
|   |   |   |   |   |   |--- Out_temperature >  -1.15
|   |   |   |   |   |   |   |--- value: [0.13]
|   |   |   |--- V >  -0.46
|   |   |   |   |--- Out_temperature <= -1.39
|   |   |   |   |   |--- V <= -0.17
|   |   |   |   |   |   |--- value: [0.33]
|   |   |   |   |   |--- V >  -0.17
|   |   |   |   |   |   |--- value: [0.46]
|   |   |   |   |--- Out_temperature >  -1.