In [52]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, _tree
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import accuracy_score, classification_report, precision_score, recall_score
from sklearn.model_selection import GroupShuffleSplit

path = "data_distanceFilter_best_before_selection.csv"
df = pd.read_csv(path)

target = "yield_effect"
X = df.drop(columns=[target])

#dropping potential leakage cols
leakingcolumns = ["Unnamed: 0", "yield", "results_detailed", "results_actual", "farming_practice", "adm_id"]
X = X.drop(columns = [c for c in leakingcolumns if c in X.columns], errors="ignore")

y = df[target].astype(str)

#split data types
cat_cols = [c for c in X.columns if X[c].dtype == "object" or str(X[c].dtype).startswith("category")]
num_cols = [c for c in X.columns if c not in cat_cols]

preprocess = ColumnTransformer([
    ("num", Pipeline([
        ("imputer", SimpleImputer(strategy="median"))
    ]), num_cols),
    ("cat", Pipeline([
        ("imputer", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore"))
    ]), cat_cols),
])

clf = DecisionTreeClassifier(
    max_depth = 8,
    min_samples_leaf = 20,
    random_state = 42,
    class_weight = "balanced"
)
pipe = Pipeline(
    [("preprocess", preprocess), ("clf", clf)]
)

groups = df["adm_id"]
gss = GroupShuffleSplit(test_size = 0.25, random_state = 42)

train_idx, test_idx = next(gss.split(X, y, groups = groups))

X_train = X.iloc[train_idx]
X_test  = X.iloc[test_idx]
y_train = y.iloc[train_idx]
y_test  = y.iloc[test_idx]

pipe.fit(X_train, y_train)
pred = pipe.predict(X_test)
print("Accuracy:", accuracy_score(y_test, pred))

#rules
ohe = pipe.named_steps["preprocess"].named_transformers_["cat"].named_steps["onehot"] if cat_cols else None
cat_feature_names = ohe.get_feature_names_out(cat_cols).tolist() if cat_cols else []
feature_names = num_cols + cat_feature_names

tree = pipe.named_steps["clf"].tree_
class_names = pipe.named_steps["clf"].classes_.tolist()

#tree traversal
def walk(node=0, conditions=None):
    if conditions is None:
        conditions = []
    feat = tree.feature[node]
    if feat != _tree.TREE_UNDEFINED:
        name = feature_names[feat]
        thr = tree.threshold[node]
        walk(tree.children_left[node], conditions + [f"{name} <= {thr:.4g}"])
        walk(tree.children_right[node], conditions + [f"{name} > {thr:.4g}"])
    else:
        dist = tree.value[node][0]
        pred_class = class_names[dist.argmax()]
        conf = dist.max() / dist.sum()
        support = tree.n_node_samples[node]
        print(f"IF {' AND '.join(conditions)} THEN yield_effect={pred_class}  (conf={conf:.3f}, n={support})")

walk()

print(
    classification_report(
        y_test,
        pred,
        digits=3
    )
)

Accuracy: 0.9153005464480874
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon <= 3.194 AND lat <= 43.86 THEN yield_effect=49.1  (conf=1.000, n=239)
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon <= 3.194 AND lat > 43.86 THEN yield_effect=25.0  (conf=0.836, n=35)
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon > 3.194 AND yearly_rain_mean <= 399.5 THEN yield_effect=20.42  (conf=0.977, n=26)
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon > 3.194 AND yearly_rain_mean > 399.5 AND phh2o_0-5cm_mean <= 6.999 AND yearly_avg_mean_temp_mean <= 13.01 THEN yield_effect=26.04  (conf=1.000, n=131)
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon > 3.194 AND yearly_rain_mean > 399.5 AND phh2o_0-5cm_mean <= 6.999 AND yearly_avg_mean_temp_mean > 13.01 AND clay_15-30cm_mean <= 256.4 THEN yield_effect=25.0  (conf=1.000, n=20)
IF clay_0-5cm_mean <= 403.8 AND lat <= 45.76 AND lon > 3.194 AND yearly_rain_mean > 399.5 AND phh2o_0-5cm_mean <= 6.999 AND yearly_avg_mean_temp_mean > 13.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
