In [None]:
using SparseArrays

In [None]:
struct FDMesh2D
    xl::Float64
    xr::Float64
    yl::Float64
    yr::Float64
    dx::Float64
    dy::Float64
    nx::Int64
    ny::Int64
    # nodes::Vector{Tuple{Float64, Float64}}
    nodes_idxs::CartesianIndices{2}
    midpoints_idxs_x::CartesianIndices{2} # On horizontal edges, staggered on x axis
    midpoints_idxs_y::CartesianIndices{2} # On vertical edges, staggered on y axis
end

In [None]:
function FDMesh2D(xl::Float64, xr::Float64, nx::Int64, yl::Float64, yr::Float64, ny::Int64)
    hx = (xr - xl) / nx
    hy = (yr - yl) / ny
    metrics = [hx, hy]
    nodes_idxs = CartesianIndices((nx + 1, ny + 1))
    midpoints_idxs_x = CartesianIndices((nx, ny + 1))
    midpoints_idxs_y = CartesianIndices((nx + 1, ny))
    return FDMesh2D(xl, xr, yl, yr, hx, hy, nx, ny, nodes_idxs, midpoints_idxs_x, midpoints_idxs_y)
    # nodes = [(collect(Tuple(idx)).-1) .* metrics for idx in nodes_idxs] |> vec
end  

In [None]:
function gridPointToLeftEdgeNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    mesh.nx * (y - 1) + (x - 1)
end

function gridPointToRightEdgeNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    mesh.nx * (y - 1) + x
end

function gridPointToTopEdgeNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * (y - 1) + x 
end

function gridPointToBottomEdgeNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * (y - 2) + x 
end

In [None]:
function generateEtaMatrix(mesh::FDMesh2D)
    eta_start = 0
    u_start = eta_start + (mesh.nx + 1) * (mesh.ny + 1)
    v_start = u_start + (mesh.nx) * (mesh.ny + 1)
    A = spzeros(Float64, (mesh.nx + 1) * (mesh.ny + 1), (mesh.nx + 1) * (mesh.ny + 1) + (mesh.nx) * (mesh.ny + 1) + (mesh.nx + 1) * (mesh.ny))

    
    i = 0
    for idx in mesh.nodes_idxs
        i += 1    
        x, y = Tuple(idx)      
        if x == 1 && y == 1
            # Bottom left corner
            continue
            
        elseif x == 1 && y == (mesh.ny + 1)
            # Top left corner
            continue
            
        elseif x == (mesh.nx + 1) && y == (mesh.ny + 1)
            # Top right corner
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            bottom_midpoint = gridPointToBottomEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,v_start + bottom_midpoint] = -1/mesh.dy
            
        elseif x == (mesh.nx + 1) && y == 1 
            # Bottom right corner
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            top_midpoint = gridPointToTopEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,v_start + top_midpoint] = 1/mesh.dy
            
        elseif x == 1
            # Left boundary
            continue
            
        elseif y == 1
            # Bottom boundary
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            right_midpoint = gridPointToRightEdgeNeighbor(mesh, x, y)
            top_midpoint = gridPointToTopEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,u_start + right_midpoint] = 1/mesh.dx 
            A[i,v_start + top_midpoint] = 1/mesh.dy
            
        elseif y == (mesh.ny + 1)
            # Top boundary
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            right_midpoint = gridPointToRightEdgeNeighbor(mesh, x, y)
            bottom_midpoint = gridPointToBottomEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,u_start + right_midpoint] = 1/mesh.dx 
            A[i,v_start + bottom_midpoint] = -1/mesh.dy
            
        elseif x == (mesh.nx + 1)
            # Right boundary
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            top_midpoint = gridPointToTopEdgeNeighbor(mesh, x, y)
            bottom_midpoint = gridPointToBottomEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,v_start + top_midpoint] = 1/mesh.dy
            A[i,v_start + bottom_midpoint] = -1/mesh.dy
            
        else
            # Inside
            left_midpoint = gridPointToLeftEdgeNeighbor(mesh, x, y)
            right_midpoint = gridPointToRightEdgeNeighbor(mesh, x, y)
            top_midpoint = gridPointToTopEdgeNeighbor(mesh, x, y)
            bottom_midpoint = gridPointToBottomEdgeNeighbor(mesh, x, y)
            
            A[i,u_start + left_midpoint] = -1/mesh.dx 
            A[i,u_start + right_midpoint] = 1/mesh.dx 
            A[i,v_start + top_midpoint] = 1/mesh.dy
            A[i,v_start + bottom_midpoint] = -1/mesh.dy
        end
            
    end
    A
end

In [None]:
function horizontalEdgePointLeftGridPointNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * (y - 1) + x
end

function horizontalEdgePointRightGridPointNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * (y - 1) + (x + 1)
end

