In [22]:
from ngsolve import *
from netgen.geom2d import SplineGeometry
from ngsolve.webgui import Draw
from netgen.occ import *
from netgen.webgui import Draw as DrawGeo
from ngsolve.krylovspace import GMResSolver
from time import perf_counter as time

tend = 0.5
maxh = 0.2

def solveWaveEq(maxh, method, draw=False): #for maxh in [0.2, 0.1, 0.05, 0.025]: 
    shape = Rectangle(2,2).Face().Move((-1,-1,0))
    mesh = Mesh(OCCGeometry(shape, dim=2).GenerateMesh(maxh=maxh))
    
    fes = H1(mesh, order=4)
    X = fes * fes
    (u,v), (ut, vt) = X.TnT()
    t = 0.0
    dt = 0.01
    idt = 1/dt
    mu = 0.1
    
    a = BilinearForm(X, symmetric=False)
    a += idt * v * vt * dx + idt * u * ut * dx - v * ut * dx #mu*grad(u)*grad(vt)*dx 
    aP = Preconditioner(a, "local")
    a.Assemble()
    
    g = BilinearForm(X, symmetric=False)
    g += -idt * v * vt * dx - idt * u * ut * dx + mu*grad(u)*grad(vt)*dx 
    g.Assemble()
    
    gf = GridFunction(X)
    gfu, gfv = gf.components
    rr = x**2 + y**2 
    C = 10.
    def tanh(xx): return (exp(xx) - exp(-xx)) / (exp(xx) + exp(-xx))
    u0 = 1 - tanh(10 * rr) #(exp(rr) - exp(-rr)) / (exp(rr) + exp(-rr))
    gfu.Set(u0)
    if draw:
        scene = Draw(u0, mesh)
    
    def TimeStepping(inv, t0 = 0, tend = 0.2,
                     saveEvery=10):
        cnt = 0; t = t0
        gfut = GridFunction(gfu.space,multidim=0)
        gfut.AddMultiDimComponent(gfu.vec)
        while t < tend - 0.5 * dt:
            res = -g.mat * gf.vec
            gf.vec.data = inv * res
            print("\r{:1.4f}".format(t),end="")
            if draw: scene.Redraw()
            if cnt % saveEvery == 0:
                gfut.AddMultiDimComponent(gfu.vec)
            cnt += 1; t = cnt * dt
        return gfut
    
    t0 = time()
    if method == "direct": 
        inv = a.mat.Inverse()
    elif method == "iterative":
        inv = GMResSolver(a.mat, aP.mat, printrates=False, maxiter=200, tol=1e-6)
    gfu.Set(u0)
    gfv.Set(0.0)
    gfut = TimeStepping(inv, tend=tend)
    tf = time() - t0
    if draw:
        Draw(gfut, mesh, interpolate_multidim=True, animate=True)
    return tf


t0 = time()
solveWaveEq(maxh, 'direct', draw=False)
t_direct = time() - t0

t0 = time()
solveWaveEq(maxh, 'iterative', draw=True)
t_iter = time() - t0

print("\nSolve time direct={:2.3f}".format(t_direct))
print("Solve time iterative={:2.3f}".format(t_iter))
print("Ratio:", t_iter/t_direct)
    

0.4900

WebGuiWidget(layout=Layout(height='50vh', width='100%'), value={'gui_settings': {}, 'ngsolve_version': '6.2.23…

0.4900

WebGuiWidget(layout=Layout(height='50vh', width='100%'), value={'gui_settings': {}, 'ngsolve_version': '6.2.23…


Solve time direct=0.426
Solve time iterative=6.384
Ratio: 14.969362044990156
