In [None]:
use_cupy = False

if use_cupy:
    import cupy as np
else:
    import numpy as np
# import cupy as np
import numpy as real_np
import argparse
import matplotlib.pyplot as plt
#import gt4py


In [None]:
# Initialize model parameters
M = 512 # args.M
N = 512 # args.N
M_LEN = M + 1
N_LEN = N + 1
L_OUT = True # args.L_OUT
VIS = True
VIS_DT=100
VAL = True
ITMAX = 1000
dt = 90.
tdt = dt
dx = 100000.
dy = 100000.
fsdx = 4. / (dx)
fsdy = 4. / (dy)
a = 1000000.
alpha = 0.001
el = N * dx
pi = 4. * np.arctan(1.)
tpi = 2. * pi
d_i = tpi / M
d_j = tpi / N
pcf = (pi * pi * a * a) / (el * el)
SIZE = M_LEN * N_LEN

# Model Variables
u = np.zeros((M_LEN, N_LEN))
v = np.zeros((M_LEN, N_LEN))
p = np.zeros((M_LEN, N_LEN))
unew = np.zeros((M_LEN, N_LEN))
vnew = np.zeros((M_LEN, N_LEN))
pnew = np.zeros((M_LEN, N_LEN))
uold = np.zeros((M_LEN, N_LEN))
vold = np.zeros((M_LEN, N_LEN))
pold = np.zeros((M_LEN, N_LEN))
uvis = np.zeros((M_LEN, N_LEN))
vvis = np.zeros((M_LEN, N_LEN))
pvis = np.zeros((M_LEN, N_LEN))
cu = np.zeros((M_LEN, N_LEN))
cv = np.zeros((M_LEN, N_LEN))
z = np.zeros((M_LEN, N_LEN))
h = np.zeros((M_LEN, N_LEN))
psi = np.zeros((M_LEN, N_LEN))


In [None]:
from IPython.display import clear_output
from matplotlib import pyplot as plt
%matplotlib inline
    

def live_plot3(fu, fv, fp, title=''):
    clear_output(wait=True)
    fig, (ax1, ax2, ax3) = plt.subplots(figsize=(13, 3), ncols=3)

    pos1 = ax1.imshow(fp, cmap='Blues', vmin=49999, vmax=50001,interpolation='none')
    ax1.set_title('p')
    pos2 = ax2.imshow(fu, cmap='Reds', vmin=-1, vmax=1,interpolation='none')
    ax2.set_title('u')
    pos3 = ax3.imshow(fv, cmap='Greens',vmin=-1, vmax=1,interpolation='none')
    ax3.set_title('v')

    fig.suptitle(title)
    #plt.xlabel('x')
    #plt.ylabel('y')
    plt.show()

def live_plot_val(fu, fv, fp, title=''):
    mxu = fu.max()
    mxv = fv.max()
    mxp = fp.max()
    clear_output(wait=True)
    fig, (ax1, ax2, ax3) = plt.subplots(figsize=(13, 3), ncols=3)

    pos1 = ax1.imshow(fp, cmap='Blues', vmin=-mxp, vmax=mxp,interpolation='none')
    ax1.set_title('p')
    plt.colorbar(pos1,ax=ax1)
    pos2 = ax2.imshow(fu, cmap='Reds', vmin=-mxu, vmax=mxu,interpolation='none')
    ax2.set_title('u')
    plt.colorbar(pos2,ax=ax2)
    pos3 = ax3.imshow(fv, cmap='Greens',vmin=-mxv, vmax=mxv,interpolation='none')
    ax3.set_title('v')
    plt.colorbar(pos3, ax=ax3)

    fig.suptitle(title)
    #plt.xlabel('x')
    #plt.ylabel('y')
    plt.show()


In [None]:
%matplotlib inline
# Initial values of the stream function and p

psi[...] = a * np.sin((np.arange(0, M_LEN)[:, np.newaxis]+0.5) * d_i) * np.sin((np.arange(0, N_LEN) +0.5) * d_j)
p[...] = pcf * (np.cos(2. * np.arange(0, M_LEN)[:, np.newaxis] * d_i) + np.cos(2. * np.arange(0, N_LEN) * d_j)) + 50000.
            
# Calculate initial u and v
u[1:, :-1] = -(psi[1:, 1:] - psi[1:, :-1]) / dy
v[:-1, 1:] = (psi[1:, 1:] - psi[:-1, 1:]) / dx    
        

if VIS==True:
    if isinstance(u, real_np.ndarray):
        live_plot3(u, v, p, "init")
    else:
        live_plot3(u.get(), v.get(), p.get(), "init")
    print(p.max())
    print(p.min())
    print(u.max())
    print(u.min())
    print(v.max())
    print(v.min())


In [None]:
%matplotlib inline

# Periodic Boundary conditions

u[0, :] = u[M, :]
v[M, 1:] = v[0, 1:]
u[1:, N] = u[1:, 0]
v[:, 0] = v[:, N]

u[0, N] = u[M, 0]
v[M, 0] = v[0, N]


if VIS==True:
    if isinstance(u, real_np.ndarray):
        live_plot3(u, v, p, "Periodic Bounday Conditions")
    else:
        live_plot3(u.get(), v.get(), p.get(), "Periodic Bounday Conditions")
    
# Save initial conditions
uold = np.copy(u)
vold = np.copy(v)
pold = np.copy(p)


In [None]:
# Print initial conditions
if L_OUT:
    print(" Number of points in the x direction: ", M)
    print(" Number of points in the y direction: ", N)
    print(" grid spacing in the x direction: ", dx)
    print(" grid spacing in the y direction: ", dy)
    print(" time step: ", dt)
    print(" time filter coefficient: ", alpha)
    print(" Initial p:\n", p.diagonal()[:-1])
    print(" Initial u:\n", u.diagonal()[:-1])
    print(" Initial v:\n", v.diagonal()[:-1])
        


