In [None]:
import Logging: global_logger
import TerminalLoggers: TerminalLogger
global_logger(TerminalLogger())

In [None]:
using LinearAlgebra
using SparseArrays
using ModelingToolkit
using MacroTools: @capture, postwalk, prewalk
using DifferentialEquations
using NonlinearSolve
using Plots

Nx = Ny = 50;
harmonics = 3; # number of harmonics
u0_HB = 0.005 * zeros((Nx+1) * (Ny+1) * harmonics * 2);
order = 2;
Nt = 5
N = Nx + 1

xleft::Float64 = 0.0
xright::Float64 = 1.0
yleft::Float64 = 0.0
yright::Float64 = 1.0
stepx = (xright-xleft)/(N-1)
stepy = (yright-yleft)/(N-1)

#Definition of constants (placeholders):
gamma::Float64 = 0.1
gamma_3::Float64 = 0.1
c::Float64 = 0.5
omega::Float64 = 1



@parameters x, y, t
@variables u(..)

Dx = Differential(x)
Dy = Differential(y)
Dt = Differential(t)
#Full PDE: Dt(Dt(u(x, y))) - c*c*(Dx(Dx(u(x, y))) + Dy(Dy(u(x, y)))) + gamma*Dt(u(x, y)) + gamma_3*Dt(u(x, y))*Dt(u(x, y))*Dt(u(x, y))
#Reformat as ddu = y_eq
y_eq = c*c*(Dx(Dx(u(x, y))) + Dy(Dy(u(x, y)))) - gamma*Dt(u(x, y)) - gamma_3*Dt(u(x, y))*Dt(u(x, y))*Dt(u(x, y))
# y_eq = expand_derivatives(y_eq)

In [None]:
function transform_sym(ex)
    return prewalk(ex) do instr
        if @capture(instr, Differential(x)(Differential(x)(s_(x, y))))
            return :(($s[i+1, j]-2* $s[i, j]+$s[i-1, j])/dx^2)
            elseif @capture(instr, Differential(y)(Differential(y)(s_(x, y))))
            return :(($s[i, j+1]-2* $s[i, j]+$s[i, j-1])/dy^2)
            elseif @capture(instr, Differential(x)(s_(x, y)))
            return :(($s[i+1, j] - $s[i-1, j]) / (2*dx))
            elseif @capture(instr, Differential(y)(s_(x, y)))
            return :(($s[i, j+1] - $s[i, j-1]) / (2*dy))
            elseif @capture(instr, Differential(t)(s_(x, y)))
            return :($(Symbol("d$(s)_array"))[i, j])
            elseif @capture(instr, Differential(x, 2)(s_(x, y)))
            return :(($s[i+1, j]-2* $s[i, j]+$s[i-1, j])/dx^2)
            elseif @capture(instr, Differential(y, 2)(s_(x, y)))
            return :(($s[i, j+1]-2* $s[i, j]+$s[i, j-1])/dy^2)
            elseif @capture(instr, Differential(x, 1)(s_(x, y)))
            return :(($s[i+1, j] - $s[i-1, j]) / (2*dx))
            elseif @capture(instr, Differential(y, 1)(s_(x, y)))
            return :(($s[i, j+1] - $s[i, j-1]) / (2*dy))
            elseif @capture(instr, Differential(t, 1)(s_(x, y)))
            return :($(Symbol("d$(s)_array"))[i, j])
            elseif @capture(instr, s_(x, y))
            return :($(Symbol("$(s)_array"))[i, j])
            elseif @capture(instr, s_[i_, j_])
            return :($(Symbol("$(s)_array"))[$(i), $(j)])
        end
        return instr
    end
end
println(y_eq)
println("_____________________________________________")
println(transform_sym(Meta.parse(string(y_eq))))

function create_ODE_function(N, y_expr)
    function_code = quote
        function secODE!(ddu, du, u, p, t)
            dx, dy = p
            grid_size = $N * $N

            u_array = reshape(@view(u[1:grid_size]), $N, $N)
            du_array = reshape(@view(du[1:grid_size]), $N, $N)
            ddu_array = reshape(@view(ddu[1:grid_size]), $N, $N)

            # BCs (Dirichlet for now):
            ddu_array[1,:] .= 0; ddu_array[end,:] .= 0
            ddu_array[:,1] .= 0; ddu_array[:,end] .= 0

            # Inner points:
            for i in 2:$(N-1)
                for j in 2:$(N-1)
                    ddu_array[i, j] = ($y_expr) - 250 * exp(-40*(i * dx)^2) * sin(t)
                end
            end

            return ddu
        end
    end
    return eval(function_code)
end

y_expr = transform_sym(Meta.parse(string(y_eq)))

ODEfunc = create_ODE_function(N, y_expr)
u0_FD = zeros(N*N)
du0 = zeros(N*N)
tspan = (0.0, 120.0)
par = [stepx, stepy]
prob = SecondOrderODEProblem(ODEfunc, du0, u0_FD, tspan, par)

