In [1]:
from manim import *
import numpy as np
import random

# Functions

In [79]:
# Define approximating function
def curve_func(x):
    return np.sin(x*np.pi)*x**2

# Define all nodes up to hierarchical level l
def make_nodes(l):
    return np.linspace(0, 1, 2**l+1, endpoint = True)

# Define all nodes at hierarchical level l
def make_hierarchical_nodes(l):
    if (l==0):
        new_nodes = np.linspace(0, 1, 2**l+1, endpoint = True)
    else:
        old_nodes = np.linspace(0, 1, 2**(l-1)+1, endpoint = True)
        new_nodes = [x for x in make_nodes(l) if x not in old_nodes]
    return new_nodes

# Define basis functions
def make_basis(nodes, l):
    basis = []
    for i in nodes:
        if l == 0:
            basis.append(lambda x, i=i: max(1 - abs(2**l * x - i), 0))
        else:
            basis.append(lambda x, i=i: max(1 - abs(2**l * x - i*2**l), 0))
    return(basis)

# Define coefficients
def make_coefficients(curve, nodes):
    coefficients = []
    for n in nodes:
        coefficients.append(curve(n))
    return coefficients

# Define scaled linear basis functions
def make_scaled_basis(coeffs, basis): 
    scaled_basis = []
    for alpha, fn in zip(coeffs, basis):
        scaled_basis.append(lambda x, alpha = alpha, fn = fn: alpha * fn(x))
    return scaled_basis

def make_approximation(basis):
    return lambda x: sum(f(x) for f in basis)

def make_residual_function(resid, basis):
    return lambda x: resid(x) - sum(f(x) for f in basis)

lmax = 5
# Initialise
nodes_l = []
phi_l = []
coeff_l = []
coeff_phi_l = []
approx_l = []

for l in range(lmax):
    nodes_l.append(make_nodes(l))
    phi_l.append(make_basis(nodes_l[l], l))
    coeff_l.append(make_coefficients(curve_func, nodes_l[l]))
    coeff_phi_l.append(make_scaled_basis(coeff_l[l], phi_l[l]))
    approx_l.append(make_approximation(coeff_phi_l[l]))

# Introduction

In [26]:
%%manim -qh --disable_caching -v WARNING DynamicOptimisation

class DynamicOptimisation(Scene):
    def construct(self):

        bellman = MathTex("V",
                          "(",
                          "x_{t}",
                          ") = \\max_{\\{",
                          "a_{t}",
                          " \\in",
                          "\\Gamma(x_{t})",
                          "\\}} \\{F(x_{t}, a_{t}) + \\beta V(",
                          "T(x_{t}, a_{t})",
                          ")\\}")
        
        underline = [Underline(x, color = YELLOW) for x in bellman]
        
        self.play(Write(bellman))#, run_time = 4)
        self.wait(1)
        self.play(Create(underline[0]))
        self.play(FadeOut(underline[0]))
        self.play(Create(underline[2]))
        self.play(FadeOut(underline[2]))
        self.play(Create(underline[4]))
        self.play(FadeOut(underline[4]))
        self.play(Create(underline[6]))
        self.play(FadeOut(underline[6]))
        self.wait(1)
#         self.play(bellman.animate.shift(2*UP))

                                                                                                                       

In [22]:
%%manim -qh --disable_caching -v WARNING Setting_Title

class Setting_Title(Scene):
    def construct(self):
        Title = Text("Setting").scale(1.5)
        self.play(Write(Title))
        self.play(FadeOut(Title))

                                                                                                                       

In [29]:
%%manim -qh --disable_caching -v WARNING Setup2

