In [None]:
## PREPARATIONS BEFORE PARALLEL
const USE_GPU     = false                       # Whether to use GPU
using ParallelStencil
using ParallelStencil.FiniteDifferences2D
@static if USE_GPU
    @init_parallel_stencil(CUDA,    Float64, 2)
else
    @init_parallel_stencil(Threads, Float64, 2)
end
using Images, FileIO, Interpolations, ColorTypes, TestImages, Plots, ImageTransformations



## CREATE PARALLEL FUNCTIONS TO BE USED IN THE PSEUDO-TRANSIENT LOOP
@parallel function compute_Vgrad!(dVxdx::Data.Array, dVydy::Data.Array, dVxdy::Data.Array, dVydx::Data.Array, Vx::Data.Array, Vy::Data.Array, dx::Data.Number, dy::Data.Number)
    @all(dVxdx)   = @d_xa(Vx)/dx
    @all(dVydy)   = @d_ya(Vy)/dy
    @all(dVxdy)   = @d_ya(Vx)/dy
    @all(dVydx)   = @d_xa(Vy)/dx
    return
end

@parallel function compute_P!(divV::Data.Array, P::Data.Array, dVxdx::Data.Array, dVydy::Data.Array, Gdτ::Data.Array, r::Data.Number, dx::Data.Number, dy::Data.Number)
    @all(divV)    = @av_ya(dVxdx) + @av_xa(dVydy)
    @all(P)       = @all(P) - r*@all(Gdτ)*@all(divV)
    return
end

@parallel function compute_E!(Exx::Data.Array, Eyy::Data.Array, Exy::Data.Array, dVxdx::Data.Array, dVydy::Data.Array, dVxdy::Data.Array, dVydx::Data.Array)
    @all(Exx)     = @av_ya(dVxdx)
    @all(Eyy)     = @av_xa(dVydy)
    @all(Exy)     = 0.5 * ( @av_xa(dVxdy) + @av_ya(dVydx) )
    return
end

@parallel function compute_τ!(τxx::Data.Array, τyy::Data.Array, τxy::Data.Array, Exx::Data.Array, Eyy::Data.Array, Exy::Data.Array, etan::Data.Array, Gdτ::Data.Array)
    @all(τxx)     = ( @all(τxx) + 2.0*@all(Gdτ)*@all(Exx) )  /  ( @all(Gdτ)/@all(etan) + 1.0 )
    @all(τyy)     = ( @all(τyy) + 2.0*@all(Gdτ)*@all(Eyy) )  /  ( @all(Gdτ)/@all(etan) + 1.0 )
    @all(τxy)     = ( @all(τxy) + 2.0*@all(Gdτ)*@all(Exy) )  /  ( @all(Gdτ)/@all(etan) + 1.0 )
    return
end

@parallel function compute_Pτgrad!(dPdx::Data.Array, dPdy::Data.Array, dτxxdx::Data.Array, dτyydy::Data.Array, dτxydx::Data.Array, dτxydy::Data.Array, P::Data.Array, τxx::Data.Array, τyy::Data.Array, τxy::Data.Array, dx::Data.Number, dy::Data.Number)
    @all(dPdx)    = @d_xa(P)  /dx
    @all(dPdy)    = @d_ya(P)  /dy
    @all(dτxxdx)  = @d_xa(τxx)/dx
    @all(dτyydy)  = @d_ya(τyy)/dy
    @all(dτxydx)  = @d_xa(τxy)/dx
    @all(dτxydy)  = @d_ya(τxy)/dy
    return
end

