## Setup

In [None]:
from manim import *

config.media_width = "75%"
config.verbosity = "WARNING"

In [23]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [8]:
%%manim -qh  LinearModel 

class LinearModel(Scene):
    def construct(self):
        
        # plot area for the linear function
        plot_linear = Axes(
        x_range=[0, 5, 1],
        y_range=[0, 5, 1],
        x_length=4,
        y_length=3.5,
        axis_config={
            "numbers_to_include": np.arange(0, 5, 1), # if not specified, no numbers on the axis
            "font_size": 20},
        
        tips=False, # adds arrow tips at the end of the axis
        ).shift(3.5*LEFT).shift(DOWN) # creates is already in place
        
        y_label = plot_linear.get_y_axis_label("y", edge=LEFT, direction=LEFT, buff=0.3).scale(0.8)
        x_label = plot_linear.get_x_axis_label("x").scale(0.8)
        plot_labels = VGroup(x_label, y_label)
        
        # Datapoints - training data
        dots = VGroup()
        dots += Dot(point=plot_linear.c2p(1, 1, 0), color=WHITE)
        dots += Dot(point=plot_linear.c2p(2, 2, 0), color=WHITE)
        dots += Dot(point=plot_linear.c2p(3, 3, 0), color=WHITE)
        
        # plot area for the cost function
        plot_cost = Axes(
        x_range=[-0.5, 2.5, 0.5],
        y_range=[0, 3, 1],
        x_length=4,
        y_length=3.5,
        y_axis_config={
            "numbers_to_include": np.arange(0, 3.0, 0.5),  # if not specified, no numbers on the axis
            "font_size": 20},
        x_axis_config={
            "numbers_to_include": np.arange(-0.5, 3.0, 0.5),  # if not specified, no numbers on the axis
            "font_size": 20},
        tips=False, # adds arrow tips at the end of the axis
        ).shift(4*RIGHT).shift(DOWN) # creates is already in place
        
        y_label_cost = plot_cost.get_y_axis_label("J(w)", edge=LEFT, direction=LEFT, buff=0.3).scale(0.8)
        x_label_cost= plot_cost.get_x_axis_label("w").scale(0.8)
        plot_labels_cost = VGroup(x_label_cost, y_label_cost)
        
        # Functions to display 
        linear = MathTex(r"f_{w,b}(x^{(i)}) = wx^{(i)} + b", font_size = 41)
        cost = MathTex(r"J(w,b) = \frac{1}{2m} \sum\limits_{i = 0}^{m-1} (f_{w,b}(x^{(i)}) - y^{(i)})^2",font_size = 33)
        
        linear2 = MathTex(r"f_{w}(x^{(i)}) = wx^{(i)}", font_size = 41).shift(3*UP).shift(4.3*LEFT)
        cost2 = MathTex(r"J(w) = \frac{1}{2m} \sum\limits_{i = 0}^{m-1} (f_{w}(x^{(i)}) - y^{(i)})^2",font_size = 33).shift(3*UP).shift(3.5*RIGHT)
    
        # ANIMATION  
        
        # Writes the functions and place them at the locations intended
        self.play(Write(linear), run_time=3)
        self.play(linear.animate.shift(3*UP).shift(4.5*LEFT), run_time=3) # linear function
        
        self.play(Write(cost), run_time=3)
        self.play(cost.animate.shift(3*UP).shift(3.5*RIGHT), run_time=3) # cost function
        
        self.play(TransformMatchingShapes(linear, linear2))
        self.play(TransformMatchingShapes(cost, cost2))

        
    #### cross the b and make it equal to zero, then vanish it. #### switch, transform? permu function
        
        # Create the plots, its axis and training data all at the same time and place
        self.play(Create(plot_linear), Create(plot_labels), Create(dots)) # linear function plot area
        self.play(Create(plot_cost), Create(plot_labels_cost)) # cost function plot area
        
        self.wait()

        # Plots the function lines according to the value of w
        for w in (0, 0.5, 1, 1.5): 
            plot_linear_function = VGroup()
            
            if w != 1.5: 
                plot_linear_function = plot_linear.plot(lambda x: w*x, color=GREY)
            
            else:
                plot_linear_function = plot_linear.plot(lambda x: w*x, color=GREY, x_range=[0, 3.3])
            
            # text indicating the value of w
            w_text = Text(f"w = {w}", font_size = 31).shift(4*LEFT).shift(2*UP) 
            
            # shows the value of w, each time
            self.play(Write(w_text)) 
            self.wait()
            
            # shows the fitted line each time
            self.play(AnimationGroup(Create(plot_linear_function)))

            # calculates the cost for each value of w
            
            y = [1,2,3]
            x = [1,2,3]
            
            error_lines = VGroup()
            cost_sum = 0 
#             total_costs_list = []

            for i in range(len(x)):
                
                # linear function 
                f_wb = w * x[i]   
                
                # cost function 
                cost = (f_wb - y[i]) ** 2  # for each point in the training set
                cost_sum = cost_sum + cost  # calculates the error for each point and sums the tree points 
                
                # plot the error lines
                x_line = [x[i] , x[i]]
                y_line = [y[i], w*x[i]]
                
                # creates the three error lines that will appear at once 
                error_lines += plot_linear.plot_line_graph(x_line, y_line, add_vertex_dots=False, line_color = RED) #get_lines_to_point? 
                
            
            # outside the loop finished the cost calculation for that value of w
            total_cost = (1 / (2 * len(x))) * cost_sum 