class Setup2(Scene):
    def construct(self):
               
        ax = Axes(
            x_range=[0, 1], 
            y_range=[0, 0.5], 
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "include_numbers": False,
                "include_ticks": False,
                "include_tip": False
            }
        )

        labels = ax.get_axis_labels(y_label = MathTex("f(x)"))
        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE)
        curve_label = MathTex("f: [0,1] \\longrightarrow \\mathbb{R}, f(x)", color = BLUE)
        discrete_label = MathTex("f \\approx \\hat{f}", color = GREEN)

        x_pos = [x for x in [0,1]]
        x_vals = ["0", "1"]
        x_labels = [MathTex(x) for x in x_vals]
        x_dict = dict(zip(x_pos, x_labels))
     
        ax.add_coordinates(x_dict)

        l = 2
        dots = VGroup(*[Dot(point=ax.c2p(n,0)) for n in nodes_l[l]])      
        coeff_phi = VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]])
        vert_lines = VGroup(*[ax.get_vertical_line(ax.i2gp(n, curve), color = YELLOW) for n in nodes_l[l]])
        approx = ax.get_graph(approx_l[l], x_range = [0,1,0.003], color = GREEN)
                        
        # Animations       
        self.play(Write(curve_label))
        self.play(curve_label.animate.shift(3*UP))
        self.play(Create(ax), Create(curve))
        self.wait(1)      
            
        x_pos = [x for x in np.linspace(0,1,5)]
        x_vals = ["0", "x_{1}", "x_{2}", "x_{3}", "1"]
        x_labels = [MathTex(x) for x in x_vals]
        x_dict = dict(zip(x_pos, x_labels))
    
        self.play(Write(discrete_label.shift(2*UP)))
        self.wait(1)
        ax.add_coordinates(x_dict)
        ax.x_axis.remove(ax.x_axis.labels)
        self.play(Write(ax.x_axis.labels))
        
        self.play(FadeIn(dots))
        self.play(FadeIn(vert_lines))
        self.play(Create(approx), run_time = 4)
        self.wait(1)
        self.play(FadeOut(vert_lines), FadeOut(dots), FadeOut(curve))

        self.wait(1)

                                                                                                                       

In [32]:
%%manim -qh --disable_caching -v WARNING Setup3

class Setup3(Scene):
    def construct(self):
               
        ax = Axes(
            x_range=[0, 1], 
            y_range=[0, 0.5], 
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "include_numbers": False,
                "include_ticks": False,
                "include_tip": False
            }
        )

        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE)

        x_pos = [x for x in [0,1]]
        x_vals = ["0", "1"]
        x_labels = [MathTex(x) for x in x_vals]
        x_dict = dict(zip(x_pos, x_labels))
     
        ax.add_coordinates(x_dict)

        l = 3
        dots = VGroup(*[Dot(point=ax.c2p(n,0)) for n in nodes_l[l]])      
        coeff_phi = VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]])
        vert_lines = VGroup(*[ax.get_vertical_line(ax.i2gp(n, curve), color = YELLOW) for n in nodes_l[l]])
        approx = ax.get_graph(approx_l[l], x_range = [0,1,0.003], color = GREEN)
                        
        # Animations       
        self.add(ax, curve)     
            
        x_pos = [x for x in np.linspace(0,1,9)]
        x_vals = ["0", "x_{1}", "x_{2}", "x_{3}", "x_{4}", "x_{5}", "x_{6}", "x_{7}", "1"]
        x_labels = [MathTex(x) for x in x_vals]
        x_dict = dict(zip(x_pos, x_labels))
    
        ax.add_coordinates(x_dict)
        ax.x_axis.remove(ax.x_axis.labels)
        self.play(Write(ax.x_axis.labels))
        
        self.play(FadeIn(dots))
        self.play(FadeIn(vert_lines))
        self.play(Create(approx), run_time = 4)
        self.wait(1)
        self.play(FadeOut(vert_lines), FadeOut(dots))

        self.wait(1)

                                                                                                                       

In [36]:
%%manim -qh --disable_caching -v WARNING Basis_Blackboard

