In [6]:
#9/4/2024

using FFTW
using LinearAlgebra
using MultiFloats
MultiFloats.use_bigfloat_transcendentals()
using Plots
using BenchmarkTools
using Printf

flush(stdout)

function fourier_diff_matrix_naive(N::Int; L::Real = 2π, T = Float64)
    @assert iseven(N) "N must be even."

    k = vcat(0:N÷2, -N÷2+1:-1)
    k = T.(2π / L .* k)
    D1_operator = Complex{T}.(im .* k)

    ω = exp(-2π * im / N)
    F = [ω^((m-1)*(n-1)) for m in 1:N, n in 1:N] ./ sqrt(T.(N))
    F_inv = inv(F)

    D = real(F_inv * Diagonal(D1_operator) * F)
    return Matrix{T}(D)  # force type T here
end


function newtons_method(f, x0, Dx, dt, tol=1e-10, maxiter=100)
    x = x0
    Id = Matrix{typeof(x0[1])}(I, length(x0), length(x0))
    for i in 1:maxiter
        fx = f(x)
        if norm(fx) < tol
            return x
        end
        J = (dt / 2) * -Dx * Diagonal(x) - Id
        x = x - J \ fx
    end
    return x
    #error("Newton's method did not converge")
end


function implicit_midpoint_step(u::Vector{H}, D::Matrix{H}, Dlow::Matrix{L}, dt::H, ::Type{H}, ::Type{L}; correct::Bool= false, n_corr::Int = 2, B::Matrix{H}, update::Bool = true) where {H, L}
    ulow = L.(u)
    f = y -> -y + ulow + L(dt)/2 * (L(-0.5) * Dlow * (y.^2))
    tol = 10 * eps(L) 
    y1 = newtons_method(f, ulow, Dlow, L(dt), tol) #as long as this is computed in low precision it should be fine
    y1 = H.(y1)
    if correct
        tol_corr = 10 * eps(H)
        f = y -> y - u + dt/2 * (0.5 * D * (y.^2))
        y0 = copy(y1)
        r0 = f(y0)
        r1 = copy(r0)
        dy = similar(y0)
        df = similar(y0)
        E  = similar(B)   # independent work matrix

        for i in 1:n_corr
            y1 .= y0 .- B*r0
            if update
                r1 .= f(y1)
                if norm(r1) < tol_corr
                    break
                end
                dy .= y1 .- y0          # broadcast avoids temporary array for subtraction
                df .= r1 .- r0
    
                denom = dot(df, df)    # same as (b * df), scalar
                if denom != 0
                    E .= (dy .- B*df) * df' / denom   # df' is the adjoint (row vector), no extra b
                    B .+= E                          # in-place update, avoids new matrix
                end
            end
            y0 .= y1               # copy values instead of rebinding
            r0 .= r1
        end    
    end
    return u + dt*(-0.5*D*(y1.^2))
end

function rk4_step(u, D, dt)
    k1 = -0.5*D*(u.^2)
    k2 = -0.5*D*(u + 0.5 * dt * k1).^2
    k3 = -0.5*D*(u + 0.5 * dt * k2).^2
    k4 = -0.5*D*(u + dt * k3).^2
    return u + (dt / 6) * (k1 + 2*k2 + 2*k3 + k4)
end

function norm_L2(u, dx)
    return sqrt(sum(abs2, u) * dx)
end


using Dates  # for timing

