In [1]:
import numpy as np
import time
import matplotlib.pyplot as plt

from numba import jit,int32,float64,boolean


%matplotlib notebook

In [2]:
"""
Create the exact solution for the poisson equation

"""
@jit(float64[:,:](int32),nopython=True,nogil=True)
def create_solution(m):
    x = np.linspace(0,1,m+2)[1:-1]
    y = np.linspace(0,1,m+2)[1:-1]
    
    
    n_x = len(x)
    n_y = len(y)
    u= np.zeros((n_x,n_y))
    for i,x_i in enumerate(x): 
        for j,y_j in enumerate(y):
            u[i,j] = (x_i**2-x_i**4)*(y_j**4-y_j**2)
    return u

In [3]:
%timeit create_solution(65)

The slowest run took 6.80 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 3.68 µs per loop


In [4]:
"""
Gauss seidel iteration implementation for Poisson Equation.
"""
@jit(float64[:,:](float64[:,:],float64[:,:], float64),nopython=True,nogil=True)
def gauss_seidel_iteration(x, b, h = None):

    n_x,n_y = x.shape
    
    if not h:
        h = 1./n_x
    
    for i in range(n_x):
        for j in range(n_y):
            node_sum = b[i,j] * h**2 
            
            #Check if we aren't in a 
            #otherwise we substract 0
            if i > 0:
                node_sum -= x[i-1,j]
            if i < (n_x -1):
                node_sum -= x[i+1,j]
            if j > 0:
                node_sum -= x[i,j-1]
            if j < (n_y -1):
                node_sum -= x[i,j+1]
            
            node_sum /= -4.
            
            x[i,j] = node_sum
    return x

In [5]:
"""
Create the right part of the Poisson Equation using the grid data.
"""
@jit(float64[:,:](float64[:],float64[:]),nopython=True,nogil=True)
def create_right(x,y):
    m = len(x)
    b = np.zeros((m,m))
    
    for i in range(m):
        for j in range(m):
            b[i,j] = - 2 * ((1-6*x[i]**2)  
                             *y[j]**2 
                            *(1-y[j]**2) 
                            +(1-6*y[j]**2)
                            *x[i]**2 
                            *(1 - x[i]**2))
    return b
"""
Calculate the residue for the current solution.
"""
@jit(float64[:,:](float64[:,:],float64[:,:], float64),nopython=True,nogil=True)
def calculate_residue(x,b, h = None):
    n_x,n_y = x.shape

    errors = np.zeros((n_x,n_y))
    for i in range(n_x):
        for j in range(n_y):
            node_sum = -4*x[i,j]
            if i > 0:
                node_sum += x[i-1,j]
            if i < (n_x -1):
                node_sum += x[i+1,j]
            if j > 0:
                node_sum += x[i,j-1]
            if j < (n_y -1):
                node_sum += x[i,j+1]
            errors[i,j] = b[i,j] - node_sum/h**2
    return errors

In [6]:
"""
Gauss seidel
"""
def gauss_seidel(m, atol = 1e-3, plot=False, random_seed = -1):
        
    if plot:
        fig,ax = plt.subplots(1,1)
        
    x = np.linspace(0,1,m+2)[1:-1]
    y = np.linspace(0,1,m+2)[1:-1]
    
    h = x[1] - x[0]
    
    if random_seed > 0:
        np.random.seed(random_seed)
        v = np.random.rand(m,m)
    else:
        v = np.zeros((m,m))
        
    f = create_right(x,y)
    
    r = np.linalg.norm(calculate_residue(v,f,h))
    iters = 0
    
    while r > atol:
        v = gauss_seidel_iteration(v,f,h)
        
        r = np.linalg.norm(calculate_residue(v,f,h))
        
        iters += 1
        if plot:
            ax.set_title("Iteration: {}, Residue: {}".format(iters,r))
            ax.imshow(v)
            fig.canvas.draw()
    return v,iters

In [7]:
@jit(float64[:,:](float64[:,:], boolean), nopython=True,nogil=True)
def restrict(r, fullweight=True):
    m = len(r)
    mC = int((m-1)/2+1)

    rC = np.zeros((mC,mC))

    for i in range(mC):
        for j in range(mC):
            rC[i,j] = r[2*i,2*j]
            
            if fullweight:
                if 2*i > 0:
                    rC[i,j] += 1/2. * r[2*i-1,2*j]
                if 2*i < m-1:          
                    rC[i,j] += 1/2. * r[2*i+1,2*j]
                if 2*j > 0:
                    rC[i,j] += 1/2. * r[2*i,2*j-1]
                if 2*j < m-1:          
                    rC[i,j] += 1/2. * r[2*i,2*j +1]

                if 2*i > 0 and 2*j < m-1:            
                    rC[i,j] += 1/16. * r[2*i-1,2*j+1]
                if 2*i < m-1 and 2*j < m-1:
                    rC[i,j] += 1/16. * r[2*i+1,2*j+1]
                if 2*i > 0 and 2*j > 0:                
                    rC[i,j] += 1/16. * r[2*i-1,2*j-1]
                if 2*i < m-1 and 2*j > 0:                    
                    rC[i,j] += 1/16. * r[2*i+1,2*j-1]
            rC[i,j] *= 1/4.
            
    return rC

