## Standard RMM experiments

This notebook accumulates experiments for all benchmark datasets that can be handled with standard reservoir memory machines, in particular the latch, copy, repeat copy, and signal copy task. FSM learning and associative recall require some special concepts and are handled in separate notebooks.

In [1]:
# in this first cell we set some experimental meta-parameters that are used across all
# datasets

# the number of training time series
N = 90
# the number of test time series
N_test = 10
# the number of repeats for the experiments
R = 20
# the names of the tasks to be performed
tasks = ['latch', 'copy', 'repeat_copy', 'signal_copy']
# the number of neurons for each task
num_neurons = [64, 256, 256, 64]
# the number of input dimensions for each task
ns = [1, 9, 9, 2]
# the horizons for each task
Ts = [256, 24, 16, 312]

## Hyperparameter Optimization

In [2]:
# the number of hyperparameter combinations to be tested
hyper_R = 20
# the number of repeats for each hyperparameter combination
hyper_num_repeats = 3
# set the hyper-parameter ranges for all models
models = ['ESN', 'CRJ', 'LMU', 'RMM_ESN', 'RMM_CRJ', 'RMM_LMU']
hyperparam_ranges = {
    'ESN' : {
        'radius' : [0.5, 0.7, 0.9],
        'sparsity' : [0.1, 0.2, 0.5],
        'regul' : [1E-7, 1E-5, 1E-3]
    },
    'CRJ' : {
        'v' : [0.1, 0.3, 0.5],
        'w_c' : [0.1, 0.7, 0.9],
        'w_j' : [0.1, 0.2, 0.4],
        'l' : [4, 8, 16],
        'regul' : [1E-7, 1E-5, 1E-3]
    },
    'LMU' : {
        'regul' : [1E-7, 1E-5, 1E-3],
        #'T' : [16, 32, 128, 384]
    },
    'RMM_ESN' : {
        'radius' : [0.5, 0.7, 0.9],
        'sparsity' : [0.1, 0.2, 0.5],
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        'svm_kernel' : ['linear', 'rbf']
    },
    'RMM_CRJ' : {
        'v' : [0.1, 0.3, 0.5],
        'w_c' : [0.1, 0.7, 0.9],
        'w_j' : [0.1, 0.2, 0.4],
        'l' : [4, 8, 16],
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        'svm_kernel' : ['linear', 'rbf']
    },
    'RMM_LMU' : {
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        #'T' : [16, 32, 128, 384],
        'svm_kernel' : ['linear', 'rbf']
    }
}

import numpy as np
import rmm2.esn as esn
import rmm2.crj as crj
import rmm2.lmu as lmu
import rmm2.rmm as rmm

# set up a function to initialize an instance for each model
def setup_model(model, m, n, hyperparams):
    # first, set up the correct reservoir and nonlinearity
    if model.endswith('ESN'):
        U, W = esn.initialize_reservoir(m, n, radius = hyperparams['radius'],
                                        sparsity = hyperparams['sparsity'])
        nonlin = np.tanh
    elif model.endswith('CRJ'):
        U = crj.setup_input_weight_matrix(n, m, v = hyperparams['v'])
        W = crj.setup_reservoir_matrix(m, w_c = hyperparams['w_c'],
                                       w_j = hyperparams['w_j'], l = hyperparams['l'])
        nonlin = np.tanh
    elif model.endswith('LMU'):
        degree = int(m/n)-1
        U, W = lmu.initialize_reservoir(n, degree, hyperparams['T'])
        nonlin = lambda x : x
    else:
        raise ValueError('Unknown model: %s' % model)
    # then, set up the model
    if not model.startswith('RMM_'):
        net = esn.ESN(U, W, regul = hyperparams['regul'], input_normalization = False,
                      nonlin = nonlin)
    else:
        net = rmm.RMM(U, W, regul = hyperparams['regul'], input_normalization = False,
                      nonlin = nonlin, C = hyperparams['C'],
                      svm_kernel = hyperparams['svm_kernel'])
    return net