@parallel function compute_dV!(Rx::Data.Array, Ry::Data.Array, dVx::Data.Array, dVy::Data.Array, dPdx::Data.Array, dPdy::Data.Array, dτxxdx::Data.Array, dτyydy::Data.Array, dτxydx::Data.Array, dτxydy::Data.Array, dτ_Rho::Data.Array, PI2::Data.Number, rho_Rock::Data.Array, rho_sc::Data.Number)
    @all(Rx)      = - @av_ya(dPdx) + PI2 * @av_xa(dτxydy) + PI2 * @av_ya(dτxxdx)
    @all(Ry)      = - @av_xa(dPdy) + PI2 * @av_ya(dτxydx) + PI2 * @av_xa(dτyydy) - PI2 * @av(rho_Rock) / rho_sc
    @all(dVx)     =   @av(dτ_Rho)  * @all(Rx)
    @all(dVy)     =   @av(dτ_Rho)  * @all(Ry)
    return
end

@parallel function compute_V!(Vx::Data.Array, Vy::Data.Array, dVx::Data.Array, dVy::Data.Array)
    @inn(Vx)      = @inn(Vx) + @all(dVx)
    @inn(Vy)      = @inn(Vy) + @all(dVy)
    return
end

@parallel function boundary_value!(P_x::Data.Array, P_y::Data.Array, etan_x::Data.Array, etan_y::Data.Array, P::Data.Array, etan::Data.Array)
    @all(P_x)     = @av_xa(P)
    @all(P_y)     = @av_ya(P)
    @all(etan_x)  = @av_xa(etan)
    @all(etan_y)  = @av_ya(etan)
    return
end

@parallel_indices (ix) function TopBottom!(Vy::Data.Array, dy::Data.Number, p_water::Data.Number, p_NVP::Data.Number, P_x::Data.Array, etan_x::Data.Array)
    Vy[ix+1, 1]   = Vy[ix+1, 2]     - dy * ( - p_water - p_NVP + P_x[ix, 1] ) / etan_x[ix, 1]   / 2
    Vy[ix+1, end] = Vy[ix+1, end-1] + dy * ( - p_water + P_x[ix, end]       ) / etan_x[ix, end] / 2
    return
end

@parallel_indices (iy) function LeftRight!(Vx::Data.Array, dx::Data.Number, p_NMHP::Data.Number, P_y::Data.Array, etan_y::Data.Array, DL::Data.Array)
    Vx[1,iy+1]    = Vx[2, iy+1]     - dx * ( - p_NMHP * DL[1,iy+1]   + P_y[1, iy]   ) / etan_y[1,iy]   / 2
    Vx[end,iy+1]  = Vx[end-1, iy+1] + dx * ( - p_NMHP * DL[end,iy+1] + P_y[end, iy] ) / etan_y[end,iy] / 2 
    return
end



## CREATE OTHER FUNCTIONS TO BE USED OUTSIDE THE PSEUDO-TRANSIENT LOOP
# Smooth the viscosity
function smooth(A::Array, factor::Number)
    m, n  = size(A)
    A2    = zeros(m, n)
    for i = 2:m-1
        for j = 2:n-1
            A2[i, j] = A[i, j] + 1.0/4.1/factor * (A[i-1, j] - 2 * A[i, j] + A[i+1, j] + A[i, j-1] - 2 * A[i, j] + A[i, j+1])
        end
    end
    # Copy boundary values
    A2[1, :]   = A[1, :]
    A2[end, :] = A[end, :]
    A2[:, 1]   = A[:, 1]
    A2[:, end] = A[:, end]
    return A2
end

# Compute the location of the maximum viscosity of the grid and the grids adjacent to it
function compute_maxloc(matrix::Matrix)
    rows, cols = size(matrix)
    result     = copy(matrix)
    for r = 1:rows
        for c = 1:cols
            neighbors    = [
                matrix[max(r-1, 1)   , c             ], 
                matrix[min(r+1, rows), c             ], 
                matrix[r             , max(c-1, 1)   ], 
                matrix[r             , min(c+1, cols)], 
                matrix[r             , c             ]
            ]
            result[r, c] = maximum(neighbors)
        end
    end
    return result
end



