In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
from sklearn.datasets import make_regression


def generate_dataset(n_samples=1000, n_features=20, noise=0.1):
    X, y = make_regression(n_samples=n_samples, n_features=n_features, noise=noise, random_state=42)
    return X, y


def fit_decision_tree(X_train, X_test, y_train, y_test, max_depth=None):
    reg = DecisionTreeRegressor(max_depth=max_depth, random_state=42)
    reg.fit(X_train, y_train)
    y_pred_train = reg.predict(X_train)
    y_pred_test = reg.predict(X_test)
    mse_train = mean_squared_error(y_train, y_pred_train)
    mse_test = mean_squared_error(y_test, y_pred_test)
    return mse_train, mse_test, reg


def calculate_mse_complexity(X_train, X_test, y_train, y_test, max_depth_values):
    mse_train_values = []
    mse_test_values = []
    for max_depth in max_depth_values:
        mses_train = []
        mses_test = []
        for _ in range(10):  
            mse_train, mse_test, _ = fit_decision_tree(X_train, X_test, y_train, y_test, max_depth=max_depth)
            mses_train.append(mse_train)
            mses_test.append(mse_test)
        mse_train_values.append(np.mean(mses_train))
        mse_test_values.append(np.mean(mses_test))
    return mse_train_values, mse_test_values


def main():
    
    X, y = generate_dataset()


    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    max_depth_values = list(range(1, 21))


    mse_train_values, mse_test_values = calculate_mse_complexity(X_train, X_test, y_train, y_test, max_depth_values)

 
    plt.figure(figsize=(10, 6))
    plt.plot(max_depth_values, mse_train_values, label='Training Data', marker='o')
    plt.plot(max_depth_values, mse_test_values, label='Test Data', marker='o')
    plt.xlabel('Max Depth')
    plt.ylabel('Mean Squared Error (MSE)')
    plt.title('MSE vs Complexity')
    plt.legend()
    plt.grid(True)
    plt.show()

if name == "main":
    main()