In [1]:
import numpy as np
import pandas as pd

from GradientBoostedRegressionTree import GBRT

In [24]:
from sklearn.datasets import make_regression

X, y = make_regression(
    n_samples=2000,        # Number of samples
    n_features=10,          # Number of features
    n_informative=2,       # Number of features that are actually useful
    noise=10.0,            # Amount of noise in the output
    random_state=42,       # For reproducibility
    coef=False             # Don't return coefficients (we'll interpret ourselves)
)

df = pd.DataFrame(X)
df['y'] = y

df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,y
0,-1.892762,0.149342,0.161699,0.239628,0.311491,-1.411057,0.803155,0.102859,0.733054,-0.699415,6.689182
1,-2.299358,-0.767507,0.886843,-0.629867,0.169827,-2.143967,0.102382,-0.452424,0.09643,-0.729997,91.133535
2,1.179297,1.277677,-1.124642,1.551152,0.067518,0.332314,-0.748487,-1.534114,0.115675,0.711615,-107.620402
3,0.283911,-1.601181,0.388425,-1.026354,-0.007063,0.625591,1.484642,-1.027252,1.589441,0.187107,40.732998
4,-0.038135,-2.388531,0.412071,0.11835,-0.210535,-0.834338,-0.691054,0.525169,-0.80336,-1.380652,27.098789


In [25]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [26]:
def r2_score(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    y_mean = np.mean(y_true)
    sse = np.sum((y_true - y_pred) ** 2)
    sst = np.sum((y_true - y_mean) ** 2)
    r2 = 1 - (sse / sst)
    return r2

In [27]:
model = GBRT(num_trees=50, alpha=.2)
model.fit(X_train, y_train)

preds = model.predict(X_test)

df_res = pd.DataFrame(X_test)
df_res['y'] = y_test
df_res['preds'] = preds

r2 = r2_score(y_test, preds)

rmse = np.sqrt(np.mean((preds - y_test) ** 2))
print(f"Root Mean Squared Error: {rmse:.2f}")
print(f'R2 Score: {r2:.2f}')

df_res.head()

Root Mean Squared Error: 11.52
R2 Score: 0.99


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,y,preds
0,0.757032,-0.784164,1.263856,1.600098,-0.252625,-0.470755,1.121188,0.262503,1.528332,-0.54342,129.867302,116.171189
1,-0.346772,0.436236,1.466783,-0.929136,-0.585793,-1.190917,0.678728,0.505558,-1.13402,-0.241431,133.765477,143.529969
2,-2.604214,-0.012089,1.511155,0.886887,0.198948,-1.252393,0.363632,-1.451176,-0.420762,-0.127549,150.297287,152.395993
3,1.05307,-0.157178,-0.878136,-0.718138,0.148203,-0.524228,0.336322,-1.319172,-0.615392,0.383655,-61.390321,-89.371922
4,-1.02366,0.550128,0.343598,-0.03542,-0.285687,-0.800293,-0.017928,-1.436352,-0.527066,-0.074593,26.200784,37.77382