@time sol = solve(prob, DPRKN6(), saveat=0.1, progress=true, progress_steps=200)

tgrid = sol.t


xgrid = range(xleft, xright, length=N)
ygrid = range(yleft, yright, length=N)


u_data = [reshape(sol.u[k].x[1], N, N) for k in eachindex(sol.t)]

anim = @animate for (i, t) in enumerate(tgrid)
    heatmap(xgrid, ygrid, u_data[i],
            color=:magma,
            xlabel="x", ylabel="y",
            title="u(x,y,t) at t=$t s",
            clims=(-7.5, 7.5),
            aspect_ratio=1)
end

gif(anim, "FD_sol.gif", fps=25)

In [None]:
println(typeof(u_data[1]))

## Harmonic Balance Solution

In [None]:
import SymPy as sp
import Symbolics as Symb
using DomainSets
import ApproxFun as AF
include("multiharmonic_balance.jl");

In [None]:
function build_problem(x, y, t, omega, harmonics, xleft, xright, yleft, yright, gamma, gamma3)
    vars, var_exprs, (u,) = create_ansatz((x, y), t, omega, harmonics);
    F = 250 * exp(-40*(x^2)) * sin(omega*t);
    pde::Symbolics.Num = Dt(Dt(u)) - 0.25*(Dx(Dx(u)) + Dy(Dy(u))) + gamma*Dt(u) + gamma3*Dt(u)*Dt(u)*Dt(u) - F;
    return pde, var_exprs, vars;
end

function simplify_problem(pde, t, omega, harmonics, Nx, Ny, vars)
    expanded = expand_trig_jl(pde, t, omega)
    eqns = make_equations(expanded, harmonics, omega, t)
    sym_eqs = map(transform_sym(Nx, Ny) ∘ Meta.parse ∘ string, eqns)
    resid = create_residual_function(sym_eqs, vars, Nx, Ny)
    return resid
end

function solve_problem(resid, harmonics, Nx, Ny, u0, stepx, stepy)
    residual! = eval(resid)
    N = (Nx+1) * (Ny+1)
    
    R = similar(u0)

    # Simplify the jacobian to detect sparsity
    jac_sparsity = Symbolics.jacobian_sparsity(
        (R, U) -> Base.invokelatest(residual!, R, U, [stepx, stepy]), 
        R, u0
    )
    wrapped_residual! = (R, U, p) -> Base.invokelatest(residual!, R, U, p)
    
    f = NonlinearFunction(wrapped_residual!; jac_prototype=float.(jac_sparsity))
    prob = NonlinearProblem(f, u0, [stepx, stepy])
    sol = solve(prob, NewtonRaphson(), reltol=1e-5, abstol=1e-5, maxiters=2000)
    
    println(sol.retcode)
    solutions = [reshape(sol.u[(k-1)*N+1:k*N], Nx+1, Ny+1) for k in 1:(2*harmonics)]
    return solutions, sol
end

In [None]:
compute_time = @elapsed begin
    pde, var_exprs, vars = build_problem(x, y, t, omega, harmonics, xleft, xright, yleft, yright, gamma, gamma_3)
    println(harmonics)
    println(var_exprs)
    resid = simplify_problem(pde, t, omega, harmonics, Nx, Ny, vars)
    solutions, sol = solve_problem(resid, harmonics, Nx, Ny, u0_HB, stepx, stepy)
end;
println("Compute time: $compute_time seconds")

In [None]:
## Analysis

In [None]:
fs_time = fs = 1/0.1
t_steady0 = 500
target = 3 * omega / (2 * pi)

function generateComplexAmplitudeFDSpectrumMatrix(sol, target)
    r1 = zeros(N, N)
    rows = 0
    println(rows)
    for x_idx in 1:1:N
        for y_idx in 1:1:N
            u_at_point = [sol[t][x_idx, y_idx] for t in 1:length(sol)]
            F = fftshift(fft(u_at_point))
            freqs = fftshift(fftfreq(length(u_at_point), fs_time))
            (_, idx) = findmin(abs.(freqs .- target))
            r1[x_idx, y_idx] = abs(F[idx]) / (length(sol) - t_steady0)
        end
    end
    r1
end

using FFTW
fd_complex_spectrum = generateComplexAmplitudeFDSpectrumMatrix(u_data, target)
heatmap(fd_complex_spectrum)

In [None]:
coefficientsHB = solutions
println(size(coefficientsHB[1]))
function generateComplexAmplitudeHBSpectrumMatrix(h)
    r1 = zeros(N, N)
    rows = 0
    for x_idx in 1:1:N
        for y_idx in 1:1:N
            #f = h * omega / 2π
            r1[x_idx, y_idx] = abs(0.5 * coefficientsHB[2 * h][x_idx, y_idx] - im * 0.5 * coefficientsHB[2 * h - 1][x_idx, y_idx])
        end  
    end
    r1
end
hb_complex_spectrum = generateComplexAmplitudeHBSpectrumMatrix(3)
heatmap(hb_complex_spectrum)