function run_mixed_precision_burgers(low::Type, high::Type; N=800, T_Final=0.7, correct::Bool= false, corr_id::Int = 1, n_corr = 1)
    L = 2π
    x = L * (0:N-1) / N
    dx = L / N

    # Use naive diff for high, FFTW for low as long as low is fp64 or less
    Dx_high = fourier_diff_matrix_naive(N; L=L, T=high)
    Dx_low  = fourier_diff_matrix_naive(N; L=L, T=low)
    #Dx_high = fourier_diff_matrix(N; L=L, T=high)  # FFTW for high precision
    #Dx_low  = fourier_diff_matrix(N; L=L, T=low)   # FFTW for low precision


    # Initial condition (fallback to BigFloat for sin)
    u0 = high.(sin.(BigFloat.(x)))

    # Reference RK4 solution
    dt_rk = 0.0001
    nt_rk = Int(round(T_Final / dt_rk))
    uref = copy(u0)

    for _ in 1:nt_rk
        uref = rk4_step(uref, Dx_high, dt_rk)
    end

    # Run mixed-precision IMR
    dt = 1e-2
    nt = Int(round(T_Final / dt))
    u = copy(u0)

    Id_high = Matrix{high}(I, N, N)
    if correct 
        if corr_id == 1
            B = Id_high
            upd = false
        elseif corr_id == 2
            Id = Matrix{low}(I, N, N)
            J = (low(dt)/2) * (Dx_low * Diagonal(low.(u0))) + Id
            B = Matrix{high}(inv(J))   # invert in low precision, then convert once
            upd = true
        elseif corr_id ==  3
            B = 0.01*Id_high
            upd = true
        else 
            Id = Matrix{low}(I, N, N)
            J = Diagonal(diag((low(dt)/2) * (Dx_low * Diagonal(low.(u0))) + Id))
            B = Matrix{high}(inv(J))
            upd = true
        end
    else 
        B = Id_high
        upd = false
    end


    result = @timed begin
        temp_u = copy(u0)
        for _ in 1:nt
            temp_u = implicit_midpoint_step(temp_u, Dx_high, Dx_low, high(dt), high, low, correct = correct,n_corr= n_corr, B = B, update = upd)
        end
        temp_u
    end

    runtime = result.time # Convert nanoseconds to seconds
    u = result.value  # Final value after all steps


    # L2 Error
    norm_L2(u, dx) = sqrt(sum(abs2, u) * dx)
    error_L2 = norm_L2(uref - u, dx)

    return (
        error = error_L2,
        time = runtime,
        u = u,
        uref = uref,
        x = x,
    )

end




run_mixed_precision_burgers (generic function with 1 method)

In [7]:
errors = Float64[]
times = Float64[]
labels = String[]

precision_pairs = [
    (Float16, Float64),
    #(Float16, Float64x2),
    #(Float32, Float64),
    (Float64, Float64),
    #(Float32, Float64x2),
    #(Float64, Float64x2),
    #(Float64x2, Float64x2),
    #(Float64x2, Float64x8),
]


println("Running mixed precision Burgers' equation solver...")
for (low, high) in precision_pairs
    println("\nTesting low = $(low), high = $(high)")
    result = run_mixed_precision_burgers(low, high)
    @printf("L2 error = %.4e\n", result.error)
    #println("Runtime = ", result.time, " seconds")

end

Running mixed precision Burgers' equation solver...

Testing low = Float16, high = Float64
L2 error = 5.2946e-03

Testing low = Float64, high = Float64
L2 error = 2.4460e-05


In [9]:

precision_pairs = [
    (Float16, Float64),
]
n_corr_list = [2,5, 8, 10, 15]

println("Running mixed precision Burgers' equation solver with corrections...")
for n_corr in n_corr_list
    println("\nNumber of corrections = $n_corr ")
    for i in 4:4
        for (low, high) in precision_pairs
            println("\nTesting low = Float16, high = $(high), v = $i")
            result = run_mixed_precision_burgers(low, high, correct = true, corr_id = i, n_corr = n_corr)
            @printf("L2 error = %.4e\n", result.error)
        end
    
    end
end 

Running mixed precision Burgers' equation solver with corrections...

Number of corrections = 8 

Testing low = Float16, high = Float64, v = 4
L2 error = 3.7974e-04


In [5]:
precision_pairs = [
    (Float16, Float64),
]
n_corr_list = [2, 5, 8, 10, 15]

println("Running mixed precision Burgers' equation solver with corrections...")
for n_corr in n_corr_list
    println("\nNumber of corrections = $n_corr ")
    for (low, high) in precision_pairs
        println("\nTesting low = Float16, high = $(high), v = 1")
        result = run_mixed_precision_burgers(low, high, correct = true, corr_id = 1, n_corr = n_corr)
        @printf("L2 error = %.4e\n", result.error)
    end
end 

Running mixed precision Burgers' equation solver with corrections...

Number of corrections = 2 

Testing low = Float16, high = Float64, v = 1


LoadError: ArgumentError: matrix contains Infs or NaNs