In [None]:
%matplotlib inline

time = 0.0 
# Main time loop
for ncycle in range(ITMAX):
    if((ncycle%100==0) & (VIS==False)):
        print("cycle number ", ncycle)
    # Calculate cu, cv, z, and h
    cu[1:,:-1] = .5 * (p[1:,:-1] + p[:-1,:-1]) * u[1:,:-1]
    cv[:-1,1:] = .5 * (p[:-1,1:] + p[:-1,:-1]) * v[:-1,1:]
    z[1:,1:] = (fsdx * (v[1:,1:] - v[:-1,1:]) - fsdy * (u[1:,1:] - u[1:,:-1])) / (p[:-1,:-1] + p[1:,:-1] + p[1:,1:] + p[:-1,1:])
    h[:-1,:-1] = p[:-1,:-1] + 0.25 * (u[1:,:-1] * u[1:,:-1] + u[:-1,:-1] * u[:-1,:-1] + v[:-1,1:] * v[:-1,1:] + v[:-1,:-1] * v[:-1,:-1])
   
    # # Periodic Boundary conditions
    cu[0, :] = cu[M, :]
    h[M, :] = h[0, :]
    cv[M, 1:] = cv[0, 1:]
    z[0, 1:] = z[M, 1:]
    
    cv[:, 0] = cv[:, N]
    h[:, N] = h[:, 0]
    cu[1:, N] = cu[1:, 0]
    z[1:, N] = z[1:, 0]
        
    cu[0, N] = cu[M, 0]
    cv[M, 0] = cv[0, N]
    z[0, 0] = z[M, N]
    h[M, N] = h[0, 0]
        
    # Calclulate new values of u,v, and p
    tdts8 = tdt / 8.
    tdtsdx = tdt / dx
    tdtsdy = tdt / dy
    #print(tdts8, tdtsdx, tdtsdy)
              
    unew[1:,:-1]=uold[1:,:-1] + tdts8 * (z[1:,1:] + z[1:,:-1]) * (cv[1:,1:] + cv[1:,:-1] + cv[:-1,1:] + cv[:-1,:-1]) - tdtsdx * (h[1:,:-1] - h[:-1,:-1])
    vnew[:-1,1:]= vold[:-1,1:] - tdts8 * (z[1:,1:] + z[:-1,1:]) * (cu[1:,1:] + cu[1:,:-1] + cu[:-1,1:] + cu[:-1,:-1]) - tdtsdy * (h[:-1,1:] - h[:-1,:-1])
    pnew[:-1,:-1] =pold[:-1,:-1] - tdtsdx * (cu[1:,:-1] - cu[:-1,:-1]) - tdtsdy * (cv[:-1,1:] - cv[:-1,:-1])
    
    # Periodic Boundary conditions
    unew[0, :] = unew[M, :]
    pnew[M, :] = pnew[0, :]
    vnew[M, 1:] = vnew[0, 1:]
    unew[1:, N] = unew[1:, 0]
    vnew[:, 0] = vnew[:, N]
    pnew[:, N] = pnew[:, 0]
    
    unew[0, N] = unew[M, 0]
    vnew[M, 0] = vnew[0, N]
    pnew[M, N] = pnew[0, 0]
    
    time = time + dt

    if(ncycle > 0):
        uold[...]=u+alpha*(unew-2.*u+uold)
        vold[...]=v+alpha*(vnew-2.*v+vold)
        pold[...]=p+alpha*(pnew-2.*p+pold)
                            

        u[...]=unew
        v[...]=vnew
        p[...]=pnew

    else:
        tdt = tdt+tdt

        uold = np.copy(u)
        vold = np.copy(v)
        pold = np.copy(p)
        u = np.copy(unew)
        v = np.copy(vnew)
        p = np.copy(pnew)

    if((VIS == True) & (ncycle%VIS_DT==0)):
        if isinstance(u, real_np.ndarray):
            live_plot3(u, v, p, "ncycle: " + str(ncycle))
        else:
            live_plot3(u.get(), v.get(), p.get(), "ncycle: " + str(ncycle))


In [None]:
 # Print initial conditions
if L_OUT:
        print("cycle number ", ITMAX)
        print(" diagonal elements of p:\n", pnew.diagonal()[:-1])
        print(" diagonal elements of u:\n", unew.diagonal()[:-1])
        print(" diagonal elements of v:\n", vnew.diagonal()[:-1])

In [None]:
if VAL:

    u_val_f = 'ref/u.64.64.IT4000.txt'
    v_val_f = 'ref/v.64.64.IT4000.txt'
    p_val_f = 'ref/p.64.64.IT4000.txt'
    uval = np.zeros((M_LEN, N_LEN))
    vval = np.zeros((M_LEN, N_LEN))
    pval = np.zeros((M_LEN, N_LEN))

    uref, vref, pref = read_arrays(v_val_f, u_val_f, p_val_f)
    uval = uref-unew
    vval = vref-vnew
    pval = pref-pnew
    
    uLinfN= np.linalg.norm(uval, np.inf)
    vLinfN= np.linalg.norm(vval, np.inf)
    pLinfN= np.linalg.norm(pval, np.inf)

    

    live_plot_val(uval, vval, pval, "Val")
    print("uLinfN: ", uLinfN)
    print("vLinfN: ", vLinfN)
    print("pLinfN: ", pLinfN)
    print("udiff max: ",uval.max())
    print("vdiff max: ",vval.max())
    print("pdiff max: ",pval.max())