In [None]:
import ModelingToolkit as Model
import SymPy as sp
import Symbolics as Symb
using DomainSets
import ApproxFun as AF
using NonlinearSolve
import DifferentialEquations as DE
include("multiharmonic_balance.jl");
using MethodOfLines

In [None]:
gamma = 0;
omega = 5.0;
gamma3 = 0.0;
g0::Float64 = 9.80665; # m / s^2
height = 5.0; # m

In [None]:
xleft::Float64 = 0.0;
xright::Float64 = 1.0;
yleft = 0.0;
yright = 1.0;
Nt = 5
Nx = Ny = 10;
harmonics = 1; # number of harmonics
order = 2;
stepx = (xright-xleft)/Nx;
stepy = (yright - yleft)/Ny;
u0 = 0.01 * randn((Nx) * (Ny) * harmonics * 2);

In [None]:
# Define symbolics
Model.@parameters x, y, t;

const Dy = Model.Differential(y)
const Dx = Model.Differential(x);
const Dt = Model.Differential(t);

In [None]:
function build_problem(x, y, t, omega, harmonics, xleft, xright, yleft, yright, gamma, gamma3)
    vars, var_exprs, u = create_ansatz((x, y), t, omega, harmonics);
    bcs = create_bcs(vars, ((xleft, xright), (yleft, yright)), (x, y), 0.0);
    F = 50 * exp(-40*(x^2)) * sin(omega*t);
    pde::Symbolics.Num = Dt(Dt(u)) - 9*(Dx(Dx(u)) + Dy(Dy(u))) + gamma*Dt(u) + gamma3*Dt(u)*Dt(u)*Dt(u) - F;
    return pde, bcs, var_exprs;
end

In [None]:
@time begin # check type stability, 56% recompilation
pde, bcs, var_exprs = build_problem(x, y, t, omega, harmonics, xleft, xright, yleft, yright, gamma, gamma3);
end;

In [None]:
@time begin
expanded = expand_trig_jl(pde, t, omega); # check type stability 14% recompilation
end;

In [None]:
@time begin
eqns = make_residual(expanded, harmonics, omega, t);
end;

In [None]:
eqns[1]

In [None]:
using MacroTools: prewalk
using Symbolics: Differential

function transform_symino(ex, var1::Symbol, var2::Symbol, dx, dy)
    return prewalk(ex) do node
        if node isa Expr && node.head == :call
            f = node.args[1]
            
            # Second derivative in var1
            if f isa Differential && Symbol(f.x) == var1
                # println(f)
                inner = node.args[2]
                if inner isa Expr && inner.head == :call
                    f2 = inner.args[1]
                    if f2 isa Differential && Symbol(f2.x) == var1
                        content = inner.args[2]
                        base = content.args[1]
                        return :(($base[i+1, j] - 2*$base[i, j] + $base[i-1, j]) / $dx^2)
                    end
                end
            end
            
            # Second derivative in var2
            if f isa Differential && Symbol(f.x) == var2
                inner = node.args[2]
                if inner isa Expr && inner.head == :call
                    f2 = inner.args[1]
                    if f2 isa Differential && Symbol(f2.x) == var2
                        content = inner.args[2]
                        base = content.args[1]
                        return :(($base[i, j+1] - 2*$base[i, j] + $base[i, j-1]) / $dy^2)
                    end
                end
            end
            
            # First derivative in var1
            if f isa Differential && Symbol(f.x) == var1
                content = node.args[2]
                return :(($content[i+1, j] - $content[i-1, j]) / (2*$dx))
            end
            
            # First derivative in var2
            if f isa Differential && Symbol(f.x) == var2
                content = node.args[2]
                return :(($content[i, j+1] - $content[i, j-1]) / (2*$dy))
            end
            
            # Convert function calls like A(x, y) -> A[i, j]
            if f isa Symbol && length(node.args) == 3
                arg1, arg2 = node.args[2], node.args[3]
                if arg1 == var1 && arg2 == var2
                    return :($(f)[i, j])
                elseif arg1 == var2 && arg2 == var1
                    return :($(f)[j, i])
                end
            end
        end
        
        return node
    end
end

In [None]:
expr = Symbolics.toexpr(eqns)

In [None]:
using Latexify
latexify(transform_symino(expr[1], :x, :y, stepx, stepy))

In [None]:
solution_coeffs = solve_harmonicbalance(
                        eqns, harmonics, bcs,
                        ((xleft, xright), (yleft, yright)),
                        (stepx, stepy),
                        (x, y),
                        var_exprs, u0
                    )