class Basis_Blackboard(Scene):
    def construct(self):
        
        title = Title(f"Linear Spline Approximation with Full Grids")
        no_intervals = MathTex("\\text{Number of intervals} = 2^l")
        grid_spacing = MathTex("\\text{Grid spacing: } h_{l} := 2^{-l}")
        omega = MathTex("\\text{Set of grid points: } \\Omega_{l} := \left\{ x_{l,i} \mid i = 0, \\dots, 2^l \\right\}")
        grid_points = MathTex("\\text{Grid points: } x_{l, i} := i h_{l}")  
        basis_functions = MathTex("\\text{Basis functions: } \\phi_{l,i}(x)", ":= \\max (1-|2^{l}x - i|, 0)")
        interpolant = MathTex("\\text{Interpolant: } \\hat{f}(x) := \sum_{i=0}^{2^l} c_{l,i} \phi_{l,i}(x)")
        constants = MathTex("c_{l,i} = f(x_{l,i})")
  
        text = VGroup(no_intervals, grid_spacing, omega, grid_points, basis_functions, interpolant, constants).scale(0.75).arrange(DOWN)
        
        self.play(Write(title))
        
        for t in text:
            self.play(Write(t))
            self.wait()
#         for t in basis_functions:
#             self.play(Write(t))
#             self.wait()

                                                                                                                       

# Animating basis functions 

In [37]:
%%manim -qh --disable_caching -v WARNING Basis

class Basis(Scene):
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/4], 
            y_range=[0, 1.5, 1], 
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            },
            y_axis_config={
                "numbers_to_include": [0, 1, 2],
                "include_ticks": True
            }
        )

        labels = ax.get_axis_labels()
        self.add(ax, labels)
        
        colors = [BLUE_E, TEAL_E, GREEN_E, YELLOW_E, GOLD_E, TEAL_E, RED_E, MAROON_E, PURPLE_E]
        lmax = 4
        nodes_l = []
        phi_l = []
        
        for l in range(lmax):
            nodes_l.append(make_nodes(l))
            phi_l.append(make_basis(nodes_l[l], l))
            
        phi = []
        dots = []
        
        for l in range(lmax):
            phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = colors[i]) for i, p in enumerate(phi_l[l])]))         
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in nodes_l[l]]))
            
            omega_string = "\\Omega_{" + str(l) + "} = \\left\{"
            omega_string += ", ".join([str(element) for element in nodes_l[l]])
            omega_string += "\\right\}"
            omega_label = MathTex(omega_string)
            phi_label = MathTex("\\phi_{" + str(l) + ",i} = \\max(1-|" + str(2**l) + "x - i|, 0)")  
            labels = VGroup(omega_label, phi_label).arrange(DOWN).shift(3*UP).scale(0.8)
            
            self.play(FadeIn(labels))
            self.play(FadeIn(dots[l]))
            self.play(Create(phi[l]), run_time = (l+1))
            self.wait(1)
            self.play(FadeOut(phi[l]), FadeOut(dots[l]), FadeOut(labels))

                                                                                                                       

# Linear spline approximation

## Animating linear spline appoximation

In [6]:
%%manim -qh -v WARNING LinearSplineApproximation

class LinearSplineApproximation(Scene):
    
    # Create axis
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/8], 
            y_range=[0, 0.5], 
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            }
        )

        labels = ax.get_axis_labels()

        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE_C)
        curve_label = MathTex("f(x) = x^2 \sin(\pi x)", color = BLUE_C).shift(3*UP)

        self.add(ax, labels)
        self.play(Create(curve))
        self.play(Write(curve_label))
        
        dots = []
        coeff_phi = []  
        approx = []
        vert_lines = []
        braces = []
        braces_text = []
        annotation = []
        
        for l in range(lmax):
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in nodes_l[l]]))        
            coeff_phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]]))
            vert_lines.append(VGroup(*[ax.get_vertical_line(ax.i2gp(n, curve), color = YELLOW) for n in nodes_l[l]]))
#             braces.append(Brace(vert_lines[l]).rotate(PI/2))
#             braces_text.append(braces[l].get_text("c"))
            approx.append(ax.get_graph(approx_l[l], x_range = [0,1,0.003], color = GREEN))
            annotation.append(MathTex("l = " + str(l)).shift(3*UP))
            
            
        # Animations
        self.play(FadeOut(curve_label))
        for l in range(2, lmax):
            self.play(Write(annotation[l]))
            self.play(FadeIn(dots[l]))
            self.play(FadeIn(vert_lines[l]))
            self.play(Create(coeff_phi[l]), run_time = l)
