In [1]:
import numpy as np
import matplotlib.pyplot as plt

from resnet import ResNet
from get_te import get_te

from tqdm.notebook import tqdm
import uncertainties as unc

In [2]:
def lorenz(x, y, z, s=10, r=28, b=2.667):
    '''
    Given:
       x, y, z: a point of interest in three dimensional space
       s, r, b: parameters defining the lorenz attractor
    Returns:
       x_dot, y_dot, z_dot: values of the lorenz attractor's partial
           derivatives at the point x, y, z
    '''
    x_dot = s*(y - x)
    y_dot = r*x - y - x*z
    z_dot = x*y - b*z
    return x_dot, y_dot, z_dot

dt = 0.02
num_steps = 6250

# Need one more for the initial values
xs = np.empty(num_steps + 1)
ys = np.empty(num_steps + 1)
zs = np.empty(num_steps + 1)

# Set initial values
xs[0], ys[0], zs[0] = (0., 1., 1.05)

# Step through "time", calculating the partial derivatives at the current point
# and using them to estimate the next point
for i in range(num_steps):
    x_dot, y_dot, z_dot = lorenz(xs[i], ys[i], zs[i])
    xs[i + 1] = xs[i] + (x_dot * dt)
    ys[i + 1] = ys[i] + (y_dot * dt)
    zs[i + 1] = zs[i] + (z_dot * dt)
    
traj = np.vstack((xs[1:],ys[1:],zs[1:]))
time = np.linspace(-100,25,num_steps)


# warmup
w = 50
trajtest = traj[:,5000+w-1:]

In [4]:
def robust_te(nensembles=10, prune_thr=None, verbose=True, **kwargs):
    resnets = []
    tes = []
    preds = []
    for i in range(nensembles):
        resnets.append(ResNet(**kwargs))
        resnets[i].train(traj[:,:5000])
        resnets[i].compute_Wout(traj[:,:5000])
        preds.append(resnets[i].test(traj[:,5000:], warmup=50))
        tes.append(get_te(preds[i],trajtest,eps=0.4,error_version='normnorm', verbose=verbose))
        if prune_thr and tes[-1] < prune_thr:
            raise ValueError('Pruning')
    te = unc.ufloat(np.mean(tes), np.std(tes))
    print(kwargs, te)
    return te, tes, resnets, preds

In [14]:
robust_te(N=300)

Error exceeds threshold value 0.4 after 177 time steps --> t_e = 3.540.
Error exceeds threshold value 0.4 after 85 time steps --> t_e = 1.700.
Error exceeds threshold value 0.4 after 229 time steps --> t_e = 4.580.
Error exceeds threshold value 0.4 after 186 time steps --> t_e = 3.720.
Error exceeds threshold value 0.4 after 239 time steps --> t_e = 4.780.
Error exceeds threshold value 0.4 after 266 time steps --> t_e = 5.320.
Error exceeds threshold value 0.4 after 177 time steps --> t_e = 3.540.
Error exceeds threshold value 0.4 after 237 time steps --> t_e = 4.740.
Error exceeds threshold value 0.4 after 265 time steps --> t_e = 5.300.
Error exceeds threshold value 0.4 after 180 time steps --> t_e = 3.600.
{'N': 300} 4.1+/-1.0


