In [5]:
import numpy as np
from gplearn.genetic import SymbolicRegressor
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [6]:
# Example dataset
np.random.seed(42)
x = np.random.uniform(-10, 10, 100).reshape(-1, 1)  # Input feature
y = 3 * x[:, 0]**2 - 2 * x[:, 0] + 5 + np.random.normal(0, 10, 100)  # Target variable

In [7]:
problem = np.load('problem_0.npz')
x = problem['x'].T
y = problem['y']

x.shape, y.shape

((1000, 2), (1000,))

In [8]:
x_train, x_valid, y_train, y_valid = train_test_split(x, y, test_size=0.2, random_state=42)

# Define the symbolic regressor
est = SymbolicRegressor(
    population_size=2000,
    generations=20,
    stopping_criteria=0.01,
    p_crossover=0.7,
    p_subtree_mutation=0.1,
    p_hoist_mutation=0.05,
    p_point_mutation=0.1,
    max_samples=0.9,
    verbose=1,
    parsimony_coefficient=0.01,
    random_state=42
)

# Fit the model
est.fit(x_train, y_train)

# Predict on test data
y_pred = est.predict(x_valid)

# Print the resulting formula
print("Best formula:", est._program)

    |   Population Average    |             Best Individual              |
---- ------------------------- ------------------------------------------ ----------
 Gen   Length          Fitness   Length          Fitness      OOB Fitness  Time Left
   0    37.95      1.17003e+07        5         0.090893        0.0948581     53.65s
   1    10.29          4.55047       35        0.0352202        0.0346501     30.66s
   2     6.03          1.41631        7        0.0205792        0.0207401     24.54s
   3     1.62          1.56535        7        0.0204394        0.0219985     21.98s
   4     1.80          0.65177        7        0.0203898        0.0224446     20.76s
   5     2.58          1.11226        5       0.00406248       0.00451982     19.08s
Best formula: sub(X0, mul(-0.188, X1))


In [10]:
# Evaluate and visualize
mse = mean_squared_error(y_valid, y_pred)
print(f"Mean Squared Error on Validations Set: {mse}")

Mean Squared Error on Validations Set: 3.915270591451484e-05
