# MLP Animation
here i am prompting claude to write anim code for me!

## Attempt 1
just default prompt

In [2]:
from manim import *

In [6]:
%%manim -qm -v WARNING MLPForwardPass

from manim import *
import numpy as np

class MLPForwardPass(Scene):
    def construct(self):
        # Set up the neural network parameters
        input_dim = 3
        hidden_dim = 3
        output_dim = 3
        
        # Create sample weights and input
        np.random.seed(42)  # For reproducible results
        W1 = np.random.randn(hidden_dim, input_dim) * 0.5
        W2 = np.random.randn(output_dim, hidden_dim) * 0.5
        input_vec = np.array([1.2, -0.8, 2.1])
        
        # Neural Network Visualization
        self.show_network_structure(input_dim, hidden_dim, output_dim)
        self.wait(2)
        
        # Show input coming in
        self.show_input(input_vec)
        self.wait(2)
        
        # First matrix multiplication (input to hidden)
        hidden_activations = self.animate_matrix_multiplication(
            W1, input_vec, "Input to Hidden Layer", position=LEFT*3
        )
        self.wait(2)
        
        # Apply activation function (ReLU)
        hidden_activated = np.maximum(0, hidden_activations)
        self.show_activation_function(hidden_activations, hidden_activated)
        self.wait(2)
        
        # Second matrix multiplication (hidden to output)
        output_activations = self.animate_matrix_multiplication(
            W2, hidden_activated, "Hidden to Output Layer", position=RIGHT*3
        )
        self.wait(2)
        
        # Show final result
        self.show_final_output(output_activations)
        self.wait(3)
    
    def create_matrix_display(self, matrix_data, color=WHITE):
        """Create a matrix display using Text objects and rectangles"""
        rows = len(matrix_data)
        cols = len(matrix_data[0]) if rows > 0 else 0
        
        # Create text elements for each entry
        entries = []
        max_width = 0
        max_height = 0
        
        for i in range(rows):
            row_entries = []
            for j in range(cols):
                entry = Text(str(matrix_data[i][j]), font_size=20, color=color)
                row_entries.append(entry)
                
                # Track dimensions for spacing
                if entry.width > max_width:
                    max_width = entry.width
                if entry.height > max_height:
                    max_height = entry.height
            entries.append(row_entries)
        
        # Position entries in a grid
        spacing_x = max_width + 0.3
        spacing_y = max_height + 0.2
        
        matrix_group = VGroup()
        
        for i in range(rows):
            for j in range(cols):
                entry = entries[i][j]
                x_pos = (j - (cols-1)/2) * spacing_x
                y_pos = ((rows-1)/2 - i) * spacing_y
                entry.move_to([x_pos, y_pos, 0])
                matrix_group.add(entry)
        
        # Add brackets
        if rows > 0 and cols > 0:
            height = rows * spacing_y
            left_bracket = Text("[", font_size=int(height*30)).move_to(matrix_group.get_left() + LEFT*0.3)
            right_bracket = Text("]", font_size=int(height*30)).move_to(matrix_group.get_right() + RIGHT*0.3)
            matrix_group.add(left_bracket, right_bracket)
        
        return matrix_group, entries
    
    def create_vector_display(self, vector_data, color=WHITE):
        """Create a column vector display"""
        # Handle both numeric and string inputs
        matrix_data = []
        for val in vector_data:
            if isinstance(val, str):
                matrix_data.append([val])
            else:
                matrix_data.append([f"{val:.1f}"])
        return self.create_matrix_display(matrix_data, color)
    
    def show_network_structure(self, input_dim, hidden_dim, output_dim):
        """Create and display the neural network structure"""
        self.title = Text("Multi-Layer Perceptron Forward Pass", font_size=36)
        self.title.to_edge(UP)
        self.play(Write(self.title))
        
        # Create node positions
        input_nodes = []
        hidden_nodes = []
        output_nodes = []
        
        # Input layer
        for i in range(input_dim):
            node = Circle(radius=0.3, color=BLUE, fill_opacity=0.7)
            node.move_to(LEFT*4 + UP*(1-i)*1.5)
            input_nodes.append(node)
        
        # Hidden layer
        for i in range(hidden_dim):
            node = Circle(radius=0.3, color=GREEN, fill_opacity=0.7)
            node.move_to(ORIGIN + UP*(1-i)*1.5)
            hidden_nodes.append(node)
        
        # Output layer
        for i in range(output_dim):
            node = Circle(radius=0.3, color=RED, fill_opacity=0.7)
            node.move_to(RIGHT*4 + UP*(1-i)*1.5)
            output_nodes.append(node)
        
        # Store nodes for later use
        self.input_nodes = input_nodes
        self.hidden_nodes = hidden_nodes
        self.output_nodes = output_nodes
        
        # Create edges
        edges = []
        for input_node in input_nodes:
            for hidden_node in hidden_nodes:
                edge = Line(input_node.get_center(), hidden_node.get_center(), 
                          stroke_width=1, color=GRAY)
                edges.append(edge)
        
        for hidden_node in hidden_nodes:
            for output_node in output_nodes:
                edge = Line(hidden_node.get_center(), output_node.get_center(), 
                          stroke_width=1, color=GRAY)
                edges.append(edge)
        
        self.edges = edges
        
        # Add labels
        input_label = Text("Input\nLayer", font_size=24).next_to(input_nodes[1], DOWN*2)
        hidden_label = Text("Hidden\nLayer", font_size=24).next_to(hidden_nodes[1], DOWN*2)
        output_label = Text("Output\nLayer", font_size=24).next_to(output_nodes[1], DOWN*2)
        
        # Animate network creation
        self.play(*[Create(node) for node in input_nodes + hidden_nodes + output_nodes])
        self.play(*[Create(edge) for edge in edges])
        self.play(Write(input_label), Write(hidden_label), Write(output_label))
        
        self.network_labels = [input_label, hidden_label, output_label]
    
    def show_input(self, input_vec):
        """Show the input vector and highlight input nodes"""
        # Create input vector display
        input_display, input_entries = self.create_vector_display(input_vec, BLUE)
        input_display.next_to(self.input_nodes[1], LEFT*2)
        
        input_text = Text("Input Vector", font_size=20).next_to(input_display, UP)
        
        self.play(Write(input_display), Write(input_text))
        
        # Highlight input nodes and show values
        input_values = []
        for i, (node, value) in enumerate(zip(self.input_nodes, input_vec)):
            self.play(node.animate.set_fill(YELLOW, opacity=0.9), run_time=0.5)
            value_text = Text(f"{value:.1f}", font_size=16, color=BLACK)
            value_text.move_to(node.get_center())
            input_values.append(value_text)
            self.play(Write(value_text), run_time=0.3)
        
        self.input_display = input_display
        self.input_text = input_text
        self.input_values = input_values
    
    def animate_matrix_multiplication(self, weight_matrix, input_vector, title, position):
        """Animate matrix multiplication with highlighting"""
        # Create matrix display
        weight_data = [[f"{weight_matrix[i,j]:.1f}" for j in range(weight_matrix.shape[1])] 
                      for i in range(weight_matrix.shape[0])]
        W_display, W_entries = self.create_matrix_display(weight_data, ORANGE)
        W_display.move_to(position + UP*2)
        
        # Create input vector display for this multiplication
        input_display, input_entries = self.create_vector_display(input_vector, BLUE)
        input_display.next_to(W_display, RIGHT, buff=0.5)
        
        # Title
        title_text = Text(title, font_size=24).next_to(W_display, UP)
        
        self.play(Write(W_display), Write(input_display), Write(title_text))
        
        # Calculate result
        result = np.dot(weight_matrix, input_vector)
        
        # Create result display
        result_display, result_entries = self.create_vector_display(result, GREEN)
        result_display.next_to(input_display, RIGHT, buff=0.8)
        
        equals_sign = Text("=", font_size=36).next_to(input_display, RIGHT, buff=0.3)
        
        # Animate each dot product
        highlighted_elements = []
        
        for i in range(len(result)):
            # Highlight the current row in weight matrix
            row_elements = []
            for j in range(weight_matrix.shape[1]):
                idx = i * weight_matrix.shape[1] + j
                if idx < len(W_entries[0]) * len(W_entries):
                    row_elements.append(W_entries[i][j])
            
            if row_elements:
                row_highlight = SurroundingRectangle(
                    VGroup(*row_elements), color=YELLOW, buff=0.1
                )
            
                # Highlight the input vector
                col_highlight = SurroundingRectangle(input_display, color=YELLOW, buff=0.1)
                
                self.play(Create(row_highlight), Create(col_highlight))
                
                # Show the calculation
                calc_parts = []
                for j in range(len(input_vector)):
                    if j > 0:
                        calc_parts.append(" + ")
                    calc_parts.append(f"{weight_matrix[i,j]:.1f} × {input_vector[j]:.1f}")
                calc_parts.append(f" = {result[i]:.1f}")
                
                calc_text = Text("".join(calc_parts), font_size=18).to_edge(DOWN)
                
                self.play(Write(calc_text))
                self.wait(1)
                
                # Show result entry
                if i == 0:
                    self.play(Write(equals_sign), Write(result_display))
                else:
                    # Just highlight the new result entry
                    if i < len(result_entries):
                        result_highlight = SurroundingRectangle(
                            result_entries[i][0], color=GREEN, buff=0.05
                        )
                        self.play(Create(result_highlight))
                        self.wait(0.5)
                        self.play(FadeOut(result_highlight))
                
                # Update corresponding neural network node
                if title == "Input to Hidden Layer" and i < len(self.hidden_nodes):
                    target_node = self.hidden_nodes[i]
                    self.play(target_node.animate.set_fill(GREEN, opacity=0.9))
                    
                    # Add value to node
                    value_text = Text(f"{result[i]:.1f}", font_size=14, color=BLACK)
                    value_text.move_to(target_node.get_center())
                    self.play(Write(value_text))
                elif title == "Hidden to Output Layer" and i < len(self.output_nodes):
                    target_node = self.output_nodes[i]
                    self.play(target_node.animate.set_fill(RED, opacity=0.9))
                    
                    # Add value to node
                    value_text = Text(f"{result[i]:.1f}", font_size=14, color=BLACK)
                    value_text.move_to(target_node.get_center())
                    self.play(Write(value_text))
                
                highlighted_elements.extend([row_highlight, col_highlight])
                self.play(FadeOut(calc_text))
                self.wait(0.5)
        
        # Clean up highlights
        self.play(*[FadeOut(element) for element in highlighted_elements])
        
        # Store for cleanup
        self.play(FadeOut(W_display), FadeOut(input_display), FadeOut(result_display), 
                 FadeOut(equals_sign), FadeOut(title_text))
        
        return result
    
    def show_activation_function(self, pre_activation, post_activation):
        """Show ReLU activation function application"""
        activation_text = Text("Apply ReLU: max(0, x)", font_size=24)
        activation_text.to_edge(DOWN)
        
        self.play(Write(activation_text))
        
        # Show before and after values on hidden nodes
        for i, (node, pre_val, post_val) in enumerate(zip(self.hidden_nodes, pre_activation, post_activation)):
            if abs(pre_val - post_val) > 0.001:  # Value changed due to ReLU
                # Flash the node to show change
                self.play(node.animate.set_fill(PURPLE, opacity=0.9), run_time=0.3)
                self.play(node.animate.set_fill(GREEN, opacity=0.9), run_time=0.3)
                
                # Update the value text
                for mob in self.mobjects:
                    if (isinstance(mob, Text) and 
                        hasattr(mob, 'get_center') and 
                        mob.get_center() is not None and
                        np.linalg.norm(np.array(mob.get_center()) - np.array(node.get_center())) < 0.5):
                        new_text = Text(f"{post_val:.1f}", font_size=14, color=BLACK)
                        new_text.move_to(node.get_center())
                        self.play(Transform(mob, new_text))
                        break
        
        self.wait(1)
        self.play(FadeOut(activation_text))
    
    def show_final_output(self, output_values):
        """Display the final output"""
        final_text = Text("Final Output", font_size=24, color=RED)
        final_text.to_edge(DOWN, buff=1)
        
        output_display, _ = self.create_vector_display(
            [f"{val:.2f}" for val in output_values], RED
        )
        output_display.next_to(final_text, DOWN)
        
        self.play(Write(final_text), Write(output_display))
        
        # Highlight output nodes
        for node in self.output_nodes:
            self.play(node.animate.set_stroke(YELLOW, width=4), run_time=0.3)
        
        completion_text = Text("Forward Pass Complete!", font_size=28, color=GREEN)
        completion_text.next_to(output_display, DOWN, buff=0.5)
        
        self.play(Write(completion_text))

                                                                                                                