In [None]:
%%manim -qm -v WARNING NeuralNetworkSceneLastMB


class NeuralNetworkSceneLastMB(Scene):

    def construct(self):
        # Create layers without layer labels
        input_layer = self.create_layer(3, "x")
        hidden_layer = self.create_layer(4, "h")
        output_layer = self.create_layer(2, "z")
        Intro_text = Text("MB", font_size=24, weight= BOLD,color= BLUE_D).to_edge(UP).shift(LEFT*6.05)
        self.add(Intro_text)
        Neural_net_text = Text("Neural Network", font_size=16, weight= BOLD,color= BLUE_E).next_to(Intro_text,DOWN,buff=0.2) #.shift(RIGHT*0.1)
        self.add(Neural_net_text)

        # Position layers
        layers = VGroup(input_layer, hidden_layer, output_layer)
        layers.arrange(RIGHT, buff=2.7)
        self.add(layers)
        self.layers = layers

        # Create layer labels
        input_label = Text("Input Layer", font_size=24)
        hidden_label = Text("Hidden Layer", font_size=24)
        output_label = Text("Output Layer", font_size=24)
        layer_labels = VGroup(input_label, hidden_label, output_label)

        # Position layer labels above the layers at the same height
        for label, layer in zip(layer_labels, layers):
            label.move_to(layer.get_center())
            label.shift(UP * 3.2)  

        self.add(layer_labels)
        self.layer_labels = layer_labels

        # Draw connections with weights
        connections_ih, weights_ih,weights_ih_after = self.connect_layers(input_layer, hidden_layer, "w")
        connections_oh, weights_ho,weights_oh_after = self.connect_layers(hidden_layer, output_layer, "v")
        self.connections_oh = connections_oh
        self.connections_ih = connections_ih

        # Combine all connections and weights
        connections = connections_ih + connections_oh
        weights = weights_ih + weights_ho
        weights_after = weights_ih_after + weights_oh_after
        self.weights_after = weights_after

        # Add connections and weight labels to the scene
        self.add(connections, weights)

        # Animate forward pass with computations
        self.forward_pass(input_layer, hidden_layer, output_layer, connections, weights)

        # Animate backpropagation with derivatives
        self.backpropagation(input_layer, hidden_layer, output_layer, connections, weights)

    def animate_wave(self, connections, color, direction='forward', run_time=1):
 
        dots = VGroup()
        #lines = VGroup() #optional
        animations = []
        
        for connection in connections:
            # Reverse the connection path if direction is 'backward'
            if direction == 'backward':
                path = Line(
                    connection.get_end(),
                    connection.get_start(),
                    stroke_color=connection.get_stroke_color(),
                    stroke_width=connection.get_stroke_width()
                )
            else:
                path = connection.copy()
            
            # Create a dot at the start of the path
            dot = Dot(color=color, radius=0.05)
            dot.move_to(path.get_start())
            dots.add(dot)
            
            # Animate the dot moving along the path
            animation = MoveAlongPath(dot, path, run_time=run_time, rate_func=linear)
            animations.append(animation)
        
        
        #self.add(lines) #optional
        #self.remove(lines) #optional
        self.add(dots)
        self.play(*animations, lag_ratio=0)
        self.remove(dots)
        

    


    def create_layer(self, num_neurons, label_prefix):
        """
        Creates a layer with a specified number of neurons and labels each neuron uniquely.
        """
        neurons = VGroup()
        neuron_labels = VGroup()
        for i in range(num_neurons):
            neuron = Circle(radius=0.3, color=BLUE, fill_opacity=0)
            neurons.add(neuron)
        neurons.arrange(DOWN, buff=1)  # Increased buff for more vertical spacing
        for i, neuron in enumerate(neurons):
            # Create unique labels like x₁, x₂, etc.
            label = MathTex(f"{label_prefix}_{{{i+1}}}")
            label.scale(0.6)
            label.next_to(neuron, UP, buff=0.05)  # if we want to change the label to be inside the circle, change LEFT to UP
            neuron_labels.add(label)
        layer = VGroup(neurons, neuron_labels)
        return layer  # Return layer without layer label

    def connect_layers(self, layer1, layer2, weight_prefix):

        connections = VGroup()
        weight_labels = VGroup()
        weight_labels_after = VGroup()
        neurons1 = layer1[0]  
        neurons2 = layer2[0]
        for i, neuron1 in enumerate(neurons1):
            for j, neuron2 in enumerate(neurons2):
                connection = Line(
                    neuron1.get_right(),
                    neuron2.get_left(),
                    stroke_color=GREY,
                    stroke_width=2
                )
                connections.add(connection)
                weight_label = MathTex(f"{weight_prefix}_{{{j+1}{i+1}}}")
                weight_label_new = MathTex(f"\\tilde {weight_prefix}_{{{j+1}{i+1}}}",font_size=16)
                weight_label.scale(0.5)
                midpoint = connection.get_midpoint()
                direction = neuron2.get_left() - neuron1.get_right()
                unit_direction = direction / np.linalg.norm(direction)
                perp_direction = np.array([-unit_direction[1], unit_direction[0], 0])
                weight_label.move_to(midpoint + perp_direction * 0.5)
                weight_labels.add(weight_label)
                weight_label_new.move_to(midpoint + perp_direction * 0.5)
                weight_labels_after.add(weight_label_new)

        return connections, weight_labels,weight_labels_after

    def forward_pass(self, input_layer, hidden_layer, output_layer, connections, weights):

        forward_pass = Text("Forward Pass", font_size=30)
        forward_pass.to_corner(DOWN + LEFT)
        SIGMOID = MathTex("\\sigma(x) = \\frac{1}{1 + e^{-x}}",font_size=20)
        Bias = MathTex("b_j, c_t, \\quad j \\in [1,4], t \\in [1,2]: \quad \\text{Bias  Terms}",font_size=20)
        SIGMOID.to_corner(UP + RIGHT).shift(DOWN * 0.6)  
        Bias.next_to(SIGMOID,DOWN).shift(DOWN * 0.05,LEFT * 0.7)
        self.add(forward_pass)
        self.forward_pass_text = forward_pass
        self.add(SIGMOID)
        self.add(Bias)
        self.SIGMOID = SIGMOID
        self.Bias = Bias
        self.wait(2)

        
        self.play(FadeOut(connections), FadeOut(weights)) # Hide connections and weights during computations
        self.play(
            LaggedStart(
                *[neuron.animate.set_fill(YELLOW, opacity=0.5) for neuron in input_layer[0]],
                lag_ratio=0.1
            )
        )
        self.wait(1)

        connections_ih_to_hidden = self.connections_ih[:len(input_layer[1]) * len(hidden_layer[1])]
        self.play(
            LaggedStart(
                *[conn.animate.set_color(YELLOW) for conn in connections_ih_to_hidden],
                lag_ratio=0.01
            )
        )

        self.wait(0.5)

        # Show computations at hidden layer
        computations = VGroup()
        for i, neuron_label in enumerate(hidden_layer[1]):
            # Display computation next to each hidden neuron
            comp = MathTex(f"=\\sigma\\left(\\sum_{{j=1}}^{{3}} w_{{{i+1}j}} x_j + b_{{{i+1}}}\\right)")
            comp.scale(0.5)
            comp.next_to(hidden_layer[0][i], RIGHT, buff=0.1)  # Increased buff
            computations.add(comp)
        self.play(Write(computations))
        self.wait(1)
        self.play(
            LaggedStart(
                *[neuron.animate.set_fill(YELLOW, opacity=0.5) for neuron in hidden_layer[0]],
                lag_ratio=0.1
            )
        )
        self.wait(1)

        # Hide computations to avoid overlap with connections
        self.play(FadeOut(computations))
        

        self.wait(0.5)

        #self.animate_wave(self.connections_oh, color=YELLOW, direction='forward',run_time=1)
        onnections_oh_to_hidden = self.connections_oh[:len(input_layer[1]) * len(hidden_layer[1])]
        self.play(
            LaggedStart(
                *[conn.animate.set_color(YELLOW) for conn in onnections_oh_to_hidden],
                lag_ratio=0.01
            )
        )

        # Show computations at output layer
        computations_out = VGroup()
        for i, neuron_label in enumerate(output_layer[1]):
            # Display computation next to each output neuron
            comp = MathTex(f"=\\sigma\\left(\\sum_{{j=1}}^{{4}} v_{{{i+1}j}} h_j + c_{{{i+1}}}\\right)")
            comp.scale(0.5)
            comp.next_to(output_layer[0][i], RIGHT, buff=0.1)  # Increased buff
            computations_out.add(comp)
        self.play(Write(computations_out))
        self.wait(1)

        # Hide computations to avoid overlap
        self.play(FadeOut(computations_out))
        
       

        # Activate output layer neurons
        self.play(
            LaggedStart(
                *[neuron.animate.set_fill(YELLOW, opacity=0.5) for neuron in output_layer[0]],
                lag_ratio=0.1
            )
        )
        self.wait(1)

        # Show connections and weights again
        self.play(FadeIn(connections), FadeIn(weights),run_time=1)
        

    def backpropagation(self, input_layer, hidden_layer, output_layer, connections, weights):
        # Remove 'Forward Pass' text
        self.play(FadeOut(self.forward_pass_text),FadeOut(self.Bias))

        # Hide connections and weights during computations
        self.play(FadeOut(connections), FadeOut(weights))

        # Add 'Backpropagation' text
        backprop_text = Text("Backpropagation", font_size=30)
        backprop_text.to_corner(DOWN + LEFT)
        COST_FUNCTION = Text("C = Cost Function", font_size=17)
        COST_FUNCTION.to_corner(RIGHT).shift(RIGHT * 0.3)
        self.add(backprop_text)
        self.backprop_text = backprop_text
        self.add(COST_FUNCTION)
        self.COST_FUNCTION = COST_FUNCTION

        # Indicate error at output layer
        self.play(
            LaggedStart(
                *[Indicate(neuron, color=RED) for neuron in output_layer[0]],
                lag_ratio=0.01
            )
        )
        self.wait(0.5)

        # Show derivative computations at output layer
        derivatives_out = VGroup()
        for i, neuron_label in enumerate(output_layer[1]):
            # Display derivative next to each output neuron
            der = MathTex(f"d_{{z_{{{i+1}}}}} =  \\frac{{\\partial C}}{{\\partial z_{{{i+1}}}}}")
            der.scale(0.5)
            der.next_to(output_layer[0][i], LEFT, buff=0.5)  # Position derivatives on the left
            derivatives_out.add(der)
        self.play(Write(derivatives_out))
        self.wait(1)

        # Hide derivative computations to avoid overlap
        self.play(FadeOut(derivatives_out))
        self.animate_wave(self.connections_oh, color=RED, direction= 'backward',run_time=1)
       

        # Activate hidden layer neurons
        self.play(
            LaggedStart(
                *[Indicate(neuron, color=RED) for neuron in hidden_layer[0]],
                lag_ratio=0.1
            )
        )
        self.wait(0.5)

        # Show derivative computations at hidden layer
        derivatives_hidden = VGroup()
        for i, neuron_label in enumerate(hidden_layer[1]):
            # Display derivative next to each hidden neuron
            der = MathTex(f"\\nu_{{{i+1},t}} = \\sum_{{k=1}}^{{2}} \\frac{{\\partial C}}{{\\partial z_{{k}}}} \\frac{{\\partial z_{{k}}}}{{\\partial v_{{{i+1},t}}}}")

            der.scale(0.5)
            der.next_to(hidden_layer[0][i], LEFT, buff=0.17)
            derivatives_hidden.add(der)
        self.play(Write(derivatives_hidden))
        self.wait(1)

        

        # Hide derivative computations to avoid overlap
        self.play(FadeOut(derivatives_hidden))
        self.animate_wave(self.connections_ih, color=RED,direction='backward',run_time=1)
        

        # Activate input layer neurons
        self.play(
            LaggedStart(
                *[Indicate(neuron, color=RED) for neuron in input_layer[0]],
                lag_ratio=0.1
            )
        )
        self.wait(0.5)

        # Show derivative computations at input layer
        derivatives_input = VGroup()
        for i, neuron_label in enumerate(input_layer[1]):

            der = MathTex(f"\\omega_{{{i+1},j}} = \\sum_{{m=1}}^{{2}} \\sum_{{k=1}}^{{4}} \\frac{{\\partial C}}{{\\partial z_{{m}}}} \\cdot \\frac{{\\partial z_{{m}}}}{{\\partial h_{{k}}}} \\frac{{\\partial h_{{k}}}}{{\\partial w_{{{i+1},j }}}}",font_size=45)

            der.scale(0.5)
            der.next_to(input_layer[0][i], LEFT, buff=0.11)
            derivatives_input.add(der)
        self.play(Write(derivatives_input))
        self.wait(1)

        # Hide derivative computations
        self.play(FadeOut(derivatives_input))
        self.play(FadeOut(self.COST_FUNCTION),FadeOut(self.SIGMOID),FadeOut(backprop_text))
        self.play(FadeOut(connections), FadeOut(weights),FadeOut(self.layers),FadeOut(self.layer_labels),FadeOut(self.SIGMOID))
        Gradient = Text("After  Gradient  Descent On  The  Weights,  We  Update  The  Weights", font_size=30)  
        Example = MathTex("\\text{For Exapmle}  : \\tilde w_{i,j} = w_{i,j} - \\epsilon \\cdot \\omega_{i,j}", font_size=35) 
        Repeat = MathTex("\\text{Repeat  This  Process  Until  The  Cost  Function  Is  Minimized}", font_size=35)
        Gradient.shift(UP * 1)
        Repeat.to_edge(DOWN)

        arrow = Arrow(LEFT,RIGHT,stroke_width=5,stroke_color=WHITE,fill_color=BLUE,fill_opacity=0.5, buff=0.1).shift(LEFT * 4.5)

        self.play(LaggedStart(FadeIn(Gradient)),run_time=1)
        self.play(LaggedStart(FadeIn(Example)),run_time=1)
        self.wait(2)
        self.play(FadeOut(Gradient),FadeOut(Example))
        self.play(FadeIn(connections.shift(RIGHT * 0.4)), FadeIn(self.weights_after.shift(RIGHT * 0.4)),
                FadeIn(self.layers.shift(RIGHT * 0.4)),FadeIn(self.layer_labels.shift(RIGHT * 0.4)),run_time=1)
        self.play(LaggedStart(FadeIn(arrow)),run_time=1)
        self.play(LaggedStart(FadeIn(Repeat)),run_time=1)

        self.wait(2)
