## Standard RMM experiments

This notebook accumulates experiments for all benchmark datasets that can be handled with standard reservoir memory machines, in particular the latch, copy, repeat copy, signal copy, and image copy task. FSM learning and associative recall require some special concepts and are handled in separate notebooks.

In [1]:
# in this first cell we set some experimental meta-parameters that are used across all
# datasets

# the number of training time series
N = 90
# the number of test time series
N_test = 10
# the number of repeats for the experiments
R = 20
# the names of the tasks to be performed
tasks = ['latch', 'copy', 'repeat_copy', 'signal_copy', 'image_copy']
# the number of neurons for each task
num_neurons = [64, 256, 256, 64, 256]
# the number of input dimensions for each task
ns = [1, 9, 9, 2, 28]
# the horizons for each task
Ts = [256, 24, 16, 312, 32]

## Hyperparameter Optimization

In [2]:
# the number of hyperparameter combinations to be tested
hyper_R = 20
# the number of repeats for each hyperparameter combination
hyper_num_repeats = 3
# set the hyper-parameter ranges for all models
models = ['ESN', 'CRJ', 'LMU', 'RMM_ESN', 'RMM_CRJ', 'RMM_LMU']
hyperparam_ranges = {
    'ESN' : {
        'radius' : [0.5, 0.7, 0.9],
        'sparsity' : [0.1, 0.2, 0.5],
        'regul' : [1E-7, 1E-5, 1E-3]
    },
    'CRJ' : {
        'v' : [0.1, 0.3, 0.5],
        'w_c' : [0.1, 0.7, 0.9],
        'w_j' : [0.1, 0.2, 0.4],
        'l' : [4, 8, 16],
        'regul' : [1E-7, 1E-5, 1E-3]
    },
    'LMU' : {
        'regul' : [1E-7, 1E-5, 1E-3],
        #'T' : [16, 32, 128, 384]
    },
    'RMM_ESN' : {
        'radius' : [0.5, 0.7, 0.9],
        'sparsity' : [0.1, 0.2, 0.5],
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        'svm_kernel' : ['linear', 'rbf']
    },
    'RMM_CRJ' : {
        'v' : [0.1, 0.3, 0.5],
        'w_c' : [0.1, 0.7, 0.9],
        'w_j' : [0.1, 0.2, 0.4],
        'l' : [4, 8, 16],
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        'svm_kernel' : ['linear', 'rbf']
    },
    'RMM_LMU' : {
        'regul' : [1E-7, 1E-5, 1E-3],
        'C' : [1., 100., 10000.],
        #'T' : [16, 32, 128, 384],
        'svm_kernel' : ['linear', 'rbf']
    }
}

import numpy as np
import rmm2.esn as esn
import rmm2.crj as crj
import rmm2.lmu as lmu
import rmm2.rmm as rmm

# set up a function to initialize an instance for each model
def setup_model(model, m, n, hyperparams):
    # first, set up the correct reservoir and nonlinearity
    if model.endswith('ESN'):
        U, W = esn.initialize_reservoir(m, n, radius = hyperparams['radius'],
                                        sparsity = hyperparams['sparsity'])
        nonlin = np.tanh
    elif model.endswith('CRJ'):
        U = crj.setup_input_weight_matrix(n, m, v = hyperparams['v'])
        W = crj.setup_reservoir_matrix(m, w_c = hyperparams['w_c'],
                                       w_j = hyperparams['w_j'], l = hyperparams['l'])
        nonlin = np.tanh
    elif model.endswith('LMU'):
        degree = int(m/n)-1
        U, W = lmu.initialize_reservoir(n, degree, hyperparams['T'])
        nonlin = lambda x : x
    else:
        raise ValueError('Unknown model: %s' % model)
    # then, set up the model
    if not model.startswith('RMM_'):
        net = esn.ESN(U, W, regul = hyperparams['regul'], input_normalization = False,
                      nonlin = nonlin)
    else:
        net = rmm.RMM(U, W, regul = hyperparams['regul'], input_normalization = False,
                      nonlin = nonlin, C = hyperparams['C'],
                      svm_kernel = hyperparams['svm_kernel'])
    return net

## Experiment

