In [2]:
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error as sklearn_ms
from sklearn.metrics import mean_absolute_error as sklearn_ma
from sklearn.metrics import r2_score as sklearn_r

from metrics.MeanSquaredError import MeanSquaredError
from metrics.RootMeanSquaredError import RootMeanSquaredError
from metrics.MeanAbsoluteError import MeanAbsoluteError
from metrics.RSquared import RSquared

def test_regression_metrics():
    # Load data
    data = fetch_california_housing()
    X, y = data.data, data.target

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

    # Fit model
    model = LinearRegression()
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    # Initialize metric objects
    mse = MeanSquaredError()
    rmse = RootMeanSquaredError()
    mae = MeanAbsoluteError()
    r_squared = RSquared()

    # Calculate custom metrics
    custom_mse = mse.score(y_test, y_pred)
    custom_rmse = rmse.score(y_test, y_pred)
    custom_mae = mae.score(y_test, y_pred)
    custom_r2 = r_squared.score(y_test, y_pred)

    # Calculate sklearn metrics
    sklearn_mse = sklearn_ms(y_test, y_pred)
    sklearn_rmse = np.sqrt(sklearn_mse)
    sklearn_mae = sklearn_ma(y_test, y_pred)
    sklearn_r2 = sklearn_r(y_test, y_pred)

    # Output results
    print(f"Custom MSE: {custom_mse}, Sklearn MSE: {sklearn_mse}")
    print(f"Custom RMSE: {custom_rmse}, Sklearn RMSE: {sklearn_rmse}")
    print(f"Custom MAE: {custom_mae}, Sklearn MAE: {sklearn_mae}")
    print(f"Custom R2: {custom_r2}, Sklearn R2: {sklearn_r2}")

test_regression_metrics()

Custom MSE: 0.5305677824766755, Sklearn MSE: 0.5305677824766755
Custom RMSE: 0.7284008391515454, Sklearn RMSE: 0.7284008391515454
Custom MAE: 0.527247453830617, Sklearn MAE: 0.527247453830617
Custom R2: 0.5957714480625373, Sklearn R2: 0.5957702326061662