#             self.play(Create(VGroup(braces[l], braces_text[l])))
            self.play(Create(approx[l]), run_time = l)            
            self.play(FadeOut(coeff_phi[l]), FadeOut(vert_lines[l]))
            self.wait(1)
            self.play(FadeOut(approx[l]), FadeOut(dots[l]), FadeOut(annotation[l]))

                                                                                                                       

## Blackboard for linear spline appoximation

In [7]:
%%manim -qh --disable_caching -v WARNING Linear_Spline_Blackboard

class Linear_Spline_Blackboard(Scene):
    def construct(self):
        interpolant = MathTex("\\hat{f}(x) := \sum_{i=0}^{2^l} c_{n,i} \phi_{n,i}(x)")
        constants = MathTex("c_{l,i} = f(x_{l,i})")
        basis_functions = MathTex("\\phi_{l,i}(x) := \\max (1-|2^{l}x - i|, 0)")
        self.play(Write(interpolant.shift(2*UP)))
        self.play(Write(constants))
        self.play(Write(basis_functions.shift(2*DOWN)))
        self.wait(1)

                                                                                                                       

# Hierarchical Linear Spline Approximation

In [80]:
# Define approximating function
def curve_func(x):
    return np.sin(x*np.pi)*x**2

lmax = 5
# Initialise
new_nodes_l = [make_hierarchical_nodes(0)]
phi_l = [make_basis(new_nodes_l[0], 0)]
coeff_l = [[curve_func(n) for n in new_nodes_l[0]]]
coeff_phi_l = [make_scaled_basis(coeff_l[0], phi_l[0])]
residual_func_l = [make_residual_function(curve_func, coeff_phi_l[0])]

for l in range(1, lmax):
    
    # Add new nodes for each parent node
    if (l > 0):
        new_nodes_l.append(make_hierarchical_nodes(l))

    # Add linear basis functions
    phi_l.append(make_basis(new_nodes_l[l], l))

    # Determine hierarchical surpluses
    coeff_l.append(make_coefficients(residual_func_l[l-1], new_nodes_l[l]))
    
    # Create scaled linear basis functions
    coeff_phi_l.append(make_scaled_basis(coeff_l[l], phi_l[l]))

    # Define residual function
    residual_func_l.append(make_residual_function(residual_func_l[l-1], coeff_phi_l[l]))


## Animating hierarchical basis functions

In [61]:
%%manim -qh --disable_caching -v WARNING HBasis

class HBasis(Scene):
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/4], 
            y_range=[0, 1.5], 
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            },
            y_axis_config={
                "numbers_to_include": [0, 1, 2],
                "include_ticks": True
            }
        )

        labels = ax.get_axis_labels()
        
        self.add(ax, labels)
        colors = [BLUE_E, TEAL_E, GREEN_E, YELLOW_E, GOLD_E, TEAL_E, RED_E, MAROON_E, PURPLE_E]
        lmax = 4
        hierarchical_nodes_l = []
        phi_l = []
        
        for l in range(lmax):
            hierarchical_nodes_l.append(make_hierarchical_nodes(l))
            phi_l.append(make_basis(hierarchical_nodes_l[l], l))
            # print(hierarchical_nodes_l[l])
            
        phi = []
        dots = []      
        for l in range(lmax):       
            phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = colors[i]) for i, p in enumerate(phi_l[l])]))         
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in hierarchical_nodes_l[l]]))
            
            string = "\\tilde{\\Omega}_{" + str(l) + "} = \\left\{"
            string += ", ".join([str(element) for element in hierarchical_nodes_l[l]])
            string += "\\right\}"
            omega = MathTex(string)
            
            self.play(FadeIn(omega.shift(3*UP)))
            
            self.play(FadeIn(dots[l]))
            self.play(Create(phi[l]), run_time = len(phi[l]))
            self.wait(1)
            self.play(FadeOut(phi[l]), FadeOut(dots[l]), FadeOut(omega))

                                                                                                                       

# Text for hierarchical basis functions

## Animation for hierarchical linear spline approximation

