In [28]:
using Statistics: mean

import FFTW

using Interact
using PyPlot

In [52]:
function solve_poisson_3d_pbc(f, Lx, Ly, Lz)
    Nx, Ny, Nz = size(f)  # Number of grid points (excluding the periodic end point).

    # Forward transform the real-valued source term.
    fh = FFTW.rfft(f)

    # Wavenumber indices.    
    l1 = 0:Int(Nx/2)
    l2 = Int(-Nx/2 + 1):-1
    m1 = 0:Int(Ny/2)
    m2 = Int(-Ny/2 + 1):-1
    n1 = 0:Int(Nz/2)
    n2 = Int(-Nz/2 + 1):-1

    kx = reshape((2π/Lx) * cat(l1, l2, dims=1), (Nx, 1, 1))
    ky = reshape((2π/Ly) * cat(m1, m2, dims=1), (1, Ny, 1))
    kz = reshape((2π/Ly) * cat(n1, n2, dims=1), (1, 1, Nz))

    k² = @. kx^2 + ky^2 + kz^2

    ϕh = - fh ./ k²[1:Int(Nx/2 + 1), :, :]

    # Setting the DC/zero Fourier component to zero.
    ϕh[1, 1, 1] = 0

    # Take the inverse transform of the solution's Fourier coefficients.
    ϕ = FFTW.irfft(ϕh, Nx)
end

solve_poisson_3d_pbc (generic function with 1 method)

In [51]:
Lx, Ly, Lz = 8, 8, 8  # Domain size.
Nx, Ny, Nz = 64, 64, 64    # Number of grid points.
Δx, Δy, Δz = Lx/Nx, Ly/Ny, Lz/Nz  # Grid spacing.

# Grid point locations.
x = Δx * (0:(Nx-1))
y = Δy * (0:(Ny-1))
z = Δz * (0:(Nz-1))

# Primed coordinates to easily calculate a Gaussian centered at
# (Lx/2, Ly/2).
x′ = reshape(x .- Lx/2, (Nx, 1, 1))
y′ = reshape(y .- Ly/2, (1, Ny, 1))
z′ = reshape(z .- Lz/2, (1, 1, Nz))

f = @. (4*x′^2 + 4*y′^2 + 4*z′^2 - 6) * exp(-(x′^2 + y′^2 + z′^2))  # Source term
f .= f .- mean(f)  # Ensure that source term integrates to zero.

ϕa = @. exp(-(x′^2 + y′^2 + z′^2))  # Analytic solution

ϕs = solve_poisson_3d_pbc(f, Lx, Ly, Lz);

size(fh) = (33, 64, 64)


In [46]:
# @show size(x)
# @show size(y)
# @show size(f)
# @show minimum(f)
# @show maximum(f)
fig = figure()
@manipulate for n in 1:Nz
    withfig(fig) do
        # PyPlot.contourf(x, y, f[:, :, n], levels=20, vmin=-6, vmax=0.5); PyPlot.colorbar();
        PyPlot.pcolormesh(x, y, f[:, :, n], vmin=-6, vmax=0.5); PyPlot.colorbar();
    end
end

In [56]:
@show minimum(ϕs)
@show maximum(ϕs)
ϕs = ϕs .- minimum(ϕs)
fig = figure()
@manipulate for n in 1:Nz
    withfig(fig) do
        # PyPlot.contourf(x, y, f[:, :, n], levels=20, vmin=-6, vmax=0.5); PyPlot.colorbar();
        PyPlot.pcolormesh(x, y, ϕs[:, :, n], vmin=0, vmax=1); PyPlot.colorbar();
    end
end

minimum(ϕs) = 0.0
maximum(ϕs) = 1.0000000622106684