After all the hyperparameter setup above we can now iterate over all tasks and
first perform hyperparameter optimization, followed by the actual experiment.

In [3]:
import json
import os
import random
import time
from dataset_generators import generate_data
from dataset_generators import _permutation_sampling

# iterate over all tasks
for task_idx in range(len(tasks)):
    task = tasks[task_idx]
    print('------ Task %d of %d: %s -----' % (task_idx+1, len(tasks), task))
    m = num_neurons[task_idx]
    n = ns[task_idx]
    # try to load the selected hyperparameters from file
    hyperparam_path = '%s_hyperparams.json' % task
    if os.path.isfile(hyperparam_path):
        print('loading hyperparameters from %s' % hyperparam_path)
        with open(hyperparam_path, 'r') as hyperparam_file:
            hyperparams = json.load(hyperparam_file)
    else:
        # perform a hyperoptimization where we test R random hyperparameter
        # settings for each model and perform num_repeats repeats to obtain
        # statistics. The hyperparameters with the best mean performance across
        # repeats will be selected
        print('performing hyperparameter optimization (this may take a while)')
        # generate random parameter combination for all models
        hyperparams = {}
        for model in models:
            # initialize a hyperparameter dictionary for each combination
            hyperparams[model] = []
            for r in range(hyper_R):
                hyperparams[model].append({})
            # then iterate over each key and sample the parameter values
            for key in hyperparam_ranges[model]:
                param_range = hyperparam_ranges[model][key]
                param_value_indices = _permutation_sampling(hyper_R, 0, len(param_range)-1)
                for r in range(hyper_R):
                    value = param_range[param_value_indices[r]]
                    hyperparams[model][r][key] = value
            for r in range(hyper_R):
                # set up an extra key for the errors
                hyperparams[model][r]['errors'] = []
                # for the signal_copy dataset, use the 'pseudo' SVM, because everything
                # else takes too long to train
                if task == 'signal_copy' and model.startswith('RMM'):
                    hyperparams[model][r]['svm_kernel'] = 'pseudo'
                # set time horizon for LMU models
                if model.endswith('LMU'):
                    hyperparams[model][r]['T'] = Ts[task_idx]

        for repeat in range(hyper_num_repeats):
            print('--- repeat %d of %d ---' % (repeat+1, hyper_num_repeats))
            # sample training and test data
            Xs, Qs, Ys = generate_data(N, task)
            Xs_test, Qs_test, Ys_test = generate_data(N_test, task)
            # now iterate over all models
            for model in models:
                print('-- model: %s --' % model)
                # and iterate over all parameter combinations for this model
                for params_r in hyperparams[model]:
                    # set up a model instance
                    net = setup_model(model, m, n, params_r)
                    # fit the model to the data
                    if model.startswith('RMM_'):
                        net.fit(Xs, Qs, Ys)
                    else:
                        net.fit(Xs, Ys)
                    # measure the RMSE on the test data
                    mse = 0.
                    for i in range(N_test):
                        Ypred = net.predict(Xs_test[i])
                        mse   += np.mean((Ypred - Ys_test[i]) ** 2)
                    rmse = np.sqrt(mse / N_test)
                    params_r['errors'].append(rmse)
                    print('error: %g' % rmse)
        # write the results to a JSON file
        with open(hyperparam_path, 'w') as hyperparam_file:
            json.dump(hyperparams, hyperparam_file)

    # select best hyperparameters for each model
    hyperparams_opt = {}
    for model in models:
        min_err = np.inf
        for params_r in hyperparams[model]:
            if np.mean(params_r['errors']) < min_err:
                min_err = np.mean(params_r['errors'])
                hyperparams_opt[model] = params_r
        print('\nSelected the following hyper-parameters for %s' % model)
        for key in hyperparams_opt[model]:
            print('%s: %s' % (key, str(hyperparams_opt[model][key])))
    # hyperparameter optimization complete
    
    # ACTUAL EXPERIMENT

    # initialize error and runtime arrays
    errors   = np.zeros((len(models), R))
    runtimes = np.zeros((len(models), R))
    # iterate over all experimental repeats
    for r in range(R):
        print('--- repeat %d of %d ---' % (r+1, R))
        # sample training and test data
        Xs, Qs, Ys = generate_data(N, task)
        Xs_test, Qs_test, Ys_test = generate_data(N_test, task)
        # now iterate over all models
        for model_idx in range(len(models)):
            model = models[model_idx]
            # print('-- model: %s --' % model)
            # set up the model with the best selected hyperparameters
            start_time = time.time()
            net = setup_model(model, m, n, hyperparams_opt[model])
            # fit the model to the data
            if model.startswith('RMM_'):
                net.fit(Xs, Qs, Ys)
            else:
                net.fit(Xs, Ys)
            # measure the RMSE on the test data
            mse = 0.
            for i in range(N_test):
                Ypred = net.predict(Xs_test[i])
                mse   += np.mean((Ypred - Ys_test[i]) ** 2)
            rmse = np.sqrt(mse / N_test)
            runtimes[model_idx, r] = time.time() - start_time
            errors[model_idx, r] = rmse
    # print results
    for model_idx in range(len(models)):
        print('%s: %g +- %g (took %g seconds)' % (models[model_idx], np.mean(errors[model_idx, :]), np.std(errors[model_idx, :]), np.mean(runtimes[model_idx, :])))
    # write results to file
    np.savetxt('%s_errors.csv' % task, errors.T, delimiter='\t', header='\t'.join(models), comments='')
    np.savetxt('%s_runtimes.csv' % task, runtimes.T, delimiter='\t', header='\t'.join(models), comments='')