## Experiment

After all the hyperparameter setup above we can now iterate over all tasks and
first perform hyperparameter optimization, followed by the actual experiment.

In [3]:
import json
import os
import random
import time
from dataset_generators import generate_data
from dataset_generators import _permutation_sampling

# iterate over all tasks
for task_idx in range(len(tasks)):
    task = tasks[task_idx]
    print('------ Task %d of %d: %s -----' % (task_idx+1, len(tasks), task))
    m = num_neurons[task_idx]
    n = ns[task_idx]
    # try to load the selected hyperparameters from file
    hyperparam_path = '%s_hyperparams.json' % task
    if os.path.isfile(hyperparam_path):
        print('loading hyperparameters from %s' % hyperparam_path)
        with open(hyperparam_path, 'r') as hyperparam_file:
            hyperparams = json.load(hyperparam_file)
    else:
        # perform a hyperoptimization where we test R random hyperparameter
        # settings for each model and perform num_repeats repeats to obtain
        # statistics. The hyperparameters with the best mean performance across
        # repeats will be selected
        print('performing hyperparameter optimization (this may take a while)')
        # generate random parameter combination for all models
        hyperparams = {}
        for model in models:
            # initialize a hyperparameter dictionary for each combination
            hyperparams[model] = []
            for r in range(hyper_R):
                hyperparams[model].append({})
            # then iterate over each key and sample the parameter values
            for key in hyperparam_ranges[model]:
                param_range = hyperparam_ranges[model][key]
                param_value_indices = _permutation_sampling(hyper_R, 0, len(param_range)-1)
                for r in range(hyper_R):
                    value = param_range[param_value_indices[r]]
                    hyperparams[model][r][key] = value
            for r in range(hyper_R):
                # set up an extra key for the errors
                hyperparams[model][r]['errors'] = []
                # for the signal_copy dataset, use the 'pseudo' SVM, because everything
                # else takes too long to train
                if task == 'signal_copy' and model.startswith('RMM'):
                    hyperparams[model][r]['svm_kernel'] = 'pseudo'
                # set time horizon for LMU models
                if model.endswith('LMU'):
                    hyperparams[model][r]['T'] = Ts[task_idx]

        for repeat in range(hyper_num_repeats):
            print('--- repeat %d of %d ---' % (repeat+1, hyper_num_repeats))
            # sample training and test data
            Xs, Qs, Ys = generate_data(N, task)
            Xs_test, Qs_test, Ys_test = generate_data(N_test, task)
            # now iterate over all models
            for model in models:
                print('-- model: %s --' % model)
                # and iterate over all parameter combinations for this model
                for params_r in hyperparams[model]:
                    # set up a model instance
                    net = setup_model(model, m, n, params_r)
                    # fit the model to the data
                    if model.startswith('RMM_'):
                        net.fit(Xs, Qs, Ys)
                    else:
                        net.fit(Xs, Ys)
                    # measure the RMSE on the test data
                    mse = 0.
                    for i in range(N_test):
                        Ypred = net.predict(Xs_test[i])
                        mse   += np.mean((Ypred - Ys_test[i]) ** 2)
                    rmse = np.sqrt(mse / N_test)
                    params_r['errors'].append(rmse)
                    print('error: %g' % rmse)
        # write the results to a JSON file
        with open(hyperparam_path, 'w') as hyperparam_file:
            json.dump(hyperparams, hyperparam_file)

    # select best hyperparameters for each model
    hyperparams_opt = {}
    for model in models:
        min_err = np.inf
        for params_r in hyperparams[model]:
            if np.mean(params_r['errors']) < min_err:
                min_err = np.mean(params_r['errors'])
                hyperparams_opt[model] = params_r
        print('\nSelected the following hyper-parameters for %s' % model)
        for key in hyperparams_opt[model]:
            print('%s: %s' % (key, str(hyperparams_opt[model][key])))
    # hyperparameter optimization complete
    
    # ACTUAL EXPERIMENT

    # initialize error and runtime arrays
    errors   = np.zeros((len(models), R))
    runtimes = np.zeros((len(models), R))
    # iterate over all experimental repeats
    for r in range(R):
        print('--- repeat %d of %d ---' % (r+1, R))
        # sample training and test data
        Xs, Qs, Ys = generate_data(N, task)
        Xs_test, Qs_test, Ys_test = generate_data(N_test, task)
        # now iterate over all models
        for model_idx in range(len(models)):
            model = models[model_idx]
            # print('-- model: %s --' % model)
            # set up the model with the best selected hyperparameters
            start_time = time.time()
            net = setup_model(model, m, n, hyperparams_opt[model])
            # fit the model to the data
            if model.startswith('RMM_'):
                net.fit(Xs, Qs, Ys)
            else:
                net.fit(Xs, Ys)
            # measure the RMSE on the test data
            mse = 0.
            for i in range(N_test):
                Ypred = net.predict(Xs_test[i])
                mse   += np.mean((Ypred - Ys_test[i]) ** 2)
            rmse = np.sqrt(mse / N_test)
            runtimes[model_idx, r] = time.time() - start_time
            errors[model_idx, r] = rmse
    # print results
    for model_idx in range(len(models)):
        print('%s: %g +- %g (took %g seconds)' % (models[model_idx], np.mean(errors[model_idx, :]), np.std(errors[model_idx, :]), np.mean(runtimes[model_idx, :])))
    # write results to file
    np.savetxt('%s_errors.csv' % task, errors.T, delimiter='\t', header='\t'.join(models), comments='')
    np.savetxt('%s_runtimes.csv' % task, runtimes.T, delimiter='\t', header='\t'.join(models), comments='')