#             total_costs_list.append(total_cost)
            
            # cost value for that value of w - dot
            cost = VGroup()
            cost += Dot(point=plot_cost.c2p(w,total_cost), color=RED)
            
            # created all error lines at the same time for that value of w
            self.play(AnimationGroup(*[Create(line) for line in error_lines], lag_ratio=0.05)) 
            self.wait()
            
            # show the cost calculated for that value of w
            j_text = Text(f"J = {total_cost:.2f}", font_size = 31).shift(4*RIGHT).shift(2*UP) 
            self.play(Write(j_text)) #shows the w
            self.wait()            
            
            # plots the cost dot for that value of w
            self.play(Create(cost))
            self.wait()
            
            # fades out the error lines, the linear model and the w - for that value of w and then after the loop ends
            self.play(AnimationGroup(*[FadeOut(line) for line in error_lines], lag_ratio=0.05)) 
            self.play(AnimationGroup(FadeOut(plot_linear_function))) 
            self.play(FadeOut(w_text), FadeOut(j_text)) 
        
        # fades out the training data
        self.play(FadeOut(dots)) 
        
        # creates the cost function curve 
        w_point, cost_point = [0, 0.5, 1], [2.3, 0.58, 0]
    
        plt0 = np.polyfit(w_point, cost_point, deg = 2)[0]
        plt1 = np.polyfit(w_point, cost_point, deg = 2)[1]
        plt2 = np.polyfit(w_point, cost_point, deg = 2)[2]

        reg_eq = lambda x: plt2 + plt1*x + plt0*x**2 

        plt = plot_cost.plot(reg_eq, x_range=[-0.07, 2.07])

        # pass a line through the dots of the cost function
        self.play(Create(plt))
        
        self.wait(5)

                                                                                                                                                                  

In [20]:
%%manim -qh CostEquation

class CostEquation(Scene):
    def construct(self):
        
        title = Text(f"Creating the cost function", font_size = 31).shift(1*UP) 
        
        text1 = Text(f"Taking the distance between prediction and data", font_size = 31).shift(1.2*UP) 
        distance = MathTex(r"(\hat{y} - y^{(i)})") #font_size = 33
        
        text2 = Text(f"Squaring this distance to avoid cancelling them out", font_size = 31).shift(1.2*UP) 
        distance_squared = MathTex(r"(\hat{y} - y^{(i)})^2")
        
        text3 = Text(f"Summing the distance across all data points", font_size = 31).shift(1.2*UP) 
        distance_summed = MathTex(r"\sum\limits_{i = 0}^{m-1} (\hat{y} - y^{(i)})^2")
        
        text4 = Text(f"Averaging the distance so it won't get larger as the data gets larger", font_size = 31).shift(1.2*UP) 
        distance_mean = MathTex(r"\frac{1}{m}\sum\limits_{i = 0}^{m-1} (\hat{y} - y^{(i)})^2")
        
        text5 = Text(f"Dividing by two to make calculations better", font_size = 31).shift(1.2*UP) 
        distance_mean2 = MathTex(r"\frac{1}{2m}\sum\limits_{i = 0}^{m-1} (\hat{y} - y^{(i)})^2")
        
        text6 = Text(f"The cost funtion is usually represented like this:", font_size = 31).shift(1.2*UP) 
        naming_cost = MathTex(r"J(w,b) = \frac{1}{2m} \sum\limits_{i = 0}^{m-1} (\hat{y} - y^{(i)})^2")
        
        text7 = Text(f"The prediction is obtained by the original line function, so...", font_size = 31).shift(1.2*UP) 
        cost_final = MathTex(r"J(w,b) = \frac{1}{2m} \sum\limits_{i = 0}^{m-1} (f_{w,b}(x^{(i)}) - y^{(i)})^2")
        
        rectangle = SurroundingRectangle( mobject = cost_final, color = RED, buff = 0.15)
        
        self.play(Create(title,run_time=2))
        self.wait()
        self.play(title.animate.shift(2.5*UP),run_time=3)
        self.wait(1)
        
        self.play(Create(text1))
        self.play(Create(distance))
        self.wait()
        self.play(FadeOut(text1))
        self.wait(1)
        
        self.play(Create(text2,run_time=2))
        self.play(TransformMatchingShapes(distance, distance_squared))
        self.wait()
        self.play(FadeOut(text2))
        self.wait(1)
        
        self.play(Create(text3,run_time=2))
        self.play(TransformMatchingShapes(distance_squared, distance_summed))
        self.wait()
        self.play(FadeOut(text3))
        self.wait(1)
        
        self.play(Create(text4,run_time=2))
        self.play(TransformMatchingShapes(distance_summed, distance_mean))
        self.wait()
        self.play(FadeOut(text4))
        self.wait(1)

        self.play(Create(text5,run_time=2))
        self.play(TransformMatchingShapes(distance_mean, distance_mean2))
        self.wait()
        self.play(FadeOut(text5))
        self.wait(1)
        
        self.play(Create(text6,run_time=2))
        self.play(TransformMatchingShapes(distance_mean2, naming_cost))
        self.wait()
        self.play(FadeOut(text6))
        self.wait(1)
        
        self.play(Create(text7,run_time=2))
        self.play(TransformMatchingShapes(naming_cost, cost_final))
        self.wait()
        self.play(FadeOut(text7))
        self.play(Create(rectangle), run_time = 5 )
        self.play(title.animate.shift(2*DOWN),run_time=3)
        self.wait(5)
        

                                                                                                                                                      