In [None]:
#9/4/2024

using FFTW
using LinearAlgebra
using MultiFloats
MultiFloats.use_bigfloat_transcendentals()
using Plots
using BenchmarkTools

flush(stdout)

function fourier_diff_matrix_naive(N::Int; L::Real = 2π, T = Float64)
    @assert iseven(N) "N must be even."

    k = vcat(0:N÷2, -N÷2+1:-1)
    k = T.(2π / L .* k)
    D1_operator = Complex{T}.(im .* k)

    ω = exp(-2π * im / N)
    F = [ω^((m-1)*(n-1)) for m in 1:N, n in 1:N] ./ sqrt(T.(N))
    F_inv = inv(F)

    D = real(F_inv * Diagonal(D1_operator) * F)
    return Matrix{T}(D)  # force type T here
end

function newtons_method(f, x0, Dx, dt, tol=1e-10, maxiter=1000)
    x = x0
    Id = Matrix{typeof(x0[1])}(I, length(x0), length(x0))
    for i in 1:maxiter
        fx = f(x)
        if norm(fx) < tol
            return x
        end
        J = (dt / 2) * -Dx * Diagonal(x) - Id
        x = x - J \ fx
    end
    error("Newton's method did not converge")
end

function broydens_method(f, Ji, x0, Dx, dt, good, tol=1e-10, maxiter=1000)
    x = x0
    Id = Matrix{typeof(x0[1])}(I, length(x0), length(x0))
    for i in 1:maxiter
        fx = f(x)
        xold = x

        if norm(fx) < tol
            return x
        end

        dx = -Ji * fx
        x = x + (dx)  
        y = f(x) - fx

        # Good Broyden Update
        if good
            b = dx'*Ji
        else
            b = y'
        end

        Ji = Ji + ((dx - (Ji * y))* (b))./(b * y);

        if i == maxiter
            print(norm(fx, Inf))
        end

    end
    error("Good Broyden's method did not converge")
end

function implicit_midpoint_step_newton(u::Vector{H}, D::Matrix{H}, Dlow::Matrix{L}, dt::H, ::Type{H}, ::Type{L}) where {H, L}
    ulow = L.(u)
    f = y -> -y + ulow + L.(dt)/2 * (L.(-0.5) * Dlow * (y.^2))
    tol = 10 * eps(L) 
    y1 = newtons_method(f, ulow, Dlow, L.(dt), tol) #as long as this is computed in low precision it should be fine
    y1 = H.(y1)
    return u + dt*(-0.5*D*(y1.^2))
end

function implicit_midpoint_step_broyden(u::Vector{H}, Dx_high::Matrix{H}, Dxlow::Matrix{L}, dt::H, ::Type{H}, ::Type{L}, good::Bool) where {H, L}
    # temp_u, Dx_high, Dx_low, high(dt), high, low
    uhigh = H.(u)
    f = y -> -y + uhigh + H.(dt)/2 * (H.(-0.5) * Dx_high * (y.^2)) 

    Id = Matrix{typeof(uhigh[1])}(I, length(uhigh), length(uhigh))
    J = (dt / 2) * -Dx_high * Diagonal(uhigh) - Id
    Ji = L.(inv(J))

    tol = 10 * eps(H)
    
    y1 = broydens_method(f, Ji, uhigh, Dx_high, H.(dt), tol, good)
    y1 = H.(y1)
    return u + dt*(-0.5*Dx_high*(y1.^2))
end

function rk4_step(u, D, dt)
    k1 = -0.5*D*(u.^2)
    k2 = -0.5*D*(u + 0.5 * dt * k1).^2
    k3 = -0.5*D*(u + 0.5 * dt * k2).^2
    k4 = -0.5*D*(u + dt * k3).^2
    return u + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
end

function norm_L2(u, dx)
    return sqrt(sum(abs2, u) * dx)
end

using Dates  # for timing

function run_mixed_precision_burgers(low::Type, high::Type; N=200, T_Final=0.8)
    L = 2π
    x = L * (0:N-1) / N
    dx = L / N

    # Use naive diff for high, FFTW for low as long as low is fp64 or less
    Dx_high = fourier_diff_matrix_naive(N; L=L, T=high)
    Dx_low  = fourier_diff_matrix_naive(N; L=L, T=low)
    #Dx_high = fourier_diff_matrix(N; L=L, T=high)  # FFTW for high precision
    #Dx_low  = fourier_diff_matrix(N; L=L, T=low)   # FFTW for low precision


    # Initial condition (fallback to BigFloat for sin)
    u0 = high.(sin.(BigFloat.(x)))

    # Reference RK4 solution
    dt_rk = 0.00001
    nt_rk = Int(round(T_Final / dt_rk))
    uref = copy(u0)

    for _ in 1:nt_rk
        uref = rk4_step(uref, Dx_high, dt_rk)
    end

    # Run mixed-precision IMR
    dt = 0.001
    nt = Int(round(T_Final / dt))
    u = copy(u0)

    good = true # true for Good Broyden and false for Bad Broyden
    # Broyden's
    result = @timed begin
        temp_u = copy(u0)
        for _ in 1:nt
            temp_u = implicit_midpoint_step_broyden(temp_u, Dx_high, Dx_low, high(dt), high, low, good)
        end
        temp_u
    end

    runtime = result.time / 1e9  # Convert nanoseconds to seconds
    u = result.value  # Final value after all steps

    # L2 Error
    norm_L2(u, dx) = sqrt(sum(abs2, u) * dx)
    error_L2 = norm_L2(uref - u, dx)


    # Newton's 
    result = @timed begin
        temp_u = copy(u0)
        for _ in 1:nt
            temp_u = implicit_midpoint_step_newton(temp_u, Dx_high, Dx_low, high(dt), high, low)
        end
        temp_u
    end

    Nruntime = result.time / 1e9  # Convert nanoseconds to seconds
    Nu = result.value  # Final value after all steps

    Nerror_L2 = norm_L2(uref - Nu, dx)


    return (
        error = error_L2,
        time = runtime,
        u = u,
        uref = uref,
        x = x,

        Nerror = Nerror_L2,
        Ntime = Nruntime,
        Nu = Nu,
    )

end

errors = Float64[]
times = Float64[]
labels = String[]

precision_pairs = [
    (Float32, Float64x2)
]

println("Running mixed precision Burgers' equation solver...")
for (low, high) in precision_pairs
    println("\nTesting low = $(low), high = $(high)\n")
    result = run_mixed_precision_burgers(low, high)

    # Broyden
    println("Broyden's Results:")
    println("L2 error = ", result.error)
    println("Runtime = ", result.time, " seconds\n")

    # Newton
    println("Newton's Results:")
    println("L2 error = ", result.Nerror)
    println("Runtime = ", result.Ntime, " seconds")


end


ErrorException: syntax: optional positional arguments must occur at end around /Users/tejsaikakumanu/Documents/MATLAB/Broydens Method/Current_VersionControl/PDEsResearch/Latest_June20/MixedPrecision/MixedPrecision/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_W0sZmlsZQ==.jl:41