------ Task 1 of 4: latch -----
performing hyperparameter optimization (this may take a while)
--- repeat 1 of 3 ---
-- model: ESN --
error: 0.650006
error: 1.8574
error: 0.748828
error: 0.538785
error: 0.564488
error: 0.557255
error: 0.792649
error: 1.22745
error: 0.543965
error: 2.01097
error: 0.908025
error: 0.533619
error: 0.527513
error: 0.549726
error: 0.616828
error: 0.562745
error: 0.710813
error: 0.526645
error: 0.632329
error: 0.54988
-- model: CRJ --
error: 0.557901
error: 0.531944
error: 0.595287
error: 0.556377
error: 0.609112
error: 0.552196
error: 0.585155
error: 0.639562
error: 0.605366
error: 0.566212
error: 0.558781
error: 0.554382
error: 0.594282
error: 0.572832
error: 0.517214
error: 0.534563
error: 0.550437
error: 0.564799
error: 0.538154
error: 0.588388
-- model: LMU --
error: 0.554904
error: 0.554903
error: 0.554916
error: 0.554903
error: 0.554904
error: 0.554916
error: 0.554904
error: 0.554903
error: 0.554916
error: 0.554916
error: 0.554904
error: 0.554903
error

State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 3.14139e-07
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 3.14139e-07
State prediction recall: 1; precision: 1
error: 2.41139e-07
-- model: RMM_LMU --
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 2.77876e-08
State prediction recall: 1; precision: 1
error: 1.44273e-06
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 1.44273e-06
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 1.44273e-06
State prediction recall: 0.972102; precision: 0.972102
error: 0.729559
State prediction recall: 1; precision: 1
error: 1.26675e-08
St

State prediction recall: 1; precision: 1
--- repeat 6 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 7 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 8 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 9 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 10 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 11 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 1
--- repeat 12 of 20 ---
State prediction recall: 1; precision: 1
State prediction