## MAIN FUNCTION
@views function SaltInclusion2D()
    # Physics
    A_Salt      = 1e17                          # Salt viscosity [Pa-s]
    A_Shale     = 1e15                          # Shale viscosity [Pa-s]
    g           = 9.81                          # Gravity [m/s2]
    p_grad_h    = 0.75*22620.6                  # Minimum horizontal stress gradient [Pa/m]
    p_grad_v    = 0.85*22620.6                  # Vertical stress gradient [Pa/m]
    rho_Shale   = p_grad_v/g                    # Bulk Shale density [kg/m3]
    rho_Salt    = 2200                          # Salt density [kg/m3]
    rho_Water   = 1020                          # Water density [kg/m3]
    D_sf        = 800                           # Seafloor depth [m]
    
    # Input dimensions
    Ly          = 10000                         # Domain's Y-length
    aspect      = 3                             # Aspect ratio
    
    # Scales
    rho_sc      = rho_Shale                     # Density scale
    eta_sc      = A_Shale                       # Rock viscosity scale
    L_sc        = ((eta_sc^2)/g/(rho_sc^2))^(1/3)  # Length scale
    PI2         = g*(L_sc^3)*(rho_sc^2)/(eta_sc^2) # Make PI2 = 1 to stabilize
    Ly          = Ly/L_sc                       # Non-dimensionalized Y-length
    v_sc        = rho_sc*g*(L_sc^2)/eta_sc      # Velocity scale
    p_sc        = (eta_sc^2)/(rho_sc*L_sc^2)    # Pressure scale
    p_water     = rho_Water*g*D_sf/p_sc         # Seawater pressure at the top boundary
    p_NMHP      = p_grad_h*Ly*L_sc/p_sc         # Nondimensionalized minimum horizontal pressure
    p_NVP       = p_grad_v*Ly*L_sc/p_sc         # Nondimensionalized vertical pressure
    
    # Numerics
    nx          = 50                            # Number of blocks in X-direction
    ny          = 50                            # Number of blocks in Y-direction
    epsi        = 1e-8                          # Tolerance
    niter       = 5e7                           # Maximum iterations
    nout_iter   = 1000                          # Output iterations
    CFL         = 0.9/sqrt(2)                   # Courant-Friedrichs-Lewy
    Re          = 3/2*sqrt(10)*pi               # Reynolds number
    r           = 1.0                           # Bulk to shear modulus ratio
    err         = 2*epsi                        # Initialize errors
    err_evo1    = []
    err_evo2    = []
    iter        = 0                             # Initialize iteration
    
    # Pre-processing
    Lx          = aspect * Ly                   # X-length
    dx          = Lx/(nx-1)                     # Block X-size
    dy          = Ly/(ny-1)                     # Block Y-size
    xn          = -dx/2:dx:Lx+dx/2              # Grid X-location
    yn          = -dy/2:dy:Ly+dy/2              # Grid Y-location
    xc          = 0:dx:Lx                       # Block center X-location
    yc          = 0:dy:Ly                       # Block center Y-location
    max_lxy     = max(Lx, Ly)                   # Maximum block length
    Vpdτ        = min(dx, dy) * CFL             # P-wave velocity * pseudo time step
    Vx          = @zeros(nx+1, ny+1)            # Initialize Velocity-X
    Vy          = @zeros(nx+1, ny+1)            # Initialize Velocity-Y
    P           = @zeros(nx  , ny  )            # Initialize Total Pressure
    τxx         = @ones (nx  , ny  )            # Initialize Deviatoric Stress XX
    τyy         = @ones (nx  , ny  )            # Initialize Deviatoric Stress YY
    τxy         = @zeros(nx  , ny  )            # Initialize Deviatoric Stress XY
    divV        = @zeros(nx  , ny  )            # Initialize Divergence of Velocity
    dVxdx       = @zeros(nx  , ny+1)            # Initialize dVx/dx
    dVydy       = @zeros(nx+1, ny  )            # Initialize dVy/dy
    dVxdy       = @zeros(nx+1, ny  )            # Initialize dVx/dy
    dVydx       = @zeros(nx  , ny+1)            # Initialize dVy/dx
    Exx         = @zeros(nx  , ny  )            # Initialize Strain XX
    Eyy         = @zeros(nx  , ny  )            # Initialize Strain YY
    Exy         = @zeros(nx  , ny  )            # Initialize Strain XY
    dPdx        = @zeros(nx-1, ny  )            # Initialize dP/dx
    dPdy        = @zeros(nx  , ny-1)            # Initialize dP/dy
    dτxxdx      = @zeros(nx-1, ny  )            # Initialize dτxx/dx
    dτyydy      = @zeros(nx  , ny-1)            # Initialize dτyy/dy
    dτxydx      = @zeros(nx-1, ny  )            # Initialize dτxy/dx
    dτxydy      = @zeros(nx  , ny-1)            # Initialize dτxy/dy
    Rx          = @zeros(nx-1, ny-1)            # Initialize Residual in the X-direction
    Ry          = @zeros(nx-1, ny-1)            # Initialize Residual in the Y-direction 
    dVx         = @zeros(nx-1, ny-1)            # Initialize Velocity-X Difference during convergence
    dVy         = @zeros(nx-1, ny-1)            # Initialize Velocity-Y Difference during convergence
    P_x         = @zeros(nx-1, ny  )            # Initialize Pressure matrix after averaging adjacent row values
    P_y         = @zeros(nx  , ny-1)            # Initialize Pressure matrix after averaging adjacent column values
    etan        = @ones (nx  , ny  )            # Initialize Rock Viscosity (Shale)
    etan_x      = @zeros(nx-1, ny  )            # Initialize Viscosity matrix after averaging adjacent row values
    etan_y      = @zeros(nx  , ny-1)            # Initialize Viscosity matrix after averaging adjacent column values
    DL          = @zeros(nx+1, ny+1)            # Initialize Dimensionless Length, DL = grid depth / domain depth, DL at seafloor = 0
    DL[1,:]     = LinRange(1, 0, ny+1)          
    
    # Salt location and density
    SaltFraction = zeros(nx,ny)                 # Initialize Salt Fraction
    img = load("SaltGeometry2D.JPG")            # Load the Salt Geometry
    gray_img = Gray.(img)                       # Convert the image to grayscale
    rotated_img = imrotate(gray_img, pi/2)      # Rotate the grayscale image by 90 degrees
    bw_float = channelview(rotated_img)         # Convert the rotated image to binary (black and white)
    bw = bw_float .> 0.5                        # Threshold the binary image at 0.5 to get a binary image
    resized_img = imresize(bw, (nx, ny))        # Resize the binary image to have dimensions [nx, ny]
    SaltFraction .= resized_img .< 0.5          # Invert the resized image and assign to SaltFraction

    # Rock Viscosity
    etan[SaltFraction.==1.0] .= A_Salt/A_Shale  # Replace Shale Viscosity by Salt Viscosity
    
    # Initialize Rock Density
    rho_Rock    = fill(rho_Shale, nx, ny)       # Initialize Shale Density
    rho_Rock[SaltFraction.==1.0] .= rho_Salt    # Initialize Salt Density
    
    # Smear out the coefficients of Stokes equation
    for ism = 1:10
        etan = smooth(etan, 1.0)                # Smoothed Rock Viscosity
    end
    
    # Calculate the Maximum Rock Viscosity Among Neighbors
    etanmax         = compute_maxloc(etan)
    etanmax[1, :]   = etanmax[2, :]
    etanmax[end, :] = etanmax[end - 1, :]
    etanmax[:, 1]   = etanmax[:, 2]
    etanmax[:, end] = etanmax[:, end - 1]
    
    # Pre-processing of numerics
    dτ_Rho = zeros(nx, ny)                      # Pseudo Time Step / Density
    dτ_Rho = Vpdτ * max_lxy / Re ./ etanmax
    Gdτ    = zeros(nx, ny)                      # G * Pseudo Time Step
    Gdτ    = (Vpdτ^2) ./ dτ_Rho / (r+2)
    
    
    # Pseudo-Transient Iterations
    while err > epsi && iter <= niter
        # Compute Velocity Gradients
        @parallel compute_Vgrad!(dVxdx, dVydy, dVxdy, dVydx, Vx, Vy, dx, dy)
        
        # Compute Pressure
        @parallel compute_P!(divV, P, dVxdx, dVydy, Gdτ, r, dx, dy)
        
        # Compute Strain
        @parallel compute_E!(Exx, Eyy, Exy, dVxdx, dVydy, dVxdy, dVydx)
        
        # Compute Stress
        @parallel compute_τ!(τxx, τyy, τxy, Exx, Eyy, Exy, etan, Gdτ)
        
        # Compute Gradients of Pressure and Deviatoric Stress 
        @parallel compute_Pτgrad!(dPdx, dPdy, dτxxdx, dτyydy, dτxydx, dτxydy, P, τxx, τyy, τxy, dx, dy)
        
        # Compute Residuals of Stokes Equation
        @parallel compute_dV!(Rx, Ry, dVx, dVy, dPdx, dPdy, dτxxdx, dτyydy, dτxydx, dτxydy, dτ_Rho, PI2, rho_Rock, rho_sc)
        
        # Compute Velocity
        @parallel compute_V!(Vx::Data.Array, Vy::Data.Array, dVx::Data.Array, dVy::Data.Array)
        
        # Apply Boundary Conditions
        @parallel boundary_value!(P_x, P_y, etan_x, etan_y, P, etan)
        @parallel (1:size(P_x,1)) TopBottom!(Vy, dy, p_water, p_NVP, P_x, etan_x)
        @parallel (1:size(P_y,2)) LeftRight!(Vx, dx, p_NMHP, P_y, etan_y, DL)

        # Update iteration
        iter = iter + 1
        
        # Compare error and tolerance
        if iter % nout_iter == 0
            Vmin    = minimum(Vx)
            Vmax    = maximum(Vx)
            Pmin    = minimum(P)
            Pmax    = maximum(P)
            norm_Rx = norm(Rx  ) / (Pmax - Pmin) * Lx / sqrt(length(Rx)  )
            norm_Ry = norm(Ry  ) / (Pmax - Pmin) * Lx / sqrt(length(Ry)  )
            norm_dV = norm(divV) / (Vmax - Vmin) * Lx / sqrt(length(divV))
            err     = maximum([norm_Rx, norm_Ry, norm_dV])
            push!(err_evo1, err )
            push!(err_evo2, iter)
            println("Total steps = $iter, err = $err [norm_Rx=$norm_Rx, norm_Ry=$norm_Ry, norm_dV=$norm_dV]")
        end
        
        # Make plots
        if iter % nout_iter == 1
            p1 = heatmap(xc, yc, (Array(τxx))'  , aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Deviatoric Stress τxx")
            p2 = heatmap(xc, yc, (Array(τyy))'  , aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Deviatoric Stress τyy")
            p3 = heatmap(xc, yc, (Array(τxy))'  , aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Deviatoric Stress τxy")
            p4 = heatmap(xc, yc, (Array(P))'    , aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Pressure"             )
            p5 = heatmap(xn, yn, (Array(Vx))'   , aspect_ratio=1, xlims=extrema(xn), ylims=extrema(yn), c=:viridis, title="Velocity Vx"          )
            p6 = heatmap(xn, yn, (Array(Vy))'   , aspect_ratio=1, xlims=extrema(xn), ylims=extrema(yn), c=:viridis, title="Velocity Vy"          )
            p7 = heatmap(xc, yc, (Array(τxx+P))', aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Stress Sxx"           )
            p8 = heatmap(xc, yc, (Array(τyy+P))', aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Stress Syy"           )
            p9 = heatmap(xc, yc, (Array(etan))' , aspect_ratio=1, xlims=extrema(xc), ylims=extrema(yc), c=:viridis, title="Viscosity"            )
            display(plot(p1, p2, p3, p4, p5, p6, p7, p8, p9, layout=(3,3)))
        end
        
    end
end

## RUN CODE
SaltInclusion2D = begin SaltInclusion2D();
return;
end