In [8]:
@jit(float64[:,:](float64[:,:]), nopython=True,nogil=True)
def interpolate(rC):
    m = int(((rC.shape[0]-1)*2)+1)
    
    r = np.zeros((m,m))
    
    for i in range(rC.shape[0]):
        for j in range(rC.shape[1]):
            
            r[2*i,2*j] +=  rC[i,j]
            
            if 2*i > 0:
                r[2*i -1,2*j] += 1/2. * rC[i,j]
            if 2*i < m-1:
                r[2*i +1,2*j] += 1/2. * rC[i,j]
            if 2*j > 0:
                r[2*i,2*j -1] += 1/2. * rC[i,j]
            if 2*j < m -1:
                r[2*i,2*j +1] += 1/2. * rC[i,j]
            
            if 2*i > 0 and 2*j < m-1:
                r[2*i-1,2*j +1] += 1/4. * rC[i,j]
            if 2*i < m-1 and 2*j < m-1:    
                r[2*i+1,2*j +1] += 1/4. * rC[i,j]
            if 2*i > 0 and 2*j > 0:
                r[2*i-1,2*j -1] += 1/4. * rC[i,j]
            if 2*i < m-1 and 2*j > 0:    
                r[2*i+1,2*j -1] += 1/4. * rC[i,j]
    return r

In [9]:
@jit(int32(int32,int32,int32), nopython=True, nogil=True)
def linearize(x,y,m):
    return x+y*m

@jit(float64[:,:](int32,float64), nopython=True,nogil=True)
def create_system(m,h=None):
    if not h:
        h_x = 1./(m-1)
        h_y = 1./(m-1)
    
    else:
        h_x = h
        h_y = h
    
    n = m**2
    A = np.zeros((n,n))
    for i in range(m):
        for j in range(m):
            index = linearize(i,j,m)
            if i==0 or i == m-1 or j == 0 or j == m -1:
                if i!= 0:
                    A[index, linearize(i-1,j,m)] = 1./(h**2) 
                if i != m-1:
                    A[index, linearize(i+1,j,m)] = 1./(h**2) 
                if j != 0:
                    A[index, linearize(i,j-1,m)] = 1./(h**2)
                if j != m-1:
                    A[index, linearize(i,j+1,m)] = 1./(h**2)
                A[index, linearize(i,j,m)] = -4./(h**2)
            else:
                A[index, linearize(i-1,j,m)] = 1./(h**2)
                A[index, linearize(i+1,j,m)] = 1./(h**2)

                A[index, linearize(i,j,m)] = -4/(h**2)

                A[index, linearize(i,j-1,m)] = 1./(h**2)
                A[index, linearize(i,j+1,m)] = 1./(h**2)

            
    return A

In [10]:
def vcycle_iteration(x,b, h = None, v1 = 2, v2=2, fullweight=True,plot=False, ax=None, fig=None):
    
    m = x.shape[0]
    if not h:
        h = 1./(m-1)
    
    #Do v1 gauss_seidel iterations
    for i in range(v1):
        gauss_seidel_iteration(x,b,h)
        
        if plot and fig and ax:
            ax.set_title("Resolution: {}x{}".format(m,m))
            ax.imshow(x)
            fig.canvas.draw()    

    #If m is even solve the problem directly
    #otherwise calculate residue, restrict the residue
    #and recursively do a vcycle iteration in a coarser grid
    #then interpolate the result
    if m % 2 != 0:
        r = calculate_residue(x,b,h)
        
        #Restrict r
        residC = restrict(r,fullweight=fullweight)
        
        #Calculate Ae=r in coarser grid
        eC = np.zeros_like(residC)
        eC, iter_ops = vcycle_iteration(eC,residC,h = 2*h, v1 = v1, v2=v2, fullweight=fullweight,plot=plot, ax=ax, fig=fig)
        
        iter_ops /= 4.
        
        #Interpolate x
        x = x + interpolate(eC)
        
    else:
        if m > 10:
            raise Exception("Grid Too big for direct solver")
        #Create full matrix
        #And solve using a direct solver
        A = create_system(m,h=h)
        shapeF = b.shape
        b = b.flatten()
        return np.linalg.solve(A,b).reshape(shapeF), 0
    
    #Run v2 times gauss seidel iterations
    for i in range(v2):
        gauss_seidel_iteration(x,b,h)
        
        if plot and fig and ax:
            ax.set_title("Resolution: {}x{}".format(m,m))
            ax.imshow(x)
            fig.canvas.draw()
    

    return x, v1+v2+iter_ops