In [68]:
%%manim -qh --disable_caching -v WARNING Hierarchical_Basis_Blackboard

class Hierarchical_Basis_Blackboard(Scene):
    def construct(self):
        title = Title(f"Hierarchical Linear Spline Approximation with Full Grids")
        nested = MathTex(r"\Omega_{l-1} \subset \Omega_{l}")
        omega_tilde0 = MathTex(r"\tilde{\Omega}_{0} := \left\{0, 1 \right\}")
        omega_tilde = MathTex(r"\tilde{\Omega}_{l} := \left\{x_{l,i} \mid i = 0, \dots, 2^l, i \text{ odd} \right\}")
        omega = MathTex(r"\Omega_{l} := \bigcup_{n=1}^{l} \tilde{\Omega}_{n}")

        blackboard = VGroup(nested, omega_tilde0, omega_tilde, omega).arrange(DOWN)
        
        self.play(Write(title))
        for t in blackboard:
            self.play(Write(t))
        self.wait(1)

                                                                                                                       

In [18]:
%%manim -qh -v WARNING Hierarchical_Linear_Spline_Approximation

class Hierarchical_Linear_Spline_Approximation(Scene): 
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/8], 
            y_range=[-0.1, 0.5, 0.05], 
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            }
        )
        
        labels = ax.get_axis_labels()
        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE_C)
        curve_label = MathTex("y = x^2 \sin(\pi x)", color = BLUE_C)
        curve_label.shift(3*UP)   
        self.add(ax, labels)
        
        self.play(Create(curve))
        self.play(Write(curve_label))
        self.play(FadeOut(curve_label))
        
        lmax = 5
        dots = []
        coeff_phi = []
        residual = []
        levels = []
        for l in range(lmax):
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in new_nodes_l[l]]))
            coeff_phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]]))
            residual.append(ax.get_graph(residual_func_l[l], x_range = [0,1,0.003], color = BLUE))   
            string = "l = " + str(l)
            levels.append(MathTex(string).shift(3*UP))
            
        all_new_nodes = [j for i in new_nodes_l for j in i]
        all_new_dots = VGroup(*[Dot(point=ax.c2p(n,0)) for n in all_new_nodes])
        all_basis = [j for i in coeff_phi_l for j in i]
        all_coeff_phi = VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in all_basis])
            
        total_approx_func = lambda x: curve_func(x) - residual_func_l[lmax-1](x)
        total_approx = ax.get_graph(total_approx_func, x_range = [0,1,0.003], color = GREEN)
        
        lmax_label = MathTex("l_{max} = 4").shift(5.5*RIGHT, 3*UP).scale(0.7)
        self.play(Write(lmax_label))
        
        for l in range(lmax):
            self.play(Write(levels[l]))
            self.play(FadeIn(dots[l]))
            self.play(Create(coeff_phi[l]))
            self.wait(1)
            if l == 0:
                self.play(FadeOut(coeff_phi[l]), FadeOut(dots[l]), FadeOut(levels[l]))
            else:
                self.play(Transform(VGroup(curve, coeff_phi[l]), residual[l]), FadeOut(dots[l]), FadeOut(levels[l]))

        self.play(FadeOut(VGroup(curve, residual[lmax-1])))
        self.play(FadeIn(VGroup(all_coeff_phi, all_new_dots)))
        self.wait(1)
        self.play(Transform(all_coeff_phi, total_approx))       
        self.wait(1)

                                                                                                                       

# Adaptive Hierarchical Linear Spline Approximation

In [81]:
# Define approximating function
def curve_func(x):
    return np.sin(x*np.pi)*x**2

# Define child nodes 
def make_new_nodes(parents, l):
    new_nodes = []
    for n in parents:
        new_nodes.extend([n - 2**-l, n + 2**-l])
    return new_nodes

# Define parent nodes
def make_parent_nodes(coeffs, nodes, tol):
    parent_nodes = []
    for alpha, node in zip(coeffs, nodes):
        if abs(alpha) >= tol:
            parent_nodes.append(node)
    return parent_nodes

tolerance = 0.02
lmax = 5

