In [7]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split

# ==============================
# CONFIGURATION
# ==============================
DATASETS = {
    "PRF-Landsat5-9": "PRF-Landsat5-9",
    "PRF-Landsat5-54": "PRF-Landsat5-54",
    "PRF-LiDAR-36": "PRF-LiDAR-36",
    "PRF-LiDAR-702": "PRF-LiDAR-702",
}

TARGET_COL = "Target"
TOP_K = {
    "PRF-Landsat5-9": 9,
    "PRF-Landsat5-54": 20,
    "PRF-LiDAR-36": 10,
    "PRF-LiDAR-702": 25,
}

OUTPUT_DIR = "Exploratory_Analysis"
os.makedirs(OUTPUT_DIR, exist_ok=True)

RANDOM_STATE = 42

# ==============================
# FUNCTIONS
# ==============================
def train_rf_and_get_importance(X, y):
    rf = RandomForestRegressor(
        n_estimators=500,
        max_depth=None,
        min_samples_split=2,
        min_samples_leaf=1,
        random_state=RANDOM_STATE,
        n_jobs=-1
    )
    rf.fit(X, y)

    importance = pd.Series(
        rf.feature_importances_,
        index=X.columns
    ).sort_values(ascending=False)

    return importance


def plot_feature_importance(importance, dataset_name, top_k):
    top_features = importance.head(top_k)

    plt.figure(figsize=(8, max(4, top_k * 0.35)))
    sns.barplot(
        x=top_features.values,
        y=top_features.index,
        palette="viridis"
    )
    plt.xlabel("Mean Decrease in Impurity")
    plt.ylabel("Feature")
    plt.title(f"{dataset_name} – Top {top_k} Feature Importances")
    plt.tight_layout()

    out_path = os.path.join(
        OUTPUT_DIR, f"{dataset_name}_FI.png"
    )
    plt.savefig(out_path, dpi=300)
    plt.close()


def plot_correlation_heatmap(X, dataset_name, top_features):
    corr = X[top_features].corr(method="pearson")

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        corr,
        cmap="coolwarm",
        center=0,
        square=True,
        linewidths=0.5,
        cbar_kws={"shrink": 0.8}
    )
    plt.title(f"{dataset_name} – Pearson Correlation (Top Features)")
    plt.tight_layout()

    out_path = os.path.join(
        OUTPUT_DIR, f"{dataset_name}.png"
    )
    plt.savefig(out_path, dpi=300)
    plt.close()



for dataset_name, path in DATASETS.items():
    print(f"[INFO] Processing {dataset_name}")

    df = pd.read_csv(path)

    X = df.drop(columns=[TARGET_COL])
    y = df[TARGET_COL]

    # Train RF + importance
    importance = train_rf_and_get_importance(X, y)

    # Select top-k features
    k = TOP_K[dataset_name]
    top_features = importance.head(k).index.tolist()

    # Plots
    plot_feature_importance(importance, dataset_name, k)
    plot_correlation_heatmap(X, dataset_name, top_features)

    print(f"[DONE] Saved figures for {dataset_name}\n")


[INFO] Processing PRF-Landsat5-9



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


[DONE] Saved figures for PRF-Landsat5-9

[INFO] Processing PRF-Landsat5-54



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


[DONE] Saved figures for PRF-Landsat5-54

[INFO] Processing PRF-LiDAR-36



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


[DONE] Saved figures for PRF-LiDAR-36

[INFO] Processing PRF-LiDAR-702



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `y` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(


[DONE] Saved figures for PRF-LiDAR-702

