# Exercise 12 - Applying ESNs to the Lorenz System

- In the lecture, we trained an ESN to reproduce the dynamics of the Lorenz system, correctly predicting the Lorenz trajectory up to ~5 Lyapunov times. After this time, the predicted trajectory deviates from the actual Lorenz trajectory, which is expected since small errors grow exponentially due to the chaotic dynamics of the Lorenz system. 
<br>

- Do the long-term dynamics of the ESN still constitute a valid Lorenz trajectory? Investigate this question by:
    1. Plotting the attractor of the ESN for a long integration and comparing it qualitatively with the true Lorenz attractor.
    2. Compute the maximum Lyapunov exponent of the ESN dynamics and compare it to the maximum Lyapunov exponent of the Lorenz system.<br><br>    

- Bonus: In the lecture, we performed a grid-search on the ESN hyperparameters and kept the model with the lowest error on the validation set. However, because the reservoir is generated randomly, there is some variance in the performance of an ESN for a given set of hyperparameters, meaning that the optimal hyperparameters chosen by cross-validation may not be detereministic. Extend the cross-validation procedure to generate and fit multiple ESNs for each combination of hyperparameters, keeping only the best one. Does this improve the prediction horizon of the ESN?

In [None]:
using Pkg
Pkg.activate(; temp = true)
Pkg.add(["DynamicalSystems", "ReservoirComputing", "Plots", "Printf", "MKL"])

In [None]:
using DynamicalSystems, ReservoirComputing, Plots, Printf, MKL

In [None]:
# Helper functions from the lecture

"""
    train_val_test_split(data; val_seconds, test_seconds, Δt = 0.1)

Split the given data into training, validation, and test sets.
"""
function train_val_test_split(data; val_seconds, test_seconds, Δt = 0.1)
    N = size(data, 2)
    N_val = round(Int, val_seconds / Δt)
    N_test = round(Int, test_seconds / Δt)
    
    ind1 = N - N_test - N_val
    ind2 = N - N_test
    
    train_data = data[:, 1:ind1]
    val_data = data[:, ind1+1:ind2]
    test_data = data[:, ind2+1:end]
    
    return train_data, val_data, test_data
end


"""
    generate_esn(input_signal, reservoir_size = 1000, spectral_radius = 1.0, sparsity = 0.1, input_scale = 0.1)

Generate an Echo State Network consisting of the reservoir weights W and the input weights Wᵢₙ.
"""
function generate_esn(input_signal, reservoir_size = 1000, spectral_radius = 1.0, sparsity = 0.1, input_scale = 0.1)
    W = RandSparseReservoir(reservoir_size, radius = spectral_radius, sparsity = sparsity)
    Wᵢₙ = WeightedLayer(scaling = input_scale)
    return ESN(input_signal, reservoir = W, input_layer = Wᵢₙ)
end


"""
    train_esn!(esn, y, ridge_param)

Given an Echo State Network, train it on the target sequence y_target and return the optimised output weights Wₒᵤₜ.
"""
function train_esn!(esn, y_target, ridge_param)
    training_method = StandardRidge(ridge_param)
    return train(esn, y_target, training_method)
end


"""
    cross_validate_esn(train_data, val_data, param_grid)

Do a grid search on the given param_grid to find the optimal hyperparameters.
"""
function cross_validate_esn(train_data, val_data, param_grid; iters = 1)
    best_loss = Inf
    best_params = nothing
    best_esn = nothing

    # We want to predict one step ahead, so the input signal is equal to the target signal from the previous step
    u_train = train_data[:, 1:end-1]
    y_train = train_data[:, 2:end]
        
    for hyperparams in param_grid
        # Unpack the hyperparams struct
        (;reservoir_size, spectral_radius, sparsity, input_scale, ridge_param) = hyperparams

        for _ in 1:iters
            # Generate and train an ESN
            esn = generate_esn(u_train, reservoir_size, spectral_radius, sparsity, input_scale)
            Wₒᵤₜ = train_esn!(esn, y_train, ridge_param)

            # Evaluate the loss on the validation set
            steps_to_predict = size(val_data, 2)
            prediction = esn(Generative(steps_to_predict), Wₒᵤₜ)
            loss = sum(abs2, prediction - val_data)

            # Keep track of the best hyperparameter values
            if loss < best_loss
                best_loss = loss
                best_params = hyperparams
                best_esn = esn
                println(hyperparams)
                @printf "Validation loss = %.1e\n" best_loss
            end
        end
    end
    
    # Retrain the model using the optimal hyperparameters on both the training and validation data
    # This is necessary because we don't want errors incurred during validation to affect the test error
    (;reservoir_size, spectral_radius, sparsity, input_scale, ridge_param) = best_params
    data = hcat(train_data, val_data)
    u = data[:, 1:end-1]
    y = data[:, 2:end]
    esn = ESN(u, reservoir = best_esn.reservoir_matrix, input_layer = best_esn.input_matrix)
    Wₒᵤₜ = train_esn!(esn, y, ridge_param)
    
    return esn, Wₒᵤₜ
end


"""
    plot_prediction(esn, Wₒᵤₜ, test_data, λ_max)

Given an Echo State Network, plot its predictions versus the given test set.
"""
function plot_prediction(esn, Wₒᵤₜ, test_data, λ_max)
    steps_to_predict = size(test_data, 2)
    prediction = esn(Generative(steps_to_predict), Wₒᵤₜ)
    
    label = ["actual" "predicted"]
    times = Δt * collect(0:steps_to_predict)[1:end-1] / λ_max

    p1 = plot(times, [test_data[1, :], prediction[1, :]], label = label, ylabel = "x(t)")
    p2 = plot(times, [test_data[2, :], prediction[2, :]], label = label, ylabel = "y(t)")
    p3 = plot(times, [test_data[3, :], prediction[3, :]], label = label, ylabel = "z(t)", xlabel = "t * λ_max")
    plot(p1, p2, p3, layout = (3, 1), size = (800, 600))
end


"""
Hyperparameters for an Echo State Network.
"""
struct ESNHyperparams
    reservoir_size
    spectral_radius
    sparsity
    input_scale
    ridge_param
end