In [None]:
import Pkg
Pkg.activate("..")

In [None]:
Pkg.status()

In [None]:
include("burgers_common.jl")

In [None]:
import DiffResults
import Random

In [None]:
function svd_mem(nsv, ngrid)
    (m,n) = svd_dimensions(ngrid)
    return nsv * (m + n + 1)
end

In [None]:
nsv = 1
for k in 6:12

    Nx = 2^k

    smem = svd_mem(nsv, Nx) * sizeof(Float64)
    mem = sizeof(Float64) * Nx

    (m,n) = svd_dimensions(Nx)
    x = randn(Nx)
    xsvd = IAD.SVDVector(x, m, n, nsv)
    imem = Base.summarysize(xsvd)

    @show (Nx, mem, smem, imem)
    @show imem - smem
    @show sizeof(xsvd)

end

In [None]:
Nx = 2^12
(m,n) = svd_dimensions(Nx)
# m = convert(Int32, m)
# n = convert(Int32, n)
# nsv = convert(Int32, nsv)
x = randn(Float32, Nx)
xsvd = IAD.SVDVector(x, m, n, nsv)
@show size(xsvd.U), sizeof(xsvd.U)
@show size(xsvd.Vt), sizeof(xsvd.Vt)
@show size(xsvd.singular_values), sizeof(xsvd.singular_values)
;

In [None]:
@show Base.summarysize(xsvd.U)
@show Base.summarysize(xsvd.Vt)
@show Base.summarysize(xsvd.singular_values)

In [None]:
# Vector overhead + 2 Matrix overhead + 6 (3 Ints, 3 pointer) 64-bit fields in SVDVector struct
40 + 2*48 + 48

In [None]:
# Vector overhead + 2 Matrix overhead + 6 (3 Ints, 3 pointer) 32-bit fields in SVDVector struct
40 + 2*48 + 24

In [None]:
const OVERHEAD = 184
# const OVERHEAD32 = 160
const SVD_TOL = 1e-5
nsv_max = 5

tf = 1.0
cfl = 0.85
seed = 1

umax = 0.55
umin = 0.45

function weierstrass(x, a=0.825, b=7, N=2)
    vmax = 0.9
    vmin = 0.1
    vmid = 0.5 * (vmax + vmin)

    val = 0.0
    wbd = 0.0
    for n in 1:N
        val += a^n * cos(b^n * pi * x)
        wbd += a^n
    end

    val = 0.5 * (val + wbd) * (vmax - vmin) / wbd + vmin

    return val
end

function periodic_weierstrass(x, a=0.825, b=7, N=2)
    val = weierstrass(x, a, b, N)
    vmax = 0.9
    vmin = 0.1
    vmid = 0.5 * (vmax + vmin)
    m = vmin - vmax
    val -= m * x + vmax - vmid
    return val
end

my_params = Dict(
    # :Nx => Nx,
    :cfl => cfl,
    :tf => tf,
    :flux => :lf,
    :scale => 1e2,
    :ic => grid_control,
    # :target => periodic_weierstrass,
    :target => weierstrass
)

grid_sizes = 4:20
mem_save_bytes = zeros(Int, length(grid_sizes))
mem_save_bytes32 = zeros(Int, length(grid_sizes))

my_params[:Nx] = 2^maximum(grid_sizes)
x0 = ones(my_params[:Nx])
x0 = target_condition(x0, my_params)
p = plot(BurgersEquation.space_grid(my_params[:Nx]),
    BurgersEquation.expand_solution(x0),
    label=nothing,
    legend=:outerright,
)

for (idx, k) in enumerate(grid_sizes)

    Nx = 2^k
    (m,n) = svd_dimensions(Nx)
    nsv = min(m, n, nsv_max)

    rng = Random.MersenneTwister(seed)
    x0 = (umax - umin) .* randn(rng, Nx) .+ 0.5 * (umax + umin)

    my_params[:Nx] = Nx

    (m,n) = svd_dimensions(Nx)
    x0 = target_condition(x0, my_params)
    xsvd = IAD.SVDVector(x0, m, n, SVD_TOL);

    plot!(BurgersEquation.space_grid(Nx), 
        BurgersEquation.expand_solution(xsvd);
        label=Nx)

    @show k, Nx, nsv, xsvd.nsv

    # nsv = min(nsv, xsvd.nsv)

    bp = burger_setup(x0, my_params; save=false)

    savings_per_step = (Nx - svd_mem(nsv, Nx)) * sizeof(Float64) - OVERHEAD
    total_savings = savings_per_step * bp.Nt

    mem_save_bytes[idx] = total_savings > 0 ? total_savings : 1

    # savings_per_step = (Nx - svd_mem(nsv, Nx)) * sizeof(Float32) - OVERHEAD
    savings_per_step = Nx * sizeof(Float64) - svd_mem(nsv, Nx) * sizeof(Float32) - OVERHEAD
    total_savings = savings_per_step * bp.Nt

    mem_save_bytes32[idx] = total_savings > 0 ? total_savings : 1

end

display(p)

# # bp = burger_solution(x0, my_params; progress=true)