# Initialise
new_nodes_l = [make_hierarchical_nodes(0), make_hierarchical_nodes(1)]
parent_nodes_l = [new_nodes_l[0]]
phi_l = [make_basis(new_nodes_l[0], 0)]
coeff_l = [[curve_func(n) for n in new_nodes_l[0]]]
coeff_phi_l = [make_scaled_basis(coeff_l[0], phi_l[0])]
terminal_nodes_l = [[]]
residual_func_l = [make_residual_function(curve_func, coeff_phi_l[0])]

for l in range(1, lmax):
    
    # Add new nodes for each parent node
    if (l > 1):
        new_nodes_l.append(make_new_nodes(parent_nodes_l[l-1], l))

    # Add linear basis functions
    phi_l.append(make_basis(new_nodes_l[l], l))

    # Determine hierarchical surpluses
    coeff_l.append(make_coefficients(residual_func_l[l-1], new_nodes_l[l]))
    
    # Create scaled linear basis functions
    coeff_phi_l.append(make_scaled_basis(coeff_l[l], phi_l[l]))

    # Determine parent and terminal nodes
    parent_nodes_l.append(make_parent_nodes(coeff_l[l], new_nodes_l[l], tolerance))           
    terminal_nodes_l.append(np.setdiff1d(new_nodes_l[l], parent_nodes_l[l]))
    
    # Define residual function
    residual_func_l.append(make_residual_function(residual_func_l[l-1], coeff_phi_l[l]))

## Adaptive Blackboard

In [77]:
%%manim -qh --disable_caching -v WARNING Adaptive_Hierarchical_Basis_Blackboard

class Adaptive_Hierarchical_Basis_Blackboard(Scene):
    def construct(self):
        title = Title(f"Adaptive Hierarchical Linear Spline Approximation")
        epsilon = MathTex(r"\text{Tolerance level: } \varepsilon")
        alphap = MathTex(r"\text{If } |\alpha_{l,i}| \geq \varepsilon")
        parent_node = MathTex(r"\text{Node becomes parent node}") 
        alphat = MathTex(r"\text{If } |\alpha_{l,i}| < \varepsilon")    
        terminal_node = MathTex(r"\text{Node becomes terminal node}") 

        blackboard = VGroup(epsilon, alphap, parent_node, alphat, terminal_node).arrange(DOWN)
        
        self.play(Write(title))
        for t in blackboard:
            self.play(Write(t))
        self.wait(1)

                                                                                                                       

## Animation for adaptive hierarchical linear spline approximation

In [82]:
%%manim -qh -v WARNING Adaptive_Hierarchical_Linear_Spline_Approximation

