Copyright **`(c)`** 2024 Giovanni Squillero `<giovanni.squillero@polito.it>`  
[`https://github.com/squillero/computational-intelligence`](https://github.com/squillero/computational-intelligence)  
Free under certain conditions — see the [`license`](https://github.com/squillero/computational-intelligence/blob/master/LICENSE.md) for details.  

In [1]:
import numpy as np
from icecream import ic

In [2]:
def true_f(x: np.ndarray) -> np.ndarray:
    return x[0] + np.sin(x[1]) / 5

In [3]:
TEST_SIZE = 10_000
TRAIN_SIZE = 1000

x_validation = np.vstack(
    [
        np.random.random_sample(size=TEST_SIZE) * 2 * np.pi - np.pi,
        np.random.random_sample(size=TEST_SIZE) * 2 - 1,
    ]
)
y_validation = true_f(x_validation)
train_indexes = np.random.choice(TEST_SIZE, size=TRAIN_SIZE, replace=False)
x_train = x_validation[:, train_indexes]
y_train = y_validation[train_indexes]
assert np.all(y_train == true_f(x_train)), "D'ho"

np.savez('problem_0.npz', x=x_train, y=y_train)

## Evaluation

In [4]:
import d3584

In [5]:
problem = np.load('problem_0.npz')
x = problem['x']
y = problem['y']
x.shape

(2, 1000)

In [6]:
print(f"MSE (train): {100*np.square(y_train-d3584.f(x_train)).sum()/len(y_train):g}")
print(f"MSE (real) : {100*np.square(y_validation-d3584.f(x_validation)).sum()/len(y_validation):g}")

MSE (train): 1.13489e-05
MSE (real) : 1.15919e-05


In [7]:
x_train

array([[ 2.81943915, -2.22258618,  2.60328071, ...,  1.78345499,
        -0.26135778, -2.46892904],
       [ 0.1424488 ,  0.13945924, -0.69495683, ..., -0.28033546,
        -0.79030503, -0.1194973 ]])

In [8]:
true_f(x_train)

array([ 2.84783266e+00, -2.19478466e+00,  2.47521026e+00,  3.47991973e-01,
       -2.58484131e+00,  1.15180500e+00, -2.66713652e+00,  1.92431514e+00,
       -2.52162118e+00,  2.36263193e+00, -3.05981004e+00,  1.27130790e+00,
        1.46048340e+00, -2.93294954e+00,  5.81054238e-01, -1.94718717e+00,
       -4.03701215e-01,  1.02781695e+00, -6.48833084e-01,  2.41552635e+00,
        8.63702197e-01,  1.17389950e+00, -2.25818142e+00,  1.84083828e+00,
       -1.69872514e+00, -2.37069370e+00,  2.62122403e+00, -2.04227325e+00,
        2.11122650e+00, -2.26568041e+00,  7.70675361e-01, -4.11935192e-01,
        3.31045873e-01, -2.02571431e+00, -2.84494709e+00, -2.65423657e+00,
        1.49644495e+00, -1.49930168e+00, -2.54907289e+00,  2.52379825e+00,
        2.86861310e+00, -2.47377994e-01,  3.61780147e-01, -1.42550038e+00,
       -3.83312170e-02, -2.92123124e+00,  2.15375209e+00,  2.16302911e-01,
       -2.92146271e+00,  1.33302376e+00,  2.40807373e+00, -1.27810322e+00,
       -2.41778007e+00, -