Copyright **`(c)`** 2024 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free under certain conditions — see the [`license`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

In [1]:
import numpy as np
from icecream import ic

In [2]:
def true_f(x: np.ndarray) -> np.ndarray:
    return x[0] + np.sin(x[1]) / 5

In [3]:
TEST_SIZE = 10_000
TRAIN_SIZE = 1000

x_validation = np.vstack(
    [
        np.random.random_sample(size=TEST_SIZE) * 2 * np.pi - np.pi,
        np.random.random_sample(size=TEST_SIZE) * 2 - 1,
    ]
)
y_validation = true_f(x_validation)
train_indexes = np.random.choice(TEST_SIZE, size=TRAIN_SIZE, replace=False)
x_train = x_validation[:, train_indexes]
y_train = y_validation[train_indexes]
assert np.all(y_train == true_f(x_train)), "D'ho"

np.savez('problem_0.npz', x=x_train, y=y_train)

## Evaluation

In [4]:
import d3584

In [5]:
problem = np.load('problem_0.npz')
x = problem['x']
y = problem['y']
x.shape

(2, 1000)

In [6]:
print(f"MSE (train): {100*np.square(y_train-d3584.f(x_train)).sum()/len(y_train):g}")
print(f"MSE (real) : {100*np.square(y_validation-d3584.f(x_validation)).sum()/len(y_validation):g}")

MSE (train): 0.015773
MSE (real) : 0.0145847


In [11]:
x_train

array([[-3.10960356,  1.03145469,  1.65966558, ...,  1.08914327,
         0.72309318,  1.17803602],
       [-0.4311445 , -0.7172712 , -0.88678894, ...,  0.81126928,
        -0.88722681,  0.73595542]])

In [10]:
true_f(x_train)

array([-3.19318573e+00,  8.99988549e-01,  1.50465625e+00, -2.41577934e+00,
       -2.12190501e+00,  1.65947341e+00, -8.94245265e-01,  1.02767224e+00,
       -1.49169988e+00, -8.05820049e-01, -1.45340541e+00,  2.09234969e-01,
       -2.05709611e-01,  3.28261801e+00, -2.27966700e+00,  9.33164812e-01,
       -6.61218028e-02,  5.00492374e-01,  2.84013365e+00,  2.24490602e+00,
        2.16357883e+00, -6.34193818e-01, -3.21711073e+00,  1.56125603e+00,
        1.93934183e+00,  8.00468932e-01, -1.26414388e+00, -3.59622261e-03,
       -1.71425417e+00, -2.69433727e+00,  2.71785169e+00,  8.93105075e-01,
        2.93430335e+00,  1.40247705e+00,  3.18915333e+00,  2.96450050e+00,
       -2.79606018e+00, -2.50923916e+00, -1.47317336e-01, -1.58718154e+00,
       -1.98607772e+00, -5.34999486e-01,  3.04583810e+00, -2.84979588e+00,
       -1.34906481e+00,  5.21810163e-01, -3.11458858e-01,  1.02442848e+00,
       -1.64180848e+00,  5.69117581e-01,  2.54448265e+00, -1.01905893e+00,
       -2.10566038e+00,  