State prediction recall: 0.979904; precision: 0.979904
error: 0.493503
State prediction recall: 0.826316; precision: 0.826316
error: 0.524623
State prediction recall: 0.998086; precision: 0.998086
error: 0.527704
State prediction recall: 0.542105; precision: 0.542105
error: 0.493508
State prediction recall: 0.998086; precision: 0.998086
error: 0.463222
State prediction recall: 0.65311; precision: 0.65311
error: 0.884699
State prediction recall: 1; precision: 1
error: 0.45299
State prediction recall: 0.343541; precision: 0.343541
error: 0.476869




State prediction recall: 0.947368; precision: 0.947368
error: 0.589613
State prediction recall: 1; precision: 1
error: 0.684673
State prediction recall: 0.412919; precision: 0.412919
error: 0.452001
State prediction recall: 0.537799; precision: 0.537799
error: 0.489204
State prediction recall: 0.991388; precision: 0.991388
error: 0.639925
State prediction recall: 0.37512; precision: 0.37512
error: 0.44387
State prediction recall: 1; precision: 1
error: 0.485011
State prediction recall: 0.848325; precision: 0.848325
error: 0.449967




State prediction recall: 0.96555; precision: 0.96555
error: 0.741676
State prediction recall: 0.466507; precision: 0.466507
error: 0.465552
State prediction recall: 0.955502; precision: 0.955502
error: 0.495975
State prediction recall: 0.41244; precision: 0.41244
error: 0.583674
-- model: RMM_LMU --
State prediction recall: 1; precision: 1
error: 0.445881
State prediction recall: 1; precision: 1
error: 0.104583
State prediction recall: 1; precision: 1
error: 0.445874
State prediction recall: 1; precision: 1
error: 0.0433013
State prediction recall: 0.990431; precision: 0.990431
error: 0.457413
State prediction recall: 1; precision: 1
error: 0.043301
State prediction recall: 1; precision: 1
error: 0.0433013
State prediction recall: 1; precision: 1
error: 0.445874
State prediction recall: 1; precision: 1
error: 0.104583
State prediction recall: 0.990431; precision: 0.990431
error: 0.457406
State prediction recall: 1; precision: 1
error: 0.445881
State prediction recall: 1; precision: 1
e



State prediction recall: 0.984937; precision: 0.984937
error: 0.567906
State prediction recall: 0.831876; precision: 0.831876
error: 1.43789
State prediction recall: 0.998542; precision: 0.998542
error: 0.504682
State prediction recall: 0.54276; precision: 0.54276
error: 0.594984
State prediction recall: 0.998542; precision: 0.998542
error: 0.512499
State prediction recall: 0.664723; precision: 0.664723
error: 1.30414
State prediction recall: 1; precision: 1
error: 0.48636
State prediction recall: 0.367833; precision: 0.367833
error: 0.519098




State prediction recall: 0.956268; precision: 0.956268
error: 0.75314
State prediction recall: 1; precision: 1
error: 0.751393
State prediction recall: 0.420797; precision: 0.420797
error: 0.475778
State prediction recall: 0.547619; precision: 0.547619
error: 0.616592
State prediction recall: 0.990768; precision: 0.990768
error: 1.12173
State prediction recall: 0.393586; precision: 0.393586
error: 0.488852
State prediction recall: 1; precision: 1
error: 0.558903
State prediction recall: 0.866375; precision: 0.866375
error: 0.498245
State prediction recall: 0.97036; precision: 0.97036
error: 0.957612
State prediction recall: 0.48105; precision: 0.48105
error: 0.495717
State prediction recall: 0.95724; precision: 0.95724
error: 0.560981
State prediction recall: 0.428571; precision: 0.428571
error: 0.611725
-- model: RMM_LMU --
State prediction recall: 1; precision: 1
error: 0.381816
State prediction recall: 1; precision: 1
error: 0.0887438
State prediction recall: 1; precision: 1
error: 



