In [None]:
def crank_nicolson(T, h, k, t_arr, tol, metal):
    """
    Crank-Nicolson method for solving the heat equation.
    """
    # Define the tolerance
    # tol = 0.01

    # Termal equilibrium flag
    t_eq = False
    t0 = None
    t0_indx = None

    # Select the metal 
    alpha = alpha_dic[metal]

    # Compute r factor
    alpha = alpha * 1e-2 # mm^2/s to cm^2/s
    r_factor = alpha * k / h**2

    # Dimentions
    n = T.shape[0] # Number of points in space
    m = T.shape[1] # Number of points in time

    # Create the D1 matrix

    D1_matrix_0 = np.diag([2 + 2*r_factor]*(n - 2), 0)
    D1_matrix_n = np.diag([-r_factor]*(n - 3), -1)
    D1_matrix_p = np.diag([-r_factor]*(n - 3), +1)

    D1_matrix   = D1_matrix_0 + D1_matrix_n + D1_matrix_p # Sum all

    # Create the D2 matrix
    D2_matrix_0 = np.diag([2 - 2*r_factor]*(n - 2), 0)
    D2_matrix_n = np.diag([r_factor]*(n - 3), -1)
    D2_matrix_p = np.diag([r_factor]*(n - 3), +1)

    D2_matrix   = D2_matrix_0 + D2_matrix_n + D2_matrix_p # Sum all

    # print(D1_matrix)
    # print(D2_matrix)

    # Solve the linear sytem of equations

    # Iterate over time steps
    for j in range(0, m-1):
        
        # Add initial conditions to initial b vector
        b = T[1:-1, j].copy()
        #print(b.shape)
        #print(b)
        
        # Evaluate RHS
        b = np.dot(D2_matrix, b)
        # b = D2_matrix@b (another option)
        #print(b)
        
        # Append missing values
        
        b[0]  = b[0]  + r_factor*(T[0, j+1] + T[0, j])
        b[-1] = b[-1] + r_factor*(T[-1, j+1] + T[-1, j])
        
        # Compute the solution vector:
        sln_b = np.linalg.solve(D1_matrix, b)
        
        # Update T matrix
        T[1:-1, j+1] = sln_b

        # Compute the derivative across the wire to dientify a steady state

        # d2T = np.gradient(np.gradient(T[1:-1, j+1], h), h)

        # print(np.mean(np.abs(d2T)))

        # if np.mean(np.abs(d2T)) < tol:
        #     t0 = t_arr[j+1]
        #     t0_indx = j + 1 
        #     t_eq = True
        #     break

        dT_mean = np.mean(np.abs(T[:, j] - T[:, j+1]))

        if dT_mean < tol:
            t0 = t_arr[j+1]
            t0_indx = j + 1 
            t_eq = True
            break

    print(dT_mean)

    if t_eq:
        print(f"Steady state reached at t = {t0:.2f} s")
    else:   
        print("Steady state not reached")
        t0 = None
        t0_indx = None
    
    return t0_indx, t0, T