class Adaptive_Hierarchical_Linear_Spline_Approximation(Scene): 
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/8], 
            y_range=[-0.1, 0.5, 0.05], 
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            },
            y_axis_config={
                "numbers_to_include": [0.05],
                "include_ticks": True,
            }
        )
        
        lmax = 5
        labels = ax.get_axis_labels()
        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE_C)
        curve_label = MathTex("y = x^2 \sin(\pi x)", color = BLUE_C)
        curve_label.shift(3*UP)   
        self.add(ax, labels)
        self.play(Create(curve))
        self.play(Write(curve_label))
        self.play(FadeOut(curve_label))
        
        lmin_label = MathTex("l_{min} = 1").shift(5.5*RIGHT, 3*UP).scale(0.5)
        lmax_label = MathTex("l_{max} = 4").shift(5.5*RIGHT, 2.5*UP).scale(0.5)
        tolerance_label = MathTex("\\varepsilon = 0.02").shift(5.5*RIGHT, 2*UP).scale(0.6)
        g = VGroup(lmax_label, lmin_label, tolerance_label)
        self.play(FadeIn(g))
        
        dots = []
        parent_dots = []
        terminal_dots = []
        coeff_phi = []
        residual = []
        levels = []
        
        for l in range(lmax):
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in new_nodes_l[l]]))
            parent_dots.append(VGroup(*[Dot(point=ax.c2p(n,0), color = GREEN) for n in parent_nodes_l[l]]))
            terminal_dots.append(VGroup(*[Dot(point=ax.c2p(n,0), color = RED) for n in terminal_nodes_l[l]]))
            coeff_phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]]))
            residual.append(ax.get_graph(residual_func_l[l], x_range = [0,1,0.003], color = BLUE))   
            string = "l = " + str(l)
            levels.append(MathTex(string).shift(3*UP))
            
        all_new_nodes = [j for i in new_nodes_l for j in i]
        all_new_dots = VGroup(*[Dot(point=ax.c2p(n,0)) for n in all_new_nodes])
        all_basis = [j for i in coeff_phi_l for j in i]
        all_coeff_phi = VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in all_basis])
            
        total_approx_func = lambda x: curve_func(x) - residual_func_l[lmax-1](x)
        total_approx = ax.get_graph(total_approx_func, x_range = [0,1,0.003], color = GREEN)
        
        grid_table = Table(
        [["Full grid", "Adaptive grid"],
        ["17", "13"]]).scale(0.5).shift(2.5*UP, 2*LEFT)
        
        for l in range(lmax):
            self.play(Write(levels[l]))
            if (l==0):
                self.play(FadeIn(dots[l]))
            else:
                self.play(FadeOut(terminal_dots[l-1]), Transform(parent_dots[l-1], dots[l]))
            self.play(Create(coeff_phi[l]))
            if (l==0):
                self.play(FadeOut(dots[l]), FadeIn(parent_dots[l]), FadeIn(terminal_dots[l]))
            else:
                self.play(FadeOut(parent_dots[l-1]), FadeIn(parent_dots[l]), FadeIn(terminal_dots[l]))
            if (l==0):
                self.play(FadeOut(coeff_phi[l]), FadeOut(levels[l]))
            else:
                self.play(Transform(VGroup(curve, coeff_phi[l]), residual[l]), FadeOut(levels[l]))

        self.play(FadeOut(VGroup(curve, residual[lmax-1], parent_dots[lmax-1], terminal_dots[lmax-1])))
        self.play(FadeIn(VGroup(all_coeff_phi, all_new_dots)))
        self.wait(1)
        self.play(Transform(all_coeff_phi, total_approx))      
        self.play(Write(grid_table))
        self.wait(1)

                                                                                                                       

In [142]:
def curve_func(x):
    return np.sin(np.pi/(x+0.2))

tolerance = 0.02
lmax = 8
# Initialise
new_nodes_l = [make_hierarchical_nodes(0), make_hierarchical_nodes(1)]
parent_nodes_l = [new_nodes_l[0]]
phi_l = [make_basis(new_nodes_l[0], 0)]
coeff_l = [[curve_func(n) for n in new_nodes_l[0]]]
coeff_phi_l = [make_scaled_basis(coeff_l[0], phi_l[0])]
terminal_nodes_l = [[]]
residual_func_l = [make_residual_function(curve_func, coeff_phi_l[0])]

for l in range(1, lmax):
    
    # Add new nodes for each parent node
    if (l > 1):
        new_nodes_l.append(make_new_nodes(parent_nodes_l[l-1], l))

    # Add linear basis functions
    phi_l.append(make_basis(new_nodes_l[l], l))

    # Determine hierarchical surpluses
    coeff_l.append(make_coefficients(residual_func_l[l-1], new_nodes_l[l]))
    
    # Create scaled linear basis functions
    coeff_phi_l.append(make_scaled_basis(coeff_l[l], phi_l[l]))

    # Determine parent and terminal nodes
    parent_nodes_l.append(make_parent_nodes(coeff_l[l], new_nodes_l[l], tolerance))           
    terminal_nodes_l.append(np.setdiff1d(new_nodes_l[l], parent_nodes_l[l]))
    
    # Define residual function
    residual_func_l.append(make_residual_function(residual_func_l[l-1], coeff_phi_l[l]))
    


In [143]:
%%manim -qh -v WARNING complicated_function