------ Task 5 of 5: image_copy -----
performing hyperparameter optimization (this may take a while)
--- repeat 1 of 3 ---
-- model: ESN --
error: 76.8398
error: 77.2181
error: 73.2067
error: 76.2422
error: 71.3716
error: 77.2876
error: 67.1829
error: 75.7641
error: 77.2603
error: 74.4404
error: 63.0365
error: 77.249
error: 76.718
error: 77.2887
error: 71.3956
error: 77.1826
error: 76.0015
error: 73.9041
error: 77.2717
error: 76.2544
-- model: CRJ --
error: 164.184
error: 60.234
error: 161.464
error: 160.749
error: 62.1944
error: 60.2356
error: 61.5848
error: 60.2399
error: 239.804
error: 61.7713
error: 60.3019
error: 81.3662
error: 60.3117
error: 60.2354
error: 208.129
error: 60.234
error: 80.2717
error: 70.8796
error: 62.7334
error: 60.2356
-- model: LMU --
error: 57.1207
error: 57.1207
error: 57.1298
error: 57.1207
error: 57.1207
error: 57.1298
error: 57.1207
error: 57.1298
error: 57.1207
error: 57.1298
error: 57.1207
error: 57.1207
error: 57.1207
error: 57.1298
error: 57.1207
error:



State prediction recall: 1; precision: 0.974409
error: 62.0654
State prediction recall: 1; precision: 0.970588
error: 65.1099
State prediction recall: 1; precision: 0.970588
error: 77.1399
State prediction recall: 1; precision: 0.970588
error: 60.2974
State prediction recall: 1; precision: 0.970588
error: 128.92
State prediction recall: 1; precision: 0.970588
error: 60.2352
State prediction recall: 1; precision: 0.970588
error: 69.5167
State prediction recall: 1; precision: 0.970588
error: 108.747
State prediction recall: 1; precision: 0.970588
error: 60.234
State prediction recall: 1; precision: 0.970588
error: 62.1944
State prediction recall: 1; precision: 0.970588
error: 62.9589
State prediction recall: 1; precision: 0.976331
error: 60.2927
State prediction recall: 1; precision: 0.970588
error: 62.9895
State prediction recall: 1; precision: 0.970588
error: 60.2356
State prediction recall: 1; precision: 0.970588
error: 106.866
State prediction recall: 1; precision: 0.970588
error: 60



State prediction recall: 1; precision: 0.930451
error: 58.4506
State prediction recall: 1; precision: 0.926966
error: 1387.73




State prediction recall: 1; precision: 0.933962
error: 58.4518
State prediction recall: 1; precision: 0.926966
error: 58.8461
State prediction recall: 1; precision: 0.932203
error: 78.7082
State prediction recall: 1; precision: 0.932203
error: 58.4505
State prediction recall: 1; precision: 0.932203
error: 82.3162




