In [1]:
import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error



X, y = make_regression(n_samples=500, n_features=1, noise=10, random_state=42)
y = (y - y.min()) / (y.max() - y.min())  # Normalize target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


class GBM:
    def __init__(self, n_estimators=100, learning_rate=0.1):
        self.n_estimators = n_estimators
        self.learning_rate = learning_rate
        self.models = []

    def fit(self, X, y):
        # Initialize residuals
        residuals = y
        for _ in range(self.n_estimators):
            # Fit a simple regression tree
            tree = DecisionTreeRegressor(max_depth=3)
            tree.fit(X, residuals)
            self.models.append(tree)
            predictions = tree.predict(X)
            residuals -= self.learning_rate * predictions

    def predict(self, X):
        pred = np.zeros(X.shape[0])
        for model in self.models:
            pred += self.learning_rate * model.predict(X)
        return pred


from sklearn.tree import DecisionTreeRegressor

gbm = GBM(n_estimators=50, learning_rate=0.1)
gbm.fit(X_train, y_train)
y_pred = gbm.predict(X_test)

print(y_pred)
print(y_test)

print("MSE:", mean_squared_error(y_test, y_pred))



[0.52686469 0.3977175  0.49101848 0.34907882 0.54429263 0.56189249
 0.37554505 0.22910309 0.36048092 0.50025191 0.47638719 0.67376685
 0.56189249 0.47638719 0.40668067 0.37623698 0.29946109 0.54429263
 0.67376685 0.45550772 0.52862029 0.34907882 0.56634187 0.45298931
 0.45298931 0.78116225 0.28312479 0.36153315 0.42030755 0.37554505
 0.39855469 0.41670261 0.56634187 0.45219194 0.51351882 0.24074751
 0.51199042 0.34907882 0.36745581 0.64438088 0.33586509 0.62554572
 0.52862029 0.28312479 0.33586509 0.56189249 0.36153315 0.47638719
 0.45550772 0.54429263 0.47604735 0.51019375 0.34907882 0.45298931
 0.37554505 0.43595974 0.59583147 0.45219194 0.55561292 0.56634187
 0.19158115 0.31036453 0.51351882 0.55561292 0.60272649 0.47183519
 0.53996901 0.55036644 0.52111097 0.68675002 0.64438088 0.50025191
 0.64438088 0.50474758 0.34907882 0.28312479 0.43595974 0.59583147
 0.54429263 0.49101848 0.41065883 0.37864175 0.47638719 0.41065883
 0.31036453 0.45298931 0.60272649 0.44478217 0.56189249 0.5618