class complicated_function(Scene): 
    def construct(self):
        ax = Axes(
            x_range=[0, 1, 1/8], 
            y_range=[-1.5, 1.5, 0.1],
            axis_config = {
                "decimal_number_config": {"num_decimal_places": 2}
            },
            x_axis_config={
                "numbers_to_include": np.linspace(0, 1, 5),
                "include_ticks": True,
                "include_tip": False
            },
            y_axis_config={
                "numbers_to_include": np.arange(-1.5, 1.5, 0.5),
                "include_ticks": True,
            }
        )
        
        labels = ax.get_axis_labels()
        curve = ax.get_graph(curve_func, x_range=[0, 1], color=BLUE_C)
        curve_label = MathTex("y = \\sin\\left(\\frac{\\pi}{x+0.2}\\right)", color = BLUE_C)
        curve_label.shift(3*UP)   
        self.add(ax, labels)
        
        self.play(Create(curve))
        self.play(Write(curve_label))

        lmin_label = MathTex("l_{min} = 1").shift(5.5*RIGHT, 3*UP).scale(0.5)
        lmax_label = MathTex("l_{max} = 7").shift(5.5*RIGHT, 2.5*UP).scale(0.5)
        tolerance_label = MathTex("\\varepsilon = 0.02").shift(5.5*RIGHT, 2*UP).scale(0.5)
        g = VGroup(lmax_label, lmin_label, tolerance_label)
        self.play(FadeIn(g))
        
        self.play(FadeOut(curve_label))
        
        dots = []
        parent_dots = []
        terminal_dots = []
        coeff_phi = []
        residual = []
        levels = []
        
        for l in range(lmax):
            dots.append(VGroup(*[Dot(point=ax.c2p(n,0)) for n in new_nodes_l[l]]))
            parent_dots.append(VGroup(*[Dot(point=ax.c2p(n,0), color = GREEN) for n in parent_nodes_l[l]]))
            terminal_dots.append(VGroup(*[Dot(point=ax.c2p(n,0), color = RED) for n in terminal_nodes_l[l]]))
            coeff_phi.append(VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in coeff_phi_l[l]]))
            residual.append(ax.get_graph(residual_func_l[l], x_range = [0,1,0.003], color = BLUE))   
            string = "l = " + str(l)
            levels.append(MathTex(string).shift(3*UP))
            
        all_new_nodes = [j for i in new_nodes_l for j in i]
        all_new_dots = VGroup(*[Dot(point=ax.c2p(n,0), radius = 0.5*DEFAULT_DOT_RADIUS) for n in all_new_nodes])
        all_basis = [j for i in coeff_phi_l for j in i]
        all_coeff_phi = VGroup(*[ax.get_graph(p, x_range = [0,1,0.003], color = ORANGE) for p in all_basis])
            
        total_approx_func = lambda x: curve_func(x) - residual_func_l[lmax-1](x)
        total_approx = ax.get_graph(total_approx_func, x_range = [0,1,0.003], color = GREEN)
        
        grid_table = Table(
        [["Full grid", "Adaptive grid"],
        ["129", "55"]]).scale(0.5).shift(2.5*UP, 2*RIGHT)
            
        for l in range(lmax):
            self.play(Write(levels[l]))
            if (l==0):
                self.play(FadeIn(dots[l]))
            else:
                self.play(FadeOut(terminal_dots[l-1]), Transform(parent_dots[l-1], dots[l]))
            self.play(Create(coeff_phi[l]))
            if (l==0):
                self.play(FadeOut(dots[l]), FadeIn(parent_dots[l]), FadeIn(terminal_dots[l]))
            else:
                self.play(FadeOut(parent_dots[l-1]), FadeIn(parent_dots[l]), FadeIn(terminal_dots[l]))
            self.play(Transform(VGroup(curve, coeff_phi[l]), residual[l]), FadeOut(levels[l]))

        self.play(FadeOut(VGroup(curve, residual[lmax-1], parent_dots[lmax-1], terminal_dots[lmax-1])))
        self.play(FadeIn(VGroup(all_coeff_phi, all_new_dots)))
        self.play(Transform(all_coeff_phi, total_approx))      
        self.play(Write(grid_table))
        self.wait(1)

                                                                                                                       