In [11]:
def vcycle(m, v1 = 2, v2=2, atol = 1e-3, rest_type=None,plot=False, ax = None, fig= None, random_seed = -1):
    if m%2==0:
        raise Exception("This implementation works only with odd grids")
    
    if rest_type:
        if rest_type == "injection":
            fullweight = False
        elif rest_type =="full_weight":
            fullweight = True
        else:
            raise Exception("Restriction not supported")
    else:
        fullweight = True

    if plot:
        fig,ax = plt.subplots(1,1)
    else:
        fig,ax = (None,None)
        
    x = np.linspace(0,1,m+2)[1:-1]
    y = np.linspace(0,1,m+2)[1:-1]
    
    h = x[1] - x[0]
    
    if random_seed > 0:
        np.random.seed(random_seed)
        v = np.random.rand(m,m)
    else:
        v = np.zeros((m,m))
        
    f = create_right(x,y)
    
    r = np.linalg.norm(calculate_residue(v,f,h))
    iters = 0
    ops_total = 0 
    while r > atol:
        v,ops = vcycle_iteration(v,f, h = h, v1 = v1, v2=v2, fullweight = fullweight,plot = plot, fig=fig, ax = ax)
        r = np.linalg.norm(calculate_residue(v,f,h))
        iters += 1
        ops_total +=ops

    return v,iters,ops_total

In [12]:
def calculate_coarser(m):
    while m %2 != 0:
        m = ((m-1)/2)+1
    return m

In [13]:
def multigrid_iter(x,b, h = None,v0=1, v1 = 2, v2=2, fullweight=True, plot=False, ax=None, fig=None):
    
    m = len(x)
    if not h:
        h = 1./(m-1)
    
    if m == calculate_coarser(m):
        for i in range(v0):
            x,ops = vcycle_iteration(x,b,h=h,v1=v1,v2=v2,fullweight=fullweight,plot=plot,ax=ax,fig=fig)
        return x,ops
        
    else:
        e = calculate_residue(x,b,h=h)
        eC = restrict(e,fullweight=fullweight)
        xC = np.zeros_like(eC)

        eC,opsC = multigrid_iter(xC,eC,h=2*h,v0=v0,v1=v1,v2=v2,fullweight=fullweight,plot=plot,ax=ax,fig=fig)
        
        
        x = x + interpolate(eC)
        for i in range(v0):
            x,ops = vcycle_iteration(x, b, h=h , v1=v1,v2=v2,fullweight=fullweight,plot=plot,ax=ax,fig=fig)
        return x,ops+opsC/4.

In [14]:
def full_multigrid(m,v0=1,v1=2,v2=2, plot=False,rest_type=None, atol=1e-3, random_seed= -1):
    if m%2==0:
        raise Exception("This implementation works only with odd grids")
    
    if rest_type:
        if rest_type == "injection":
            fullweight = False
        elif rest_type =="full_weight":
            fullweight = True
        else:
            raise Exception("Restriction not supported")
    else:
        fullweight = True
    
    if plot:
        fig,ax = plt.subplots(1,1)
    else:
        fig,ax = (None,None)
    
    grid_list = calculate_coarser(m)
    
    h = 1./(m-1)
    
    x = np.linspace(0,1,m+2)[1:-1]
    y = np.linspace(0,1,m+2)[1:-1]    
    
    f = create_right(x,y)
    
    if random_seed > 0:
        np.random.seed(random_seed)
        v = np.random.rand(m,m)
    else:
        v = np.zeros((m,m))
 
    r = np.linalg.norm(calculate_residue(v,f,h))
    iters = 0
    ops_total = 0
    while r > atol:
        v,ops = multigrid_iter(v,f,h = h, v0 = v0, v1=v1, v2=v2, fullweight=fullweight, plot=plot, ax=ax, fig=fig)
        r = np.linalg.norm(calculate_residue(v,f,h))
        iters += 1
        ops_total += ops
    return v,iters,ops_total

In [15]:
nodes = 33
v,i,ops = full_multigrid(nodes,atol=1e-13)
u = create_solution(nodes)
print i,ops

KeyboardInterrupt: 

In [None]:
plt.imshow(u-v)
plt.show()

In [None]:
fig = 