State prediction recall: 0.984791; precision: 0.984791
error: 0.537532
State prediction recall: 0.815589; precision: 0.815589
error: 0.872757
State prediction recall: 1; precision: 1
error: 0.50013
State prediction recall: 0.548954; precision: 0.548954
error: 0.539442
State prediction recall: 0.999049; precision: 0.999049
error: 0.497459
State prediction recall: 0.659221; precision: 0.659221
error: 1.11272
State prediction recall: 1; precision: 1
error: 0.502647
State prediction recall: 0.355038; precision: 0.355038
error: 0.528261
State prediction recall: 0.946293; precision: 0.946293
error: 0.960377
State prediction recall: 1; precision: 1
error: 0.806349
State prediction recall: 0.420627; precision: 0.420627
error: 0.480647
State prediction recall: 0.54943; precision: 0.54943
error: 0.543671
State prediction recall: 0.987643; precision: 0.987643
error: 0.714622
State prediction recall: 0.382605; precision: 0.382605
error: 0.506567
State prediction recall: 1; precision: 1
error: 0.53



State prediction recall: 0.974335; precision: 0.974335
error: 1.37292
State prediction recall: 0.460076; precision: 0.460076
error: 0.498358
State prediction recall: 0.948194; precision: 0.948194
error: 0.551911
State prediction recall: 0.41635; precision: 0.41635
error: 0.556861
-- model: RMM_LMU --
State prediction recall: 1; precision: 1
error: 0.368126
State prediction recall: 1; precision: 1
error: 5.50633e-06
State prediction recall: 1; precision: 1
error: 0.368123
State prediction recall: 1; precision: 1
error: 0.0829156
State prediction recall: 0.990494; precision: 0.990494
error: 0.412811
State prediction recall: 1; precision: 1
error: 0.0829141
State prediction recall: 1; precision: 1
error: 0.0829156
State prediction recall: 1; precision: 1
error: 0.368123
State prediction recall: 1; precision: 1
error: 5.50633e-06
State prediction recall: 0.990494; precision: 0.990494
error: 0.412807
State prediction recall: 1; precision: 1
error: 0.368126
State prediction recall: 1; precis

State prediction recall: 1; precision: 1
error: 0.408127
State prediction recall: 0.993832; precision: 0.993832
error: 0.403237
State prediction recall: 1; precision: 1
error: 0.441976
State prediction recall: 0.980424; precision: 0.980424
error: 0.362964
State prediction recall: 1; precision: 1
error: 0.409192
State prediction recall: 0.875838; precision: 0.875838
error: 0.481593
State prediction recall: 1; precision: 1
error: 0.475356
-- model: RMM_CRJ --
State prediction recall: 0.784393; precision: 0.784393
error: 0.434292
State prediction recall: 0.848485; precision: 0.848485
error: 0.925253
State prediction recall: 0.999464; precision: 0.999464
error: 2.83627
State prediction recall: 0.999464; precision: 0.999464
error: 0.668045
State prediction recall: 0.727809; precision: 0.727809
error: 0.52825
State prediction recall: 0.836149; precision: 0.836149
error: 0.958406
State prediction recall: 0.999464; precision: 0.999464
error: 1.65919
State prediction recall: 0.994368; precision

error: 0.469896
error: 0.446683
error: 0.477989
error: 0.499584
error: 0.508572
error: 0.453452
error: 0.473459
error: 0.45415
error: 0.506684
error: 0.471692
error: 0.47273
error: 0.463746
error: 0.467957
error: 0.483672
error: 0.485187
-- model: LMU --
error: 0.425221
error: 0.423716
error: 0.426045
error: 0.423716
error: 0.426045
error: 0.425221
error: 0.425221
error: 0.423716
error: 0.426045
error: 0.426045
error: 0.425221
error: 0.423716
error: 0.425221
error: 0.426045
error: 0.423716
error: 0.426045
error: 0.425221
error: 0.423716
error: 0.425221
error: 0.426045
-- model: RMM_ESN --
State prediction recall: 1; precision: 1
error: 0.424292
State prediction recall: 0.909818; precision: 0.909818
error: 0.450733
State prediction recall: 1; precision: 1
error: 0.404783
State prediction recall: 0.955783; precision: 0.955783
error: 0.439591
State prediction recall: 1; precision: 1
error: 0.447708
State prediction recall: 1; precision: 1
error: 0.45864
State prediction recall: 1; precisi