# savings_per_step = (Nx - svd_mem(nsv, Nx)) * sizeof(Float64) - overhead
# total_savings = savings_per_step * bp.Nt
# @show Nx, bp.Nt
# @show savings_per_step, savings_per_step / 1024, savings_per_step / 1024^2
# @show total_savings, total_savings / 1024, total_savings / 1024^2
@show mem_save_bytes
p = plot(2.0.^grid_sizes, mem_save_bytes ./ 1024^2, 
    legend=false, xscale=:log2, yscale=:log10,
    ylabel="Memory Saved(MB)",
    xlabel="N Grid Points",
    yticks=[10.0^ell for ell in -8:2:8],
    xticks=[2^k for k in grid_sizes],
    color=1,
)
scatter!(p, 2.0.^grid_sizes, mem_save_bytes ./ 1024^2, label=nothing, color=1)
plot!(p, 2.0.^grid_sizes, mem_save_bytes32 ./ 1024^2, label=nothing, color=2)
scatter!(p, 2.0.^grid_sizes, mem_save_bytes32 ./ 1024^2, label=nothing, color=2)
display(p)
p = plot(2.0.^grid_sizes, mem_save_bytes ./ 1024^3, 
    legend=false, xscale=:log2, yscale=:log10,
    ylabel="Memory Saved(GB)",
    xlabel="N Grid Points",
    yticks=[10.0^ell for ell in -8:2:8],
    xticks=[2^k for k in grid_sizes],
    color=1,
)
scatter!(p, 2.0.^grid_sizes, mem_save_bytes ./ 1024^3, label=nothing, color=1)
plot!(p, 2.0.^grid_sizes, mem_save_bytes32 ./ 1024^3, label=nothing, color=2)
scatter!(p, 2.0.^grid_sizes, mem_save_bytes32 ./ 1024^3, label=nothing, color=2)
display(p)
;

In [None]:
mem_save_bytes ./ 1024^3

In [None]:
mem_save_bytes32 ./ 1024^3

In [None]:
# bp = burger_solution(x0, my_params; save=true, progress=true)
# make_gif(bp, "blah"; fps=20)

In [None]:
k = 14
tf = 1.0
Nx = 2^k
cfl = 0.85
seed = 1
rng = Random.MersenneTwister(seed)
umax = 0.505
umin = 0.495
x0 = (umax - umin) .* randn(rng, Nx) .+ 0.5 * (umax + umin)

tol = 1e-4
xtol = 1e-8
my_options = Optim.Options(
    g_abstol=tol,
    x_abstol=xtol,
    outer_g_abstol=tol,
    outer_x_abstol=xtol,
    store_trace=false,
    extended_trace=false,
    show_trace=true,
)

imr_params = Dict(
    :Nx => Nx,
    :cfl => cfl,
    :tf => tf,
    :flux => :lf,
    :scale => 1e2,
    :ic => grid_control,
    :target => tf_sin,
    # :mode => :normal,
    :mode => :implicit,
)

@show imr_params

imr_params[:target] = target_condition(x0, imr_params)

svd_params = copy(imr_params)
svd_params[:mode] = :svd
svd_params[:matdim] = svd_dimensions(Nx)
svd_params[:tol] = -1e-5
svd_params[:nsv] = 1
@show svd_params[:target] === imr_params[:target]

rvs_params = copy(imr_params)
rvs_params[:mode] = :direct
@show rvs_params[:target] === imr_params[:target]
;

In [None]:
x0

In [None]:
imr_objective(x) = cost(x, imr_params)
function imr_gradient(g, x)
    RD.gradient!(g, imr_objective, x)
    return
end
function imr_tape(x)
    return RD.GradientTape(imr_objective, x)
end

rvs_objective(x) = cost(x, rvs_params)
function rvs_gradient(g, x)
    RD.gradient!(g, rvs_objective, x)
    return
end
function rvs_tape(x)
    return RD.GradientTape(rvs_objective, x)
end

svd_objective(x) = cost(x, svd_params)
function svd_gradient(g, x)
    RD.gradient!(g, svd_objective, x)
    return
end
function svd_tape(x)
    return RD.GradientTape(svd_objective, x)
end

# my_objective(x) = cost(x, my_params)

In [None]:
grad_rev = similar(x0)
grad_svd = similar(x0)
# grad_rvs = similar(x0)
@time imr_gradient(grad_rev, x0)
@time svd_gradient(grad_svd, x0)
# @time rvs_gradient(grad_rvs, x0)

In [None]:
# grad_rev

In [None]:
# grad_svd

In [None]:
# grad_rvs

In [None]:
LA.norm(grad_rev - grad_svd, 2)

In [None]:
svd_tp = svd_tape(x0)
rvs_tp = rvs_tape(x0)
imr_tp = imr_tape(x0)

In [None]:
length(svd_tp)

In [None]:
length(rvs_tp)

In [None]:
length(imr_tp)

In [None]:
(length(svd_tp) - length(imr_tp)) / Nx

In [None]:
rimr = DiffResults.GradientResult(x0)
rsvd = DiffResults.GradientResult(x0)
;

In [None]:
RD.gradient!(rimr, imr_tp, x0)
rimr.value

In [None]:
RD.gradient!(rsvd, svd_tp, x0)
rsvd.value

In [None]:
@show abs(rimr.value - rsvd.value)
@show LA.norm(rimr.derivs[1] - rsvd.derivs[1], 2)
;

In [None]:
bp = burger_solution(x0, svd_params; save=true)
;

In [None]:
make_gif(bp, "blah.gif"; fps=20)