State prediction recall: 1; precision: 0.930451
error: 58.4512
State prediction recall: 1; precision: 0.926966
error: 58.5065
State prediction recall: 1; precision: 0.928705
error: 96.3894




State prediction recall: 1; precision: 0.930451
error: 58.4505
State prediction recall: 1; precision: 0.932203
error: 58.5853
State prediction recall: 1; precision: 0.928705
error: 71.9246
State prediction recall: 1; precision: 0.933962
error: 58.4505
State prediction recall: 1; precision: 0.930451
error: 58.4698
State prediction recall: 1; precision: 0.930451
error: 58.4505
State prediction recall: 1; precision: 0.932203
error: 78.8331
State prediction recall: 1; precision: 0.932203
error: 58.478
State prediction recall: 1; precision: 0.932203
error: 58.4505
State prediction recall: 1; precision: 0.932203
error: 59.0213
-- model: RMM_LMU --
State prediction recall: 1; precision: 1
error: 37.4485
State prediction recall: 1; precision: 1
error: 37.4778
State prediction recall: 1; precision: 1
error: 37.4776
State prediction recall: 1; precision: 1
error: 37.4776
State prediction recall: 1; precision: 1
error: 37.4778
State prediction recall: 1; precision: 1
error: 37.4485
State predicti



State prediction recall: 1; precision: 0.932203
error: 57.04
State prediction recall: 1; precision: 0.928705
error: 56.9684




State prediction recall: 1; precision: 0.930451
error: 70.4252
State prediction recall: 1; precision: 0.932203
error: 56.4213
State prediction recall: 1; precision: 0.930451
error: 260.977
State prediction recall: 1; precision: 0.930451
error: 56.4194
State prediction recall: 1; precision: 0.928705
error: 57.2953
State prediction recall: 1; precision: 0.930451
error: 63.1154
State prediction recall: 1; precision: 0.930451
error: 56.4189
State prediction recall: 1; precision: 0.932203
error: 63.9867
State prediction recall: 1; precision: 0.930451
error: 58.186
State prediction recall: 1; precision: 0.933962
error: 56.4212
State prediction recall: 1; precision: 0.930451
error: 58.7967
State prediction recall: 1; precision: 0.930451
error: 56.4287
State prediction recall: 1; precision: 0.930451
error: 72.3193
State prediction recall: 1; precision: 0.930451
error: 56.4654
State prediction recall: 1; precision: 0.932203
error: 56.4204
State prediction recall: 1; precision: 0.930451
error: 5



State prediction recall: 1; precision: 0.951923
State prediction recall: 1; precision: 1
--- repeat 12 of 20 ---
State prediction recall: 1; precision: 1




State prediction recall: 1; precision: 0.978261
State prediction recall: 1; precision: 1
--- repeat 13 of 20 ---
State prediction recall: 1; precision: 1




State prediction recall: 1; precision: 0.966797
State prediction recall: 1; precision: 1
--- repeat 14 of 20 ---
State prediction recall: 1; precision: 1




State prediction recall: 1; precision: 0.939279
State prediction recall: 1; precision: 1
--- repeat 15 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 0.968689
State prediction recall: 1; precision: 1
--- repeat 16 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 0.964912
State prediction recall: 1; precision: 1
--- repeat 17 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 0.970588
State prediction recall: 1; precision: 1
--- repeat 18 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 0.953757
State prediction recall: 1; precision: 1
--- repeat 19 of 20 ---
State prediction recall: 1; precision: 1




State prediction recall: 1; precision: 0.970588
State prediction recall: 1; precision: 1
--- repeat 20 of 20 ---
State prediction recall: 1; precision: 1
State prediction recall: 1; precision: 0.961165
State prediction recall: 1; precision: 1
ESN: 70.4444 +- 3.92433 (took 0.621732 seconds)
CRJ: 59.2492 +- 1.67078 (took 0.310672 seconds)
LMU: 56.1459 +- 1.67954 (took 0.283554 seconds)
RMM_ESN: 71.0736 +- 2.4786 (took 8.7071 seconds)
RMM_CRJ: 59.3849 +- 1.7262 (took 29.5849 seconds)
RMM_LMU: 40.5545 +- 2.89842 (took 4.0279 seconds)
