In [3]:
import numpy as np
from scipy.interpolate import interp1d

def matvec(A,B,C,x):
    xN = x.shape[1]
    res = x.copy()
    for i in range(1,xN-1):
        res[:,i] = A[:,:,i].dot(x[:,i])+B[:,:,i-1].dot(x[:,i-1])+C[:,:,i].dot(x[:,i+1])
    res[:,0] = A[:,:,0].dot(x[:,0])+C[:,:,0].dot(x[:,1])
    res[:,-1] = B[:,:,-1].dot(x[:,-2])+A[:,:,-1].dot(x[:,-1])
    return res
    

def solve_linear(A,B,C,d):
    x = np.empty(d.shape,dtype= np.float64)
    xN = x.shape[1]
    Y = np.empty(d.shape,dtype= np.float64)
    gamma = np.empty(C.shape,dtype= np.float64)
    ialpha= np.linalg.inv(A[:,:,0])
    gamma[:,:,0] = ialpha.dot(C[:,:,0])
    Y[:,0] = ialpha.dot(d[:,0])
    for i in range(1,xN-1):
        ialpha = np.linalg.inv(A[:,:,i]-B[:,:,i-1].dot(gamma[:,:,i-1]))
        gamma[:,:,i] = ialpha.dot(C[:,:,i])
        Y[:,i] = ialpha.dot(d[:,i]-B[:,:,i-1].dot(Y[:,i-1]))

    ialpha = np.linalg.inv(A[:,:,xN-1]-B[:,:,xN-2].dot(gamma[:,:,xN-2]))
    Y[:,xN-1] = ialpha.dot(d[:,xN-1]-B[:,:,xN-2].dot(Y[:,xN-2]))
    x[:,xN-1] = Y[:,xN-1]
    for i in reversed(range(xN-1)):
        x[:,i] = Y[:,i]-gamma[:,:,i].dot(x[:,i+1])
    return x


def calc_residual(A,B,C,d,x):
    return -(matvec(A,B,C,x)-d)

    
def newton_solver(A,B,C,J,d,x0,grid,r,eps = 1e-8, max_iter = 100):
    x = x0.copy()
    for i in range(max_iter):
        x += solve_linear(A+J,B,C,d+r)
        A,B,C,J,d = generate_problem(grid,x)
        resid = calc_residual(A,B,C,d+r,x)
        if(np.abs(resid).mean() < eps):
            break        
    return A,B,C,J,d,resid    


def solver(A,B,C,J,d,x0,grid,r,eps = 1e-8, max_iter = 100,min_grid = 100):
    max_iter=1000
    if(len(grid) < min_grid ):
        return newton_solver(A,B,C,J,d,x,grid,r)
    else:
        x = x0.copy()
        for i in range(max_iter):
            A,B,C,J,d,x,resid = multigrid(A,B,C,J,d,x,grid,r)
            A,B,C,J,d,x,resid = newton_solver(A,B,C,J,d,x,grid,r,max_iter = 1)
            if(np.abs(resid).mean() < eps):
                break
        return A,B,C,J,d,x,resid
    
    
def interpolate(in_grid,x,out_grid):
    f = interp1d(in_grid,x,axis=1)
    return f(out_grid)
    

def generate_problem(grid,x):
    return ;
    
    
def generate_grid(h):
    return ;

def decimation(mat):
    new_mat = mat[...,:2:]
    return new_mat
    
    
def multigrid(A,B,C,J,d,v,grid,r):
    grid2 = decimation(grid)
    v2 = decimation(v)
    A2 = decimation(A)
    B2 = decimation(B)
    C2 = decimation(C)
    J2 = decimation(J)
    d2 = decimation(d)
    r2 = decimation(r)
    u2 = solve(A2,B2,C2,J2,d2,u2,grid2,(matvec(A2,B2,C2,v2)-d2)+r2)
    u = v+interpolate(grid,u2-v2,grid2)
    return u
    
    