State prediction recall: 1; precision: 1
--- repeat 18 of 20 ---
State prediction recall: 0.981851; precision: 0.981851
State prediction recall: 0.788696; precision: 0.788696
State prediction recall: 1; precision: 1
--- repeat 19 of 20 ---
State prediction recall: 0.96763; precision: 0.96763
State prediction recall: 0.81969; precision: 0.81969
State prediction recall: 1; precision: 1
--- repeat 20 of 20 ---
State prediction recall: 0.982493; precision: 0.982493
State prediction recall: 0.792042; precision: 0.792042
State prediction recall: 1; precision: 1
ESN: 0.43746 +- 0.00746174 (took 0.233494 seconds)
CRJ: 0.456347 +- 0.00542269 (took 0.0639194 seconds)
LMU: 0.439826 +- 0.0166023 (took 0.0709186 seconds)
RMM_ESN: 0.373992 +- 0.022018 (took 1.88442 seconds)
RMM_CRJ: 0.425487 +- 0.0129858 (took 6.20142 seconds)
RMM_LMU: 0.0107131 +- 0.0238775 (took 1.80946 seconds)
------ Task 4 of 4: signal_copy -----
performing hyperparameter optimization (this may take a while)
--- repeat 1 of 3 -

error: 15.8497
State prediction recall: 0.962963; precision: 0.965825
error: 16.1895
-- model: RMM_CRJ --
State prediction recall: 0.958519; precision: 0.961367
error: 18.1149
State prediction recall: 0.971852; precision: 0.982036
error: 14.6787
State prediction recall: 0.967407; precision: 0.971726
error: 14.6519
State prediction recall: 0.974815; precision: 1
error: 27.6779
State prediction recall: 0.988148; precision: 0.983776
error: 14.4528
State prediction recall: 0.98963; precision: 0.998505
error: 13.8988
State prediction recall: 0.958519; precision: 0.974398
error: 14.3305
State prediction recall: 0.971852; precision: 0.980568
error: 15.6241
State prediction recall: 0.991111; precision: 0.985272
error: 14.5396
State prediction recall: 0.986667; precision: 0.989599
error: 14.646
State prediction recall: 0.958519; precision: 0.974398
error: 14.3306
State prediction recall: 0.98963; precision: 0.976608
error: 17.3044
State prediction recall: 0.961481; precision: 0.986322
error: 14

State prediction recall: 0.965926; precision: 0.968796
State prediction recall: 0.995556; precision: 0.994083
State prediction recall: 1; precision: 1
--- repeat 2 of 20 ---
State prediction recall: 0.983704; precision: 0.973607
State prediction recall: 0.998519; precision: 0.998519
State prediction recall: 0.998519; precision: 0.998519
--- repeat 3 of 20 ---
State prediction recall: 0.986667; precision: 0.980854
State prediction recall: 0.998519; precision: 0.998519
State prediction recall: 1; precision: 1
--- repeat 4 of 20 ---
State prediction recall: 0.951111; precision: 0.945508
State prediction recall: 0.998519; precision: 0.998519
State prediction recall: 1; precision: 1
--- repeat 5 of 20 ---
State prediction recall: 0.985185; precision: 0.977941
State prediction recall: 0.992593; precision: 0.992593
State prediction recall: 1; precision: 1
--- repeat 6 of 20 ---
State prediction recall: 0.971852; precision: 0.986466
State prediction recall: 0.995556; precision: 0.997033
State 