Multi-level Extension
===

In [None]:
from ngsolve import *
from netgen.geom2d import unit_square
from ngsolve.webgui import Draw

In [None]:
mesh = Mesh(unit_square.GenerateMesh(maxh=0.3))
fes = H1(mesh, order=1, dirichlet="left|bottom", autoupdate=True)
u,v = fes.TnT()
gfu = GridFunction(fes, autoupdate=True)

for i in range(3):
    mesh.Refine()

a = BilinearForm(grad(u)*grad(v)*dx).Assemble()

In [None]:
class MLExtension:
    def __init__ (self, fes, level, bndmass, dofs):
        self.fes = fes
        self.level = level
        self.bndmass = bndmass
        self.dofs = dofs
        
        ones = bndmass.CreateRowVector()
        ones[:] = 1
        Mones = (bndmass*ones).Evaluate()
        self.inv = DiagonalMatrix(Mones).Inverse(dofs)
        
        if level > 0:
            self.prol = fes.Prolongation().CreateMatrix(level)
            self.rest = self.prol.CreateTranspose()
            coarsebndmass = self.rest @ bndmass @ self.prol # multiply matrices
            coarsedofs = BitArray(self.prol.width)
            coarsedofs[:] = False
            for i in range(len(coarsedofs)):
                coarsedofs[i] = dofs[i]
            self.coarseext = MLExtension(fes, level-1, coarsebndmass, coarsedofs)
        
    def Extend (self, x):
        Mx = (self.bndmass * x).Evaluate()
        x.data = self.ExtendRec(Mx)
        
    def ExtendRec (self, Mx):
        sol = (self.inv * Mx).Evaluate()
        if self.level == 0:
            return sol
        
        if self.level > 0:
            self.fes.Prolongation().Restrict(self.level, Mx)
            xc = self.coarseext.ExtendRec(Mx)
            pxc = (self.fes.Prolongation().Operator(self.level) * xc).Evaluate()
            
        pxc[self.dofs] = sol
        return pxc

In [None]:
bnd = mesh.Boundaries("left|bottom")
gfu.Set (1-x-y+0.3*sin(30*x), definedon=bnd)

Draw (gfu)
print ("Norm(u) = ", InnerProduct((a.mat*gfu.vec).Evaluate(), gfu.vec))

bndmass = BilinearForm(u*v*ds(bnd)).Assemble().mat
ext = MLExtension(fes, fes.mesh.levels-1, bndmass, fes.GetDofs(bnd))
ext.Extend(gfu.vec)

Draw (gfu)
print ("Norm(uext) = ", InnerProduct(a.mat*gfu.vec, gfu.vec))