In [None]:
using Distributions
using LinearAlgebra
using Plots
using Test
using Sobol
using Optim
using Profile
using PProf
using Random

In [None]:
include("../rollout.jl")
include("../testfns.jl")

### Psuedo-code for Rollout Bayesian Optimization
1. Generate low-discrepancy sequence for Quasi-Monte Carlo
2. Gather initial samples/experimental data
3. Construct the ground truth surrogate model
4. Setup hyperparameters for stochastic gradient descent
5. While budget has not been exhausted
<ol>
    <li>
        Construct a batch of samples for stochastic gradient descent. For each sample
        <ol>
            <li>Create a copy of the ground truth surrogate at the sample location and the pairwise perturbed surrogate.</li>
            <li style="color: #f66">Initialize our trajectory struct with the fantasized surrogate and fantisized perturbed surrogate and fantasy start location.</li>
            <li>Perform rollout on the trajectory for $r$ steps $M_0$ times for Quasi-Monte Carlo integration.</li>
            <li>Update values for $\alpha$ and $\nabla\alpha$</li>
        </ol>
    </li>
    <li>Once SGD has converged, update sample location using update rule</li>
    <li>Save location and value at location for each sample in batch.</li>
    <li>Select the best sample from the batch and sample original process at new sample location.</li>
    <li>Update surrogate model with value found at new sample location.</li>
    <li>Repeat until budget is exhausted.</li>
</ol>

### Issues
- Use control variates to see how they affect the rollout acquisition functions

#### Probability of Improvement
The probability of improvement (POI) is defined as follows:

$$
POI(x) = \Phi\left( \frac{\mu(x) - f^+ - \xi}{\sigma(x)} \right)
$$

where $f^+$ denotes the best value (maximum) known, $\mu(x)$ is the predictive mean at x, $\sigma(x)$ is the predictive variance, $\xi$ is our exploration parameter, and $\Phi$ is the standard normal cumulative distribution function.
<hr>

#### Expected Improvement
The expected improvement (EI) is defined as follows:

$$
EI(x) = (\mu(x) - f^+ - \xi)\Phi\left( \frac{\mu(x) - f^+ - \xi}{\sigma(x)} \right) +
        \sigma(x)\phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right)
$$

where $\phi$ is the standard normal probability density function.
<hr>

In [None]:
function ei(μ, σ, fbest)
    z = (fbest - μ) / σ
    Φz = Distributions.normcdf(z)
    ϕz = Distributions.normpdf(z)
    return σ*(z*Φz + ϕz)
end

function poi(μ, σ, fbest)
    z = (fbest - μ) / σ
    Φz = Distributions.normcdf(z)
    return Φz
end

In [None]:
# Global parameters
MAX_SGD_ITERS = 1500
BATCH_SIZE = 10
HORIZON = 0
MC_SAMPLES = 150
BUDGET = 10;

### 1. Generate low-discrepancy sequence for Quasi-Monte Carlo

In [None]:
# Setup toy problem
# testfn = TestFunction(
#     1, [0. 1.], [.5],
#     x -> 0. + 1e-6*randn(),
#     ∇x -> [0. + 1e-6*randn()]
# )
testfn = TestGramacyLee()
lbs, ubs = testfn.bounds[:,1], testfn.bounds[:,2]

# Setup low discrepancy random number stream
lds_rns = gen_low_discrepancy_sequence(MC_SAMPLES, testfn.dim, HORIZON+1);
rns = randn(MC_SAMPLES, testfn.dim+1, HORIZON+1);

In [None]:
tplot(testfn)

### 2. Gather initial samples/experimental data

In [None]:
# Gather initial samples/experimental data
N, θ = 1, [.25]
# X = [.15, .85]
X = [2.0, 2.5]
X = reshape(X, 1, length(X))
# ψ = kernel_matern52(θ);
ψ = kernel_scale(kernel_matern52, [1., θ...]);

### 3. Construct the ground truth surrogate model

In [None]:
sur = fit_surrogate(ψ, X, testfn.f);
# θ, sur = optimize_hypers(ψ.θ, kernel_matern52, sur.X, testfn.f);