In [None]:
function generateUMatrix(mesh::FDMesh2D)
    eta_start = 0
    u_start = eta_start + (mesh.nx + 1) * (mesh.ny + 1)
    v_start = u_start + (mesh.nx) * (mesh.ny + 1)
    A = spzeros(Float64, (mesh.nx) * (mesh.ny + 1), (mesh.nx + 1) * (mesh.ny + 1) + (mesh.nx) * (mesh.ny + 1) + (mesh.nx + 1) * (mesh.ny))

    i = 0
    for idx in mesh.midpoints_idxs_x
        i += 1    
        x, y = Tuple(idx)
        left_grid_point = horizontalEdgePointLeftGridPointNeighbor(mesh, x, y)
        right_grid_point = horizontalEdgePointRightGridPointNeighbor(mesh, x, y)
        A[i, eta_start + left_grid_point] = -1/mesh.dx
        A[i, eta_start + right_grid_point] = 1/mesh.dx
    end
    A
end

In [None]:
function verticalEdgePointTopGridPointNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * y + x
end

function verticalEdgePointBottomGridPointNeighbor(mesh::FDMesh2D, x::Int, y::Int)
    (mesh.nx + 1) * (y - 1) + x
end

In [None]:
function generateVMatrix(mesh::FDMesh2D)
    eta_start = 0
    u_start = eta_start + (mesh.nx + 1) * (mesh.ny + 1)
    v_start = u_start + (mesh.nx) * (mesh.ny + 1)
    A = spzeros(Float64, (mesh.nx + 1) * (mesh.ny), (mesh.nx + 1) * (mesh.ny + 1) + (mesh.nx) * (mesh.ny + 1) + (mesh.nx + 1) * (mesh.ny))

    i = 0
    for idx in mesh.midpoints_idxs_y
        i += 1    
        x, y = Tuple(idx)
        top_grid_point = verticalEdgePointTopGridPointNeighbor(mesh, x, y)
        bottom_grid_point = verticalEdgePointBottomGridPointNeighbor(mesh, x, y)
        A[i, eta_start + top_grid_point] = 1/mesh.dx
        A[i, eta_start + bottom_grid_point] = -1/mesh.dx
    end
    A
end

In [None]:
function generateFullMatrix(mesh::FDMesh2D, H::Float64, g::Float64)
    eta_mat =  - H .* generateEtaMatrix(mesh)
    u_mat = - g .* generateUMatrix(mesh)
    v_mat = - g .* generateVMatrix(mesh)
    [eta_mat; u_mat; v_mat]
end

In [None]:
function SWEEquationsLinear!(dw, w, p, t)
    mesh, A, ω, F_zero = p
    dw .= A * w
    i = 0
    for idx in mesh.nodes_idxs
        i += 1
        x, y = Tuple(idx) 
        if x == 1 && y <= (2 * mesh.ny / 3) && y >= (mesh.ny / 3)
            dw[i] = ω * F_zero * cos(ω * t) 
        end
    end
end

In [None]:
xl, xr = 0.0, 20.0
yl, yr = 0.0, 20.0
nx, ny = 75, 75
mesh = FDMesh2D(xl ,xr, nx, yl, yr, ny)

H = 2.0
g = 9.81

In [None]:
A = generateFullMatrix(mesh, H, g)

In [None]:
using DifferentialEquations

In [None]:
using Plots

In [None]:
ω = 1.0 * pi
F_zero = 0.5

p = (mesh, A, ω, F_zero)

u0 = [0.0 for xi = 1:((mesh.nx + 1) * (mesh.ny + 1) + (mesh.nx) * (mesh.ny + 1) + (mesh.nx + 1) * (mesh.ny))]

tspan = (0.0, 30.0)

prob = ODEProblem(SWEEquationsLinear!, u0, tspan, p)
sol = @time solve(prob, reltol = 1e-8, abstol = 1e-8, save_everystep = true)

In [None]:
t_start, t_end = tspan
framerate = 30
dt = 1 / framerate

total = 0.0
j = 1
indices::Vector{Int64} = []

for i in 1:(t_end-t_start)*framerate
    target = i * dt
    while total < target && j < length(sol.t)
        total = sol.t[j]
        j+=1
    end
    push!(indices, j)
end

In [None]:
anim = @animate for i in 1:(t_end-t_start)*framerate
    eta = sol.u[indices[Int64(i)]][1:((mesh.nx+1)*(mesh.ny+1))]
    eta_mat = reshape(eta, (mesh.nx+1), (mesh.ny+1))
    heatmap(eta_mat, clims=(-2.0, 2.0), title=string("t = ", round(i *dt, digits=2)))
end

gif(anim, "SWE_2D.gif", fps=framerate)