In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np

# Load the fitted LightGBM model
model_path = "../models/bikes_model.pkl"
with open(model_path, "rb") as file:
    lgbm_model = pickle.load(file)
    assert (
        lgbm_model.__class__.__name__ == "LGBMQuantileForecaster"
    ), "only lgbm model supported"

# Extract individual quantile models from the MultipleQuantileRegressor
quantile_models = lgbm_model.model.regressors_  # Dictionary of models keyed by quantile

# Step 1: Feature Importance Analysis
feature_importances = []

for quantile, model in quantile_models.items():
    importance = (
        model.feature_importances_
    )  # Get feature importance from each quantile-specific model
    feature_importances.append(importance)

# Calculate average feature importance across all quantile models
average_importance = np.mean(feature_importances, axis=0)
feature_names = lgbm_model.feature_names_in_

# Sort feature importances
sorted_indices = np.argsort(average_importance)[::-1]
sorted_feature_names = [feature_names[i] for i in sorted_indices]
sorted_importances = average_importance[sorted_indices]

n_max = 55
# Plot feature importances
plt.figure(figsize=(8, 10))
plt.barh(sorted_feature_names[:n_max], sorted_importances[:n_max], align="center")
plt.xlabel("Average Feature Importance")
plt.ylabel("Feature")
plt.title("Feature Importance (Averaged Across Quantile-Levels)")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

In [None]:
median_model = lgbm_model.model.regressors_[0.5]
bst = median_model._Booster

len(median_model.feature_importances_)

In [None]:
from lightgbm import plot_importance, plot_split_value_histogram, plot_tree

plot_split_value_histogram(bst, feature="precipitation_sum")

In [None]:
plot_importance(
    lgbm_model.model.regressors_[0.975]._Booster,
    importance_type="gain",
    max_num_features=20,
)

In [None]:
plot_importance(bst, importance_type="gain", max_num_features=20)

In [None]:
plot_tree(bst, figsize=(20, 20), tree_index=0)

In [None]:
plot_tree(bst, figsize=(20, 20), tree_index=249)