In [None]:
domain = filter(x -> !(x in X), lbs[1]:.01:ubs[1])
scatter(sur.X', sur.y)
plot!(domain, [sur([x]).μ for x in domain],
    ribbons=2*[sur([x]).σ for x in domain], label="μ ± 2σ")
plot!(domain, [sur([x]).EI for x in domain], label="EI")

### 4. Setup hyperparameters for stochastic gradient descent

In [None]:
# Define the parameters of the optimizer
λ = 0.01  # Learning rate
β1 = 0.9  # First moment decay rate
β2 = 0.999  # Second moment decay rate
ϵ = 1e-8  # Epsilon value

# Define the initial position and moment estimates
m = zeros(testfn.dim)
v = zeros(testfn.dim)

ϵsgd = 1e-12
grad_tol = 1e-5

### 5. While budget has not been exhausted
Each location in our minibatch is going to be our $x^0$ that serves as our deterministic start location. Then, we perform rollout from that point forward, computing several sample trajectories to then be averaged.

We need a few mechanisms:
* We shouldn't sample at locations that are near known locations in sur.X
* We should perform the evaluations in parallel to save time

In [None]:
∇αxs = []
batch = []

final_locations = []

for b in 1:BUDGET
    # Generate a batch of evaluation locations and filter out locations that are close
    # to know sample locations
    batch = generate_batch(BATCH_SIZE; lbs=lbs, ubs=ubs)
    batch = convert(Matrix{Float64}, filter(x -> !(x in sur.X), batch)')
    
    batch_evals = []
    final_locations = []
    
    # This should be a parallel for loop
    println("---------- BO Iteration #$b ----------")
    bndx = 1
    for x0 in eachcol(batch)
        try
            x0 = convert(Vector{Float64}, x0)

            αxs, ∇αxs = [], []
            ∇αxs = [0., 1., 2.]

            print("\n(Batch #$bndx - $x0) Gradient Ascent Iteration Count: ")
            # Run SGD until convergence
            fprev, fnow = 0., 1.
            for epoch in 1:MAX_SGD_ITERS
                if mod(epoch, 25) == 0 print("|") end
                μx, ∇μx = simulate_trajectory(
                    sur; mc_iters=MC_SAMPLES, rnstream=lds_rns, lbs=lbs, ubs=ubs, x0=x0
                )

                # Update gradient vector
                push!(αxs, μx)
                push!(∇αxs, first(∇μx))

                fprev = fnow
                fnow = μx

                # Update x0 based on gradient computation
                # x0, m, v = update_x_adam(x0; ∇g=-∇μx, λ=λ, β1=β1, β2=β2, ϵ=ϵ, m=m, v=v, lbs=lbs, ubs=ubs)
                x0 = update_x(x0; λ=λ, ∇g=∇μx, lbs=lbs, ubs=ubs)

                if abs(fnow - fprev) < ϵsgd || norm(∇μx) < grad_tol
                    println("\nConverged after $epoch epochs")
                    println("abs(fnow - fprev): $(abs(fnow - fprev)) - fnow: $fnow - fprev: $fprev")
                    break
                end

            end

            push!(batch_evals, αxs[end])
            push!(final_locations, x0)
            bndx += 1
        catch e
            bndx += 1
            println(e)
        end
    end
    # Iterate over batch for best response and sample original process afterwards
    if length(batch_evals) > 0
        println()
        [println("α($(pair[1])) = $(pair[2])") for pair in zip(final_locations, batch_evals)]
        ndx = argmax(batch_evals)
        xnew = final_locations[ndx]

        # Sample original process at x0
        println("\nFinal xnew: $xnew")
        println("--------------------------------------\n")
        res = optimize_hypers_optim(sur, kernel_matern52)
        σ, ℓ = Optim.minimizer(res)
        ψ = kernel_scale(kernel_matern52, [σ, ℓ]);
        sur = fit_surrogate(
            ψ,
            hcat(sur.X, xnew),
            vcat(sur.y, testfn.f(xnew))
        )
    end
end

In [None]:
domain = filter(x -> !(x in sur.X), lbs[1]:.01:ubs[1])
scatter(sur.X', sur.y)
plot!(domain, [sur([x]).μ for x in domain],
    ribbons=2*[sur([x]).σ for x in domain], label="μ ± 2σ")
plot!(domain, [sur([x]).EI for x in domain], label="EI")

In [None]:
sur.X'

In [None]:
2.5-2.375