(4.082+/-1.0396711018394231,
 [3.54, 1.7, 4.58, 3.72, 4.78, 5.32, 3.54, 4.74, 5.3, 3.6],
 [<resnet.ResNet at 0x7f62f0f4ab20>,
  <resnet.ResNet at 0x7f62f0f50280>,
  <resnet.ResNet at 0x7f62f0f48c70>,
  <resnet.ResNet at 0x7f62f0f48ca0>,
  <resnet.ResNet at 0x7f62f0f48ee0>,
  <resnet.ResNet at 0x7f62f0f48c10>,
  <resnet.ResNet at 0x7f62f0f48d30>,
  <resnet.ResNet at 0x7f62f0f48730>,
  <resnet.ResNet at 0x7f62f0f48670>,
  <resnet.ResNet at 0x7f62f0f57f40>],
 [array([[-16.05818245, -17.2025555 , -17.78682059, ...,   5.70829562,
            6.34874803,   7.07474646],
         [-21.77643082, -20.12252645, -16.59032128, ...,   8.9136768 ,
            9.98047167,  11.15154011],
         [ 31.79081316,  37.08840114,  42.03578481, ...,  17.09591177,
           17.19868428,  17.54746175]]),
  array([[-16.05754082, -17.20197693, -17.78701484, ...,  -2.66582523,
          -14.87082569, -21.77052371],
         [-21.77916655, -20.12696577, -16.59740488, ..., -52.45088986,
          -53.00539869, -56

In [5]:
import optuna

In [19]:
prune_thr = 3
def objective(trial):
    N = trial.suggest_int('N', 50, 1000)
    sigma = trial.suggest_float('sigma', 0.01, 1)
    degree = trial.suggest_int('degree', 1, 20)
    spr = trial.suggest_float('spr', 0.5, 2)
    beta = trial.suggest_float('beta', 0, 1e-3)
    alpha = trial.suggest_float('alpha', 0.1, 1)
    bias = trial.suggest_float('bias', 0, 3)
    
#     N = 500
#     alpha = 1
    
    try:
        te = robust_te(N=N, sigma=sigma, degree=degree, spr=spr, beta=beta, alpha=alpha, bias=bias,
                       prune_thr=prune_thr, verbose=False)[0]
    except KeyboardInterrupt:
        raise KeyboardInterrupt
    except:
        raise optuna.TrialPruned()
    
    return te.n

In [20]:
# name = 'study1'
# name = 'study_N300_alpha1'
name = 'study2'
study = optuna.create_study(storage=f'sqlite:///{name}.db',
                            direction='maximize', study_name=name,
                            load_if_exists=True,
                           )

[32m[I 2022-04-26 22:08:54,708][0m A new study created in RDB with name: study2[0m


In [21]:
study.optimize(objective, n_trials=200, show_progress_bar=True)


Progress bar is experimental (supported from v1.2.0). The interface can change in the future.



  0%|          | 0/200 [00:00<?, ?it/s]

[32m[I 2022-04-26 22:09:07,935][0m Trial 0 pruned. [0m
[32m[I 2022-04-26 22:09:09,109][0m Trial 1 pruned. [0m
[32m[I 2022-04-26 22:09:12,175][0m Trial 2 pruned. [0m
[32m[I 2022-04-26 22:09:12,520][0m Trial 3 pruned. [0m
{'N': 967, 'sigma': 0.8031136068240048, 'degree': 3, 'spr': 1.6216489974598638, 'beta': 0.0009553787312819102, 'alpha': 0.534780953477636, 'bias': 2.5036188180684524} 4.3+/-0.9
[32m[I 2022-04-26 22:09:39,299][0m Trial 4 finished with value: 4.324 and parameters: {'N': 967, 'sigma': 0.8031136068240048, 'degree': 3, 'spr': 1.6216489974598638, 'beta': 0.0009553787312819102, 'alpha': 0.534780953477636, 'bias': 2.5036188180684524}. Best is trial 4 with value: 4.324.[0m
[32m[I 2022-04-26 22:09:39,571][0m Trial 5 pruned. [0m
[32m[I 2022-04-26 22:09:39,969][0m Trial 6 pruned. [0m
[32m[I 2022-04-26 22:09:40,362][0m Trial 7 pruned. [0m
[32m[I 2022-04-26 22:09:40,788][0m Trial 8 pruned. [0m
[32m[I 2022-04-26 22:09:41,167][0m Trial 9 pruned. [0m
{'N': 

{'N': 896, 'sigma': 0.361587335828802, 'degree': 8, 'spr': 1.2696444765861765, 'beta': 0.0006254963738308148, 'alpha': 0.9961177876032155, 'bias': 1.5441712923422444} 5.0+/-0.8
[32m[I 2022-04-26 22:14:23,295][0m Trial 30 finished with value: 4.986 and parameters: {'N': 896, 'sigma': 0.361587335828802, 'degree': 8, 'spr': 1.2696444765861765, 'beta': 0.0006254963738308148, 'alpha': 0.9961177876032155, 'bias': 1.5441712923422444}. Best is trial 16 with value: 6.436.[0m
{'N': 754, 'sigma': 0.15240279827480918, 'degree': 12, 'spr': 0.7103860689674467, 'beta': 0.0005770611901317859, 'alpha': 0.9386082605777846, 'bias': 2.2147056426495086} 4.6+/-1.7
[32m[I 2022-04-26 22:14:33,449][0m Trial 31 finished with value: 4.564 and parameters: {'N': 754, 'sigma': 0.15240279827480918, 'degree': 12, 'spr': 0.7103860689674467, 'beta': 0.0005770611901317859, 'alpha': 0.9386082605777846, 'bias': 2.2147056426495086}. Best is trial 16 with value: 6.436.[0m
{'N': 900, 'sigma': 0.29667691001710145, 'degr

{'N': 831, 'sigma': 0.5077360422832468, 'degree': 15, 'spr': 1.1261997165961477, 'beta': 4.532799490036826e-05, 'alpha': 0.6036437102573835, 'bias': 0.8362084011470736} 4.5+/-1.1
[32m[I 2022-04-26 22:20:17,756][0m Trial 48 finished with value: 4.546 and parameters: {'N': 831, 'sigma': 0.5077360422832468, 'degree': 15, 'spr': 1.1261997165961477, 'beta': 4.532799490036826e-05, 'alpha': 0.6036437102573835, 'bias': 0.8362084011470736}. Best is trial 16 with value: 6.436.[0m
{'N': 965, 'sigma': 0.4412369654186921, 'degree': 20, 'spr': 1.2916082307950705, 'beta': 0.00010783258467396798, 'alpha': 0.7141180480558521, 'bias': 1.2653711658956732} 5.0+/-0.9
[32m[I 2022-04-26 22:20:44,988][0m Trial 49 finished with value: 5.002000000000001 and parameters: {'N': 965, 'sigma': 0.4412369654186921, 'degree': 20, 'spr': 1.2916082307950705, 'beta': 0.00010783258467396798, 'alpha': 0.7141180480558521, 'bias': 1.2653711658956732}. Best is trial 16 with value: 6.436.[0m
{'N': 792, 'sigma': 0.39201660

{'N': 644, 'sigma': 0.26275638537314705, 'degree': 8, 'spr': 0.8900080403490116, 'beta': 0.00012341626312289703, 'alpha': 0.4097817718379954, 'bias': 2.0368334997869537} 6.0+/-0.7
[32m[I 2022-04-26 22:24:35,012][0m Trial 66 finished with value: 6.0280000000000005 and parameters: {'N': 644, 'sigma': 0.26275638537314705, 'degree': 8, 'spr': 0.8900080403490116, 'beta': 0.00012341626312289703, 'alpha': 0.4097817718379954, 'bias': 2.0368334997869537}. Best is trial 61 with value: 6.675999999999999.[0m
{'N': 641, 'sigma': 0.18806351620769374, 'degree': 7, 'spr': 0.8879208549447052, 'beta': 0.0001343749267651669, 'alpha': 0.34271913931495046, 'bias': 2.072714804108093} 5.55+/-0.32
[32m[I 2022-04-26 22:24:42,961][0m Trial 67 finished with value: 5.554 and parameters: {'N': 641, 'sigma': 0.18806351620769374, 'degree': 7, 'spr': 0.8879208549447052, 'beta': 0.0001343749267651669, 'alpha': 0.34271913931495046, 'bias': 2.072714804108093}. Best is trial 61 with value: 6.675999999999999.[0m
{'N

{'N': 838, 'sigma': 0.3013292336783148, 'degree': 9, 'spr': 1.0462120459729545, 'beta': 5.770737011036481e-05, 'alpha': 0.26000979961098697, 'bias': 1.5521309156676362} 5.84+/-0.31
[32m[I 2022-04-26 22:27:21,978][0m Trial 83 finished with value: 5.836 and parameters: {'N': 838, 'sigma': 0.3013292336783148, 'degree': 9, 'spr': 1.0462120459729545, 'beta': 5.770737011036481e-05, 'alpha': 0.26000979961098697, 'bias': 1.5521309156676362}. Best is trial 61 with value: 6.675999999999999.[0m
{'N': 758, 'sigma': 0.41046210610897227, 'degree': 8, 'spr': 1.3404329382578692, 'beta': 0.0001559106843444095, 'alpha': 0.2843942319630738, 'bias': 1.9304195982458978} 4.8+/-0.8
[32m[I 2022-04-26 22:27:32,084][0m Trial 84 finished with value: 4.7540000000000004 and parameters: {'N': 758, 'sigma': 0.41046210610897227, 'degree': 8, 'spr': 1.3404329382578692, 'beta': 0.0001559106843444095, 'alpha': 0.2843942319630738, 'bias': 1.9304195982458978}. Best is trial 61 with value: 6.675999999999999.[0m
{'N':

{'N': 903, 'sigma': 0.1632785050680431, 'degree': 11, 'spr': 0.922826781867531, 'beta': 7.342726772864826e-06, 'alpha': 0.5134702733292626, 'bias': 2.339052269768439} 6.6+/-0.9
[32m[I 2022-04-26 22:32:50,430][0m Trial 101 finished with value: 6.603999999999999 and parameters: {'N': 903, 'sigma': 0.1632785050680431, 'degree': 11, 'spr': 0.922826781867531, 'beta': 7.342726772864826e-06, 'alpha': 0.5134702733292626, 'bias': 2.339052269768439}. Best is trial 98 with value: 7.206.[0m
{'N': 970, 'sigma': 0.18532413886790466, 'degree': 11, 'spr': 0.9215653163851869, 'beta': 4.5781179096408455e-06, 'alpha': 0.5132084401161564, 'bias': 2.3171034830026187} 6.11+/-0.07
[32m[I 2022-04-26 22:33:17,825][0m Trial 102 finished with value: 6.114 and parameters: {'N': 970, 'sigma': 0.18532413886790466, 'degree': 11, 'spr': 0.9215653163851869, 'beta': 4.5781179096408455e-06, 'alpha': 0.5132084401161564, 'bias': 2.3171034830026187}. Best is trial 98 with value: 7.206.[0m
{'N': 961, 'sigma': 0.165235

{'N': 908, 'sigma': 0.162806214499538, 'degree': 14, 'spr': 0.7316154944562117, 'beta': 2.1154718503452753e-05, 'alpha': 0.5357164434041196, 'bias': 2.4294642365009684} 6.4+/-0.9
[32m[I 2022-04-26 22:39:09,392][0m Trial 119 finished with value: 6.362 and parameters: {'N': 908, 'sigma': 0.162806214499538, 'degree': 14, 'spr': 0.7316154944562117, 'beta': 2.1154718503452753e-05, 'alpha': 0.5357164434041196, 'bias': 2.4294642365009684}. Best is trial 98 with value: 7.206.[0m
{'N': 914, 'sigma': 0.161089425838927, 'degree': 16, 'spr': 0.810209211678566, 'beta': 2.6261885692407197e-05, 'alpha': 0.39062964343011686, 'bias': 2.4303948273810803} 6.1+/-0.6
[32m[I 2022-04-26 22:39:29,612][0m Trial 120 finished with value: 6.146 and parameters: {'N': 914, 'sigma': 0.161089425838927, 'degree': 16, 'spr': 0.810209211678566, 'beta': 2.6261885692407197e-05, 'alpha': 0.39062964343011686, 'bias': 2.4303948273810803}. Best is trial 98 with value: 7.206.[0m
{'N': 937, 'sigma': 0.17583796987234734, '

{'N': 965, 'sigma': 0.18171318609009598, 'degree': 13, 'spr': 0.9374622836831519, 'beta': 4.6428366644823216e-05, 'alpha': 0.4227716198459963, 'bias': 2.07709553813064} 6.4+/-0.9
[32m[I 2022-04-26 22:45:19,991][0m Trial 137 finished with value: 6.360000000000001 and parameters: {'N': 965, 'sigma': 0.18171318609009598, 'degree': 13, 'spr': 0.9374622836831519, 'beta': 4.6428366644823216e-05, 'alpha': 0.4227716198459963, 'bias': 2.07709553813064}. Best is trial 98 with value: 7.206.[0m
[32m[I 2022-04-26 22:45:20,990][0m Trial 138 pruned. [0m
{'N': 941, 'sigma': 0.22762637789767376, 'degree': 12, 'spr': 0.8421544286580364, 'beta': 3.602869062554782e-05, 'alpha': 0.4692246680195506, 'bias': 2.268496015426249} 6.2+/-0.5
[32m[I 2022-04-26 22:45:44,757][0m Trial 139 finished with value: 6.208 and parameters: {'N': 941, 'sigma': 0.22762637789767376, 'degree': 12, 'spr': 0.8421544286580364, 'beta': 3.602869062554782e-05, 'alpha': 0.4692246680195506, 'bias': 2.268496015426249}. Best is tr

{'N': 881, 'sigma': 0.24392911435954845, 'degree': 9, 'spr': 0.6995176371333555, 'beta': 4.5322671211463604e-05, 'alpha': 0.44729664681187986, 'bias': 2.086555747100717} 5.84+/-0.31
[32m[I 2022-04-26 22:52:02,603][0m Trial 156 finished with value: 5.843999999999999 and parameters: {'N': 881, 'sigma': 0.24392911435954845, 'degree': 9, 'spr': 0.6995176371333555, 'beta': 4.5322671211463604e-05, 'alpha': 0.44729664681187986, 'bias': 2.086555747100717}. Best is trial 98 with value: 7.206.[0m
{'N': 981, 'sigma': 0.18136396937953086, 'degree': 10, 'spr': 0.9176354725029962, 'beta': 7.401920626926233e-05, 'alpha': 0.47104744018428235, 'bias': 2.1725897979380764} 6.06+/-0.08
[32m[I 2022-04-26 22:52:30,759][0m Trial 157 finished with value: 6.064 and parameters: {'N': 981, 'sigma': 0.18136396937953086, 'degree': 10, 'spr': 0.9176354725029962, 'beta': 7.401920626926233e-05, 'alpha': 0.47104744018428235, 'bias': 2.1725897979380764}. Best is trial 98 with value: 7.206.[0m
{'N': 847, 'sigma': 

{'N': 876, 'sigma': 0.17597028193315095, 'degree': 12, 'spr': 0.9386297877263212, 'beta': 1.6544122470011867e-05, 'alpha': 0.39346092059399773, 'bias': 1.5162236543790135} 6.5+/-0.8
[32m[I 2022-04-26 22:58:03,770][0m Trial 174 finished with value: 6.526000000000001 and parameters: {'N': 876, 'sigma': 0.17597028193315095, 'degree': 12, 'spr': 0.9386297877263212, 'beta': 1.6544122470011867e-05, 'alpha': 0.39346092059399773, 'bias': 1.5162236543790135}. Best is trial 98 with value: 7.206.[0m
{'N': 879, 'sigma': 0.16449706738961334, 'degree': 12, 'spr': 0.9367359083713294, 'beta': 3.9208831445344595e-05, 'alpha': 0.40216194376475123, 'bias': 1.4015197411484965} 5.99+/-0.21
[32m[I 2022-04-26 22:58:21,061][0m Trial 175 finished with value: 5.994000000000001 and parameters: {'N': 879, 'sigma': 0.16449706738961334, 'degree': 12, 'spr': 0.9367359083713294, 'beta': 3.9208831445344595e-05, 'alpha': 0.40216194376475123, 'bias': 1.4015197411484965}. Best is trial 98 with value: 7.206.[0m
{'N'

{'N': 879, 'sigma': 0.18263836402008, 'degree': 10, 'spr': 0.9994041711062052, 'beta': 1.7984251663829122e-05, 'alpha': 0.24358827114606763, 'bias': 2.2296041142829064} 6.5+/-0.9
[32m[I 2022-04-26 23:04:41,894][0m Trial 191 finished with value: 6.544000000000001 and parameters: {'N': 879, 'sigma': 0.18263836402008, 'degree': 10, 'spr': 0.9994041711062052, 'beta': 1.7984251663829122e-05, 'alpha': 0.24358827114606763, 'bias': 2.2296041142829064}. Best is trial 98 with value: 7.206.[0m
{'N': 890, 'sigma': 0.1808933555691659, 'degree': 10, 'spr': 0.9374338034145464, 'beta': 3.534447556926674e-05, 'alpha': 0.22917982363820966, 'bias': 2.2321026092718768} 5.85+/-0.30
[32m[I 2022-04-26 23:04:59,893][0m Trial 192 finished with value: 5.8500000000000005 and parameters: {'N': 890, 'sigma': 0.1808933555691659, 'degree': 10, 'spr': 0.9374338034145464, 'beta': 3.534447556926674e-05, 'alpha': 0.22917982363820966, 'bias': 2.2321026092718768}. Best is trial 98 with value: 7.206.[0m
{'N': 942, 's

In [22]:
study.best_value, study.best_params

(7.206,
 {'N': 949,
  'alpha': 0.4702160137921147,
  'beta': 3.148463444839121e-06,
  'bias': 2.36812363839121,
  'degree': 10,
  'sigma': 0.14127058956202954,
  'spr': 1.029366494669368})

In [23]:
optuna.visualization.plot_param_importances(study)

In [25]:
optuna.visualization.plot_contour(study, ['N', 'spr'])