This code is to try to do the CrankNicolson in 1D

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import scipy.sparse
import scipy.linalg as la

class CrankNicolson1D:
    """A class that solves du/dt + V du/dx = D du2/dx2 + f(u)"""
    
    def set_grid(self, x_min, x_max, n_x, t_min, t_max, n_t):

        self.x_min, self.x_max, self.n_x = x_min, x_max, n_x
        self.t_min, self.t_max, self.n_t = t_min, t_max, n_t
        self.x_pts, self.delta_x = np.linspace(x_min, x_max, n_x, retstep=True, endpoint=False)
        self.t_pts, self.delta_t = np.linspace(t_min, t_max, n_t, retstep=True, endpoint=False)
        
    def set_parameters(self, D, V, f):
        
        self.V, self.D, self.f = V, D, f

    def solve(self, u_init, sparse=True, boundary_conditions=('neumann','neumann')):
            
        sig = self.D * self.delta_t / 2. / self.delta_x**2
        nu = self.V * self.delta_t / 4. / self.delta_x
        
        # Figure the data type
        data_type = type(sig*nu*u_init[0])
        
        self.u_matrix = np.zeros([self.n_t, self.n_x], dtype=data_type)

        # Using sparse matrices and specialized tridiagonal solver speeds up the calculations
        if sparse:
            
            A = self._fillA_sp(sig, nu, self.n_x, data_type)
            B = self._fillB_sp(sig, nu, self.n_x, data_type)
            # Set boundary conditions
            for b in [0,1]:
                if boundary_conditions[b] == 'dirichlet':
                    # u(x,t) = 0
                    A[1,-b] = 1.0
                    A[2*b,1-3*b] = 0.0
                    B[-b,-b] = 0.0
                    B[-b,1-3*b] = 0.0
                elif boundary_conditions[b] == 'neumann':
                    # u'(x,t) = 0
                    A[2*b,1-3*b] = -2*sig
                    B[-b,1-3*b] = 2*sig
                    
            # Propagate
            u = u_init
            for n in range(self.n_t):
                self.u_matrix[n,:] = u
                fu = f(u)
                if n==0: fu_old = fu
                u = la.solve_banded((1,1),A, B.dot(u) + self.delta_t * (1.5 * fu - 0.5 * fu_old),\
                                    check_finite=False)
                fu_old = fu

        else:
            
            A = self._make_tridiag(sig, nu, self.n_x, data_type)
            B = self._make_tridiag(-sig, -nu, self.n_x, data_type)

            # Set boundary conditions
            for b in [0,1]:
                if boundary_conditions[b] == 'dirichlet':
                    # u(x,t) = 0
                    A[-b,-b] = 1.0
                    A[-b,1-3*b] = 0.0
                    B[-b,-b] = 0.0
                    B[-b,1-3*b] = 0.0
                elif boundary_conditions[b] == 'neumann':
                    # u'(x,t) = 0
                    A[-b,1-3*b] = -2*sig
                    B[-b,1-3*b] = 2*sig

            # Propagate
            u = u_init
            for n in range(self.n_t):
                self.u_matrix[n,:] = u
                fu = f(u)
                if n==0: fu_old = fu
                u = la.solve(A, B.dot(u) + self.delta_t * (1.5 * fu - 0.5 * fu_old))
                fu_old = fu
            
    def get_final_u(self):
        
        return self.u_matrix[-1,:].copy()
        
    def _make_tridiag(self, sig, nu, n, data_type):
    
        M = np.diagflat(np.full(n, (1+2*sig), dtype=data_type)) + \
            np.diagflat(np.full(n-1, -(sig-nu), dtype=data_type), 1) + \
            np.diagflat(np.full(n-1, -(sig+nu), dtype=data_type), -1)

        return M
    
    def _fillA_sp(self, sig, nu, n, data_type):
        """Returns a tridiagonal matrix in compact form ab[1+i-j,j]=a[i,j]"""
        
        A = np.zeros([3,n], dtype=data_type) # A has three diagonals and size n
        A[0] = -(sig-nu) # superdiagonal
        A[1] = 1+2*sig # diagonal
        A[2] = -(sig+nu) # subdiagonal
        return A

    def _fillB_sp(self, sig, nu, n, data_type):
        """Returns a tridiagonal sparse matrix in csr-form"""
        
        _o = np.ones(n, dtype=data_type)
        supdiag = (sig-nu)*_o[:-1]
        diag = (1-2*sig)*_o
        subdiag = (sig+nu)*_o[:-1]
        return scipy.sparse.diags([supdiag, diag, subdiag], [1,0,-1], (n,n), format="csr")