In [1]:
import torch
hidden_states = torch.tensor(
    # [
    [[-0.083, 0.147],
     [0.029, 0.008],
     [-0.204, 0.132]],
)
d_hidden = 2
w_query = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_query.weight.data = torch.tensor(
    [[-0.381, -0.354],
     [ 0.407, -0.601]], 
    requires_grad=True).T
w_query(hidden_states)

tensor([[ 0.0915, -0.0590],
        [-0.0078, -0.0151],
        [ 0.1314, -0.0071]], grad_fn=<MmBackward0>)

## Example

Let's consider a simple example with 3 tokens, and let's work with an embedding dimension of 3. That means each token turns into a 3-dimensional vector.



In [25]:
prompt = "012"
# for i in range(3,10):
#     print(f"{i}: torch." )
#     print(torch.randint(-3,3,(1,3), dtype=torch.float32))

In [26]:
tokenizer = {
    0: torch.tensor([ 1.,  0.,  0.], dtype=torch.float32),
    1: torch.tensor([-2.,  1.,  0.], dtype=torch.float32),
    2: torch.tensor([-1.,  1.,  1.], dtype=torch.float32),
    3: torch.tensor([ 2.,  2.,  2.], dtype=torch.float32),
    4: torch.tensor([-1.,  0.,  0.], dtype=torch.float32),
    5: torch.tensor([-2., -1.,  0.], dtype=torch.float32),
    6: torch.tensor([-1.,  0., -1.], dtype=torch.float32),
    7: torch.tensor([ 2., -1., -3.], dtype=torch.float32),
    8: torch.tensor([-1.,  0., -3.], dtype=torch.float32),
    9: torch.tensor([-2.,  1.,  0.], dtype=torch.float32),
}

In [27]:
embedding_list = []
for input_id in prompt:
    embedding_list.append(tokenizer[int(input_id)])
hidden_states = torch.stack(embedding_list)
hidden_states

tensor([[ 1.,  0.,  0.],
        [-2.,  1.,  0.],
        [-1.,  1.,  1.]])

In [30]:
import torch
hidden_states = torch.tensor(
    [[1,0,0],
    [-2,1,0],
    [-1,1,1]],
    dtype=torch.float32,
)
weights = torch.tensor(
    [[2,1,1],
    [4,-6,0],
    [-2,7,2]],
    dtype=torch.float32,
    requires_grad=True)

d_hidden = 2
w_query = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_query.weight.data = weights.T

w_key = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_key.weight.data = weights.T

w_value = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_value.weight.data = weights.T

queries = w_query(hidden_states)
print(f"queries are \n{queries}")

keys = w_key(hidden_states)
print(f"keys.T are \n{keys.T}")

attention_map = torch.matmul(queries, keys.T)
print(f"torch.matmul(queries, keys.T) is \n{torch.matmul(queries, keys.T)}")

queries are 
tensor([[ 2.,  1.,  1.],
        [ 0., -8., -2.],
        [ 0.,  0.,  1.]], grad_fn=<MmBackward0>)
keys.T are 
tensor([[ 2.,  0.,  0.],
        [ 1., -8.,  0.],
        [ 1., -2.,  1.]], grad_fn=<PermuteBackward0>)
torch.matmul(queries, keys.T) is 
tensor([[  6., -10.,   1.],
        [-10.,  68.,  -2.],
        [  1.,  -2.,   1.]], grad_fn=<MmBackward0>)


## Visualising things with Manim
(Claude wrote this)

In [4]:
from manim import *
import os
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ['PATH']

In [5]:
%%manim -qm AttentionMatMul

from manim import *
import numpy as np

class AttentionMatMul(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix mobjects
        def create_matrix(data, name, color=WHITE):
            matrix = Matrix(data.astype(int), h_buff=1.5, v_buff=1.5)
            matrix.set_color(color)
            label = Text(name, font_size=24).next_to(matrix, UP)
            return VGroup(matrix, label)
        
        # Create all matrices
        hidden_matrix = create_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3)
        weights_T_matrix.move_to(RIGHT * 3)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: Compute queries = hidden_states @ weights.T", font_size=28, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication
        mult_sign = MathTex("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2),
            weights_T_matrix.animate.move_to(RIGHT * 2),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = MathTex("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.7).move_to(LEFT * 4),
            weights_T_matrix.animate.scale(0.7).move_to(LEFT * 1.5),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Highlight computation for one element
        highlight_rect = SurroundingRectangle(queries_matrix[0][0][0], color=YELLOW, buff=0.1)
        self.play(Create(highlight_rect))
        
        # Show calculation for queries[0,0] = 1*2 + 0*4 + 0*(-2) = 2
        calc_text = MathTex("queries[0,0] = 1 \\cdot 2 + 0 \\cdot 4 + 0 \\cdot (-2) = 2", font_size=24)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(highlight_rect), FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: Compute attention_map = queries @ keys.T", font_size=28, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3))
        keys_T_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = MathTex("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements (self-attention)
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0][i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.5)
            self.play(FadeOut(rect), run_time=0.3)
        
        self.wait(2)
        
        # Add interpretation text
        interp_text = Text("Higher values indicate stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

# To render this animation in Jupyter:
# Use %%manim -qm -v WARNING AttentionMatMul at the top of the cell
# Or save to a file and run: manim -pql attention_matmul.py AttentionMatMul

ValueError: latex error converting to dvi. See log output above or the log file: media/Tex/6ecf9f51170c1a70.log

In [6]:
%%manim -qm -v WARNING AttentionMatMulNoLatex

from manim import *
import numpy as np

class AttentionMatMulNoLatex(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix mobjects
        def create_matrix(data, name, color=WHITE):
            matrix = Matrix(data.astype(int), h_buff=1.5, v_buff=1.5)
            matrix.set_color(color)
            label = Text(name, font_size=24).next_to(matrix, UP)
            return VGroup(matrix, label)
        
        # Create all matrices
        hidden_matrix = create_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3)
        weights_T_matrix.move_to(RIGHT * 3)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: Compute queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication using Text instead of MathTex
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2),
            weights_T_matrix.animate.move_to(RIGHT * 2),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = Text("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.7).move_to(LEFT * 4),
            weights_T_matrix.animate.scale(0.7).move_to(LEFT * 1.5),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Highlight computation for one element
        highlight_rect = SurroundingRectangle(queries_matrix[0][0][0], color=YELLOW, buff=0.1)
        self.play(Create(highlight_rect))
        
        # Show calculation using Text instead of MathTex
        calc_text = Text("queries[0,0] = 1×2 + 0×4 + 0×(-2) = 2", font_size=24)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(highlight_rect), FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: Compute attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3))
        keys_T_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements (self-attention)
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0][i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.5)
            self.play(FadeOut(rect), run_time=0.3)
        
        self.wait(2)
        
        # Add interpretation text
        interp_text = Text("Higher values indicate stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

ValueError: latex error converting to dvi. See log output above or the log file: media/Tex/6ecf9f51170c1a70.log

In [8]:
%%manim -qm -v WARNING AttentionMatMulPureText

from manim import *
import numpy as np

class AttentionMatMulPureText(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix display without using Matrix class
        def create_text_matrix(data, name, color=WHITE):
            # Convert to integers for display
            data_int = data.astype(int)
            
            # Create text for each element
            elements = []
            for i in range(data_int.shape[0]):
                row = []
                for j in range(data_int.shape[1]):
                    elem = Text(str(data_int[i, j]), font_size=30)
                    row.append(elem)
                elements.append(row)
            
            # Arrange in grid
            matrix_group = VGroup()
            for i, row in enumerate(elements):
                for j, elem in enumerate(row):
                    elem.move_to(np.array([j * 0.8 - 0.8, -i * 0.6 + 0.6, 0]))
                    matrix_group.add(elem)
            
            # Add brackets manually
            left_bracket = Text("[", font_size=60)
            right_bracket = Text("]", font_size=60)
            
            matrix_group.move_to(ORIGIN)
            left_bracket.next_to(matrix_group, LEFT, buff=0.2)
            right_bracket.next_to(matrix_group, RIGHT, buff=0.2)
            
            # Combine everything
            full_matrix = VGroup(left_bracket, matrix_group, right_bracket)
            full_matrix.set_color(color)
            
            # Add label
            label = Text(name, font_size=24).next_to(full_matrix, UP)
            
            return VGroup(full_matrix, label)
        
        # Create all matrices
        hidden_matrix = create_text_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_text_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_text_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_text_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_text_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3.5)
        weights_T_matrix.move_to(RIGHT * 3.5)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2.5),
            weights_T_matrix.animate.move_to(RIGHT * 2.5),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = Text("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.6).move_to(LEFT * 4.5),
            weights_T_matrix.animate.scale(0.6).move_to(LEFT * 2),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Show calculation example
        calc_text = Text("Example: queries[0,0] = 1*2 + 0*4 + 0*(-2) = 2", font_size=20)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3.5))
        keys_T_matrix.move_to(RIGHT * 3.5)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements
        highlight_text = Text("Diagonal values show self-attention", font_size=22, color=GOLD)
        highlight_text.to_edge(DOWN)
        self.play(Write(highlight_text))
        self.wait(2)
        
        # Add interpretation text
        self.play(FadeOut(highlight_text))
        interp_text = Text("Higher values = stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                               

## Attempt 2: Mechanics of MatMul

Prompting Claude to show this more explicitly

In [3]:
from manim import *

In [4]:
%%manim -qm -v WARNING AttentionMatMulImproved

from manim import *
import numpy as np

class AttentionMatMulImproved(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix display without using Matrix class
        def create_text_matrix(data, name, color=WHITE):
            # Convert to integers for display
            data_int = data.astype(int)
            
            # Create text for each element
            elements = []
            for i in range(data_int.shape[0]):
                row = []
                for j in range(data_int.shape[1]):
                    elem = Text(str(data_int[i, j]), font_size=30)
                    row.append(elem)
                elements.append(row)
            
            # Arrange in grid
            matrix_group = VGroup()
            positions = []
            for i, row in enumerate(elements):
                for j, elem in enumerate(row):
                    pos = np.array([j * 0.8 - 0.8 * (len(row)-1)/2, -i * 0.6 + 0.6 * (len(elements)-1)/2, 0])
                    elem.move_to(pos)
                    matrix_group.add(elem)
                    positions.append((i, j, elem))
            
            # Get the bounding box of all elements
            matrix_group.move_to(ORIGIN)
            
            # Add brackets that properly encapsulate the matrix
            height = len(elements) * 0.6 + 0.4
            left_bracket = Text("[", font_size=int(60 * (height/2)))
            right_bracket = Text("]", font_size=int(60 * (height/2)))
            
            # Position brackets to encapsulate entire matrix
            left_bracket.stretch_to_fit_height(height)
            right_bracket.stretch_to_fit_height(height)
            left_bracket.next_to(matrix_group, LEFT, buff=0.2)
            right_bracket.next_to(matrix_group, RIGHT, buff=0.2)
            
            # Combine everything
            full_matrix = VGroup(left_bracket, matrix_group, right_bracket)
            full_matrix.set_color(color)
            
            # Add label
            label = Text(name, font_size=24).next_to(full_matrix, UP)
            
            # Store element references for easy access
            full_matrix.elements = elements
            full_matrix.positions = positions
            
            return VGroup(full_matrix, label)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix = create_text_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_text_matrix(weights_T, "weights.T", GREEN)
        
        hidden_matrix.move_to(LEFT * 3.5)
        weights_T_matrix.move_to(RIGHT * 3.5)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries with visualization
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Position matrices for multiplication
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2.5),
            weights_T_matrix.animate.move_to(RIGHT * 2.5),
            Write(mult_sign)
        )
        self.wait()
        
        # Create result matrix placeholder
        queries_matrix = create_text_matrix(queries, "queries", YELLOW)
        queries_matrix.move_to(DOWN * 2.5)
        queries_matrix[0].set_opacity(0.2)  # Make matrix semi-transparent initially
        self.play(FadeIn(queries_matrix))
        
        # Visualize matrix multiplication for a few elements
        # Show row 0, col 0
        row_highlight = SurroundingRectangle(
            VGroup(*[hidden_matrix[0].elements[0][j] for j in range(3)]),
            color=YELLOW, buff=0.1
        )
        col_highlight = SurroundingRectangle(
            VGroup(*[weights_T_matrix[0].elements[j][0] for j in range(3)]),
            color=YELLOW, buff=0.1
        )
        
        self.play(
            Create(row_highlight),
            Create(col_highlight)
        )
        
        # Show calculation
        calc_text = Text("1×2 + 0×4 + 0×(-2) = 2", font_size=20, color=YELLOW)
        calc_text.next_to(queries_matrix, DOWN)
        self.play(Write(calc_text))
        self.play(
            queries_matrix[0].elements[0][0].animate.set_opacity(1).set_color(YELLOW),
            run_time=0.5
        )
        self.wait(0.5)
        
        # Show row 1, col 1 (faster)
        self.play(
            row_highlight.animate.move_to(VGroup(*[hidden_matrix[0].elements[1][j] for j in range(3)])),
            col_highlight.animate.move_to(VGroup(*[weights_T_matrix[0].elements[j][1] for j in range(3)])),
            Transform(calc_text, Text("(-2)×1 + 1×(-6) + 0×7 = -8", font_size=20, color=YELLOW).next_to(queries_matrix, DOWN)),
            run_time=0.5
        )
        self.play(
            queries_matrix[0].elements[1][1].animate.set_opacity(1).set_color(YELLOW),
            run_time=0.3
        )
        
        # Fill in remaining elements quickly
        self.play(
            FadeOut(row_highlight),
            FadeOut(col_highlight),
            FadeOut(calc_text),
            *[queries_matrix[0].elements[i][j].animate.set_opacity(1).set_color(YELLOW) 
              for i in range(3) for j in range(3) if not (i == 0 and j == 0) and not (i == 1 and j == 1)],
            run_time=1
        )
        
        # Move to final position
        self.play(
            FadeOut(mult_sign),
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            queries_matrix.animate.move_to(LEFT * 3.5).scale(0.9),
            run_time=1
        )
        
        # Step 3: Compute attention map with visualization
        self.play(FadeOut(step2_text))
        step3_text = Text("Step 3: attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Create keys.T
        keys_T_matrix = create_text_matrix(keys.T, "keys.T", ORANGE)
        keys_T_matrix.move_to(RIGHT * 3.5).scale(0.9)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=40).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Create attention matrix placeholder
        attention_matrix = create_text_matrix(attention_map, "attention_map", RED)
        attention_matrix.move_to(DOWN * 2.5)
        attention_matrix[0].set_opacity(0.2)
        self.play(FadeIn(attention_matrix))
        
        # Quick visualization of one element
        row_highlight2 = SurroundingRectangle(
            VGroup(*[queries_matrix[0].elements[0][j] for j in range(3)]),
            color=RED, buff=0.1
        )
        col_highlight2 = SurroundingRectangle(
            VGroup(*[keys_T_matrix[0].elements[j][0] for j in range(3)]),
            color=RED, buff=0.1
        )
        
        self.play(
            Create(row_highlight2),
            Create(col_highlight2),
            run_time=0.5
        )
        
        # Show calculation briefly
        calc_text2 = Text("2×2 + 4×4 + (-2)×(-2) = 24", font_size=20, color=RED)
        calc_text2.next_to(attention_matrix, DOWN)
        self.play(
            Write(calc_text2),
            attention_matrix[0].elements[0][0].animate.set_opacity(1).set_color(RED),
            run_time=0.5
        )
        
        # Fill in remaining elements
        self.play(
            FadeOut(row_highlight2),
            FadeOut(col_highlight2),
            FadeOut(calc_text2),
            *[attention_matrix[0].elements[i][j].animate.set_opacity(1).set_color(RED) 
              for i in range(3) for j in range(3) if not (i == 0 and j == 0)],
            run_time=1.5
        )
        
        # Final arrangement
        self.play(
            FadeOut(mult_sign2),
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            attention_matrix.animate.move_to(ORIGIN).scale(1.3),
            run_time=1
        )
        
        # Final display
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Highlight diagonal elements
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0].elements[i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.3)
            self.play(FadeOut(rect), run_time=0.2)
        
        # Add interpretation text
        interp_text = Text("Diagonal: self-attention scores", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                               

## Attempt 3

Trying to animate full matrix matmul, with slower anim on first couple entries and faster on last ones.

In [7]:
%%manim -qm -v WARNING AttentionMatMulImproved

from manim import *
import numpy as np

class AttentionMatMulImproved(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix display without using Matrix class
        def create_text_matrix(data, name, color=WHITE):
            # Convert to integers for display
            data_int = data.astype(int)
            
            # Create text for each element
            elements = []
            for i in range(data_int.shape[0]):
                row = []
                for j in range(data_int.shape[1]):
                    elem = Text(str(data_int[i, j]), font_size=30)
                    row.append(elem)
                elements.append(row)
            
            # Arrange in grid
            matrix_group = VGroup()
            positions = []
            for i, row in enumerate(elements):
                for j, elem in enumerate(row):
                    pos = np.array([j * 0.8 - 0.8 * (len(row)-1)/2, -i * 0.6 + 0.6 * (len(elements)-1)/2, 0])
                    elem.move_to(pos)
                    matrix_group.add(elem)
                    positions.append((i, j, elem))
            
            # Get the bounding box of all elements
            matrix_group.move_to(ORIGIN)
            
            # Add brackets that properly encapsulate the matrix
            height = len(elements) * 0.6 + 0.4
            left_bracket = Text("[", font_size=int(60 * (height/2)))
            right_bracket = Text("]", font_size=int(60 * (height/2)))
            
            # Position brackets to encapsulate entire matrix
            left_bracket.stretch_to_fit_height(height)
            right_bracket.stretch_to_fit_height(height)
            left_bracket.next_to(matrix_group, LEFT, buff=0.2)
            right_bracket.next_to(matrix_group, RIGHT, buff=0.2)
            
            # Combine everything
            full_matrix = VGroup(left_bracket, matrix_group, right_bracket)
            full_matrix.set_color(color)
            
            # Add label
            label = Text(name, font_size=24).next_to(full_matrix, UP)
            
            # Store element references for easy access
            full_matrix.elements = elements
            full_matrix.positions = positions
            
            return VGroup(full_matrix, label)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix = create_text_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_text_matrix(weights_T, "weights.T", GREEN)
        
        hidden_matrix.move_to(LEFT * 3.5)
        weights_T_matrix.move_to(RIGHT * 3.5)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries with visualization
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Position matrices for multiplication
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2.5),
            weights_T_matrix.animate.move_to(RIGHT * 2.5),
            Write(mult_sign)
        )
        self.wait()
        
        # Create result matrix placeholder
        queries_matrix = create_text_matrix(queries, "queries", YELLOW)
        queries_matrix.move_to(DOWN * 2.5)
        queries_matrix[0].set_opacity(0.2)  # Make matrix semi-transparent initially
        self.play(FadeIn(queries_matrix))
        
        # Calculate display position for calculation text
        calc_text_pos = queries_matrix.get_bottom() + DOWN * 0.5
        
        # Visualize matrix multiplication for ALL elements
        entry_count = 0
        for i in range(3):
            for j in range(3):
                # Create highlights
                row_highlight = SurroundingRectangle(
                    VGroup(*[hidden_matrix[0].elements[i][k] for k in range(3)]),
                    color=YELLOW, buff=0.1
                )
                col_highlight = SurroundingRectangle(
                    VGroup(*[weights_T_matrix[0].elements[k][j] for k in range(3)]),
                    color=YELLOW, buff=0.1
                )
                
                # Calculate the actual dot product for display
                dot_product = sum(hidden_states[i, k] * weights_T[k, j] for k in range(3))
                calc_str = f"{int(hidden_states[i,0])}×{int(weights_T[0,j])} + {int(hidden_states[i,1])}×{int(weights_T[1,j])} + {int(hidden_states[i,2])}×{int(weights_T[2,j])} = {int(dot_product)}"
                
                # Determine animation speed
                if entry_count < 3:  # First two entries - slower
                    highlight_time = 0.7
                    calc_time = 0.5
                    wait_time = 0.5
                else:  # Other entries - speed up
                    highlight_time = 0.5
                    calc_time = 0.3
                    wait_time = 0.2
                
                # Show calculation
                if entry_count < 2:  # Only show calculation text for first two
                    calc_text = Text(calc_str, font_size=20, color=YELLOW)
                    calc_text.move_to(calc_text_pos)
                    self.play(
                        Create(row_highlight),
                        Create(col_highlight),
                        Write(calc_text),
                        run_time=highlight_time
                    )
                    self.play(
                        queries_matrix[0].elements[i][j].animate.set_opacity(1).set_color(YELLOW),
                        run_time=calc_time
                    )
                    self.wait(wait_time)
                    self.play(
                        FadeOut(row_highlight),
                        FadeOut(col_highlight),
                        FadeOut(calc_text),
                        run_time=calc_time
                    )
                else:  # For remaining entries, no calculation text
                    self.play(
                        Create(row_highlight),
                        Create(col_highlight),
                        run_time=highlight_time
                    )
                    self.play(
                        queries_matrix[0].elements[i][j].animate.set_opacity(1).set_color(YELLOW),
                        FadeOut(row_highlight),
                        FadeOut(col_highlight),
                        run_time=calc_time
                    )
                    if wait_time > 0:
                        self.wait(wait_time)
                
                entry_count += 1
        
        # Move to final position
        self.play(
            FadeOut(mult_sign),
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            queries_matrix.animate.move_to(LEFT * 3.5).scale(0.9),
            run_time=1
        )
        
        # Step 3: Compute attention map with visualization
        self.play(FadeOut(step2_text))
        step3_text = Text("Step 3: attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Create keys.T
        keys_T_matrix = create_text_matrix(keys.T, "keys.T", ORANGE)
        keys_T_matrix.move_to(RIGHT * 3.5).scale(0.9)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=40).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Create attention matrix placeholder
        attention_matrix = create_text_matrix(attention_map, "attention_map", RED)
        attention_matrix.move_to(DOWN * 2.5)
        attention_matrix[0].set_opacity(0.2)
        self.play(FadeIn(attention_matrix))
        
        # Visualize attention matrix calculation (faster than queries)
        entry_count = 0
        for i in range(3):
            for j in range(3):
                # Create highlights
                row_highlight2 = SurroundingRectangle(
                    VGroup(*[queries_matrix[0].elements[i][k] for k in range(3)]),
                    color=RED, buff=0.1
                )
                col_highlight2 = SurroundingRectangle(
                    VGroup(*[keys_T_matrix[0].elements[k][j] for k in range(3)]),
                    color=RED, buff=0.1
                )
                
                # Determine animation speed
                if entry_count < 3:  # First two entries - slower
                    highlight_time = 0.7
                    calc_time = 0.5
                    wait_time = 0.5
                else:  # Other entries - speed up
                    highlight_time = 0.5
                    calc_time = 0.3
                    wait_time = 0.2
                
                # Show calculation
                if entry_count < 2:  # Only show calculation text for first two
                    dot_product = sum(queries[i, k] * keys.T[k, j] for k in range(3))
                    calc_str = f"{int(queries[i,0])}×{int(keys.T[0,j])} + {int(queries[i,1])}×{int(keys.T[1,j])} + {int(queries[i,2])}×{int(keys.T[2,j])} = {int(dot_product)}"
                    calc_text2 = Text(calc_str, font_size=20, color=RED)
                    calc_text2.move_to(calc_text_pos)
                    
                    self.play(
                        Create(row_highlight2),
                        Create(col_highlight2),
                        Write(calc_text2),
                        run_time=highlight_time
                    )
                    self.play(
                        attention_matrix[0].elements[i][j].animate.set_opacity(1).set_color(RED),
                        run_time=calc_time
                    )
                    self.wait(wait_time)
                    self.play(
                        FadeOut(row_highlight2),
                        FadeOut(col_highlight2),
                        FadeOut(calc_text2),
                        run_time=calc_time
                    )
                else:
                    self.play(
                        Create(row_highlight2),
                        Create(col_highlight2),
                        run_time=highlight_time
                    )
                    self.play(
                        attention_matrix[0].elements[i][j].animate.set_opacity(1).set_color(RED),
                        FadeOut(row_highlight2),
                        FadeOut(col_highlight2),
                        run_time=calc_time
                    )
                
                entry_count += 1
        
        # Final arrangement
        self.play(
            FadeOut(mult_sign2),
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            attention_matrix.animate.move_to(ORIGIN).scale(1.3),
            run_time=1
        )
        
        # Final display
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Highlight diagonal elements
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0].elements[i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.3)
            self.play(FadeOut(rect), run_time=0.2)
        
        # Add interpretation text
        interp_text = Text("Diagonal: self-attention scores", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                               

# Now we're animating https://e2eml.school/transformers.html

In [1]:
from manim import *

In [2]:
%%manim -qm -v WARNING TransformerTutorial

from manim import *
import numpy as np

class TransformerTutorial(Scene):
    def construct(self):
        # Title sequence
        title = Text("How Transformers Work", font_size=72, color=BLUE)
        subtitle = Text("A Complete Visual Guide", font_size=36, color=WHITE)
        subtitle.next_to(title, DOWN)
        
        self.play(Write(title))
        self.play(Write(subtitle))
        self.wait(2)
        self.play(FadeOut(title), FadeOut(subtitle))
        
        # Section 1: One-hot Encoding
        self.one_hot_encoding()
        
        # Section 2: Dot Product
        self.dot_product()
        
        # Section 3: Matrix Multiplication
        self.matrix_multiplication()
        
        # Section 4: Matrix as Lookup
        self.matrix_lookup()
        
        # Section 5: First Order Markov Model
        self.first_order_model()
        
        # Section 6: Second Order Model
        self.second_order_model()
        
        # Section 7: Second Order with Skips
        self.second_order_skips()
        
        # Section 8: Masking and Attention
        self.masking_attention()
        
        # Section 9: Embeddings
        self.embeddings()
        
        # Section 10: Positional Encoding
        self.positional_encoding()
        
        # Section 11: Multi-head Attention
        self.multihead_attention()
        
        # Section 12: Complete Architecture
        self.complete_architecture()
        
        # Conclusion
        self.conclusion()

    def one_hot_encoding(self):
        section_title = Text("One-Hot Encoding", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show vocabulary
        vocab_text = Text("Vocabulary: files, find, my", font_size=24)
        vocab_text.to_edge(UP).shift(DOWN * 0.8)
        self.play(Write(vocab_text))
        
        # Show sentence
        sentence = Text("Find my files", font_size=32, color=GREEN)
        sentence.shift(UP * 2)
        self.play(Write(sentence))
        
        # Show word-to-number mapping
        mapping = VGroup(
            Text("files = 1", font_size=24),
            Text("find = 2", font_size=24),
            Text("my = 3", font_size=24)
        ).arrange(DOWN, aligned_edge=LEFT)
        mapping.shift(LEFT * 4 + UP * 0.5)
        self.play(Write(mapping))
        
        # Show number sequence
        number_seq = Text("[2, 3, 1]", font_size=28, color=ORANGE)
        number_seq.next_to(sentence, DOWN)
        self.play(Write(number_seq))
        
        self.wait(1)
        
        # Transform to one-hot
        one_hot_title = Text("One-Hot Representation:", font_size=24)
        one_hot_title.shift(DOWN * 0.5)
        self.play(Write(one_hot_title))
        
        # Create one-hot vectors
        find_vector = self.create_vector([0, 1, 0], "find")
        my_vector = self.create_vector([0, 0, 1], "my") 
        files_vector = self.create_vector([1, 0, 0], "files")
        
        vectors_group = VGroup(find_vector, my_vector, files_vector)
        vectors_group.arrange(RIGHT, buff=1).shift(DOWN * 2)
        
        self.play(Create(vectors_group))
        self.wait(2)
        
        # Show as matrix
        matrix_text = Text("As a matrix:", font_size=24)
        matrix_text.shift(DOWN * 3.5)
        self.play(Write(matrix_text))
        
        matrix = Matrix([
            [0, 1, 0],
            [0, 0, 1],
            [1, 0, 0]
        ]).scale(0.8)
        matrix.next_to(matrix_text, DOWN)
        self.play(Create(matrix))
        
        self.wait(2)
        self.clear_scene()

    def create_vector(self, values, label):
        vector = Matrix([[v] for v in values]).scale(0.6)
        label_text = Text(label, font_size=16)
        label_text.next_to(vector, UP)
        return VGroup(vector, label_text)

    def dot_product(self):
        section_title = Text("Dot Product", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show dot product definition
        definition = Text("Multiply corresponding elements, then add", font_size=24)
        definition.shift(UP * 2.5)
        self.play(Write(definition))
        
        # Example vectors
        vec_a = Matrix([[1], [0], [0]]).scale(0.8)
        vec_b = Matrix([[1], [0], [0]]).scale(0.8)
        
        vec_a.shift(LEFT * 3 + UP * 1)
        vec_b.shift(LEFT * 1 + UP * 1)
        
        dot_symbol = Text("•", font_size=32)
        dot_symbol.move_to((vec_a.get_right() + vec_b.get_left()) / 2)
        
        self.play(Create(vec_a), Create(vec_b), Write(dot_symbol))
        
        # Show calculation
        calc_text = Text("1×1 + 0×0 + 0×0 = 1", font_size=24)
        calc_text.next_to(vec_b, RIGHT, buff=1)
        self.play(Write(calc_text))
        
        # Key insights
        insights = VGroup(
            Text("One-hot vector • itself = 1", font_size=20, color=GREEN),
            Text("Different one-hot vectors • = 0", font_size=20, color=RED)
        ).arrange(DOWN)
        insights.shift(DOWN * 1.5)
        self.play(Write(insights))
        
        self.wait(3)
        self.clear_scene()

    def matrix_multiplication(self):
        section_title = Text("Matrix Multiplication", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show matrices
        mat_a = Matrix([[1, 0, 0]]).scale(0.7)
        mat_b = Matrix([[2], [3], [1]]).scale(0.7)
        
        mat_a.shift(LEFT * 2)
        mat_b.shift(RIGHT * 0)
        
        times_symbol = Text("×", font_size=32)
        times_symbol.move_to((mat_a.get_right() + mat_b.get_left()) / 2)
        
        self.play(Create(mat_a), Create(mat_b), Write(times_symbol))
        
        # Show result
        equals = Text("=", font_size=32)
        result = Matrix([[2]]).scale(0.7)
        equals.next_to(mat_b, RIGHT, buff=0.5)
        result.next_to(equals, RIGHT, buff=0.5)
        
        self.play(Write(equals), Create(result))
        
        # Show rule
        rule = Text("Columns in A must equal rows in B", font_size=20, color=ORANGE)
        rule.shift(DOWN * 1.5)
        self.play(Write(rule))
        
        # Larger example
        self.wait(1)
        self.play(FadeOut(VGroup(mat_a, mat_b, times_symbol, equals, result, rule)))
        
        larger_a = Matrix([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ]).scale(0.6)
        larger_b = Matrix([
            [2, 5],
            [3, 6], 
            [1, 4]
        ]).scale(0.6)
        
        larger_a.shift(LEFT * 2.5)
        larger_b.shift(RIGHT * 0.5)
        
        times2 = Text("×", font_size=24)
        times2.move_to((larger_a.get_right() + larger_b.get_left()) / 2)
        
        self.play(Create(larger_a), Create(larger_b), Write(times2))
        
        equals2 = Text("=", font_size=24)
        result2 = Matrix([
            [2, 5],
            [3, 6],
            [1, 4]
        ]).scale(0.6)
        equals2.next_to(larger_b, RIGHT, buff=0.3)
        result2.next_to(equals2, RIGHT, buff=0.3)
        
        self.play(Write(equals2), Create(result2))
        
        self.wait(3)
        self.clear_scene()

    def matrix_lookup(self):
        section_title = Text("Matrix Multiplication as Table Lookup", font_size=40, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show one-hot matrix
        one_hot_matrix = Matrix([
            [1, 0, 0, 0],
            [0, 0, 0, 1], 
            [0, 0, 1, 0]
        ]).scale(0.6)
        one_hot_matrix.shift(LEFT * 3)
        
        # Show lookup table
        lookup_table = Matrix([
            ["show", "me"],
            ["my", "files"],
            ["please", "find"],
            ["photos", "dirs"]
        ]).scale(0.5)
        lookup_table.shift(RIGHT * 1)
        
        times = Text("×", font_size=24)
        times.move_to((one_hot_matrix.get_right() + lookup_table.get_left()) / 2)
        
        self.play(Create(one_hot_matrix), Write(times), Create(lookup_table))
        
        # Show result
        equals = Text("=", font_size=24)
        result = Matrix([
            ["show", "me"],
            ["photos", "dirs"],
            ["please", "find"]
        ]).scale(0.5)
        equals.next_to(lookup_table, RIGHT, buff=0.3)
        result.next_to(equals, RIGHT, buff=0.3)
        
        self.play(Write(equals), Create(result))
        
        # Highlight the lookup mechanism
        highlight_text = Text("One-hot vectors 'pull out' specific rows!", 
                            font_size=24, color=GREEN)
        highlight_text.shift(DOWN * 2.5)
        self.play(Write(highlight_text))
        
        core_text = Text("This is the core of how transformers work", 
                        font_size=20, color=RED)
        core_text.next_to(highlight_text, DOWN)
        self.play(Write(core_text))
        
        self.wait(3)
        self.clear_scene()

    def first_order_model(self):
        section_title = Text("First Order Sequence Model", font_size=40, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show example sentences
        sentences = VGroup(
            Text("Show me my directories please", font_size=20),
            Text("Show me my files please", font_size=20),
            Text("Show me my photos please", font_size=20)
        ).arrange(DOWN)
        sentences.shift(UP * 2)
        self.play(Write(sentences))
        
        # Show transition matrix concept
        matrix_title = Text("Transition Matrix (Markov Chain):", font_size=24)
        matrix_title.shift(UP * 0.5)
        self.play(Write(matrix_title))
        
        # Create simplified transition matrix
        transition_matrix = Matrix([
            [0, 1, 0, 0, 0],
            [0, 0, 1, 0, 0],
            [0, 0, 0, 0.2, 0.3],  # after "my"
            [0, 0, 0, 0, 1],
            [0, 0, 0, 0, 0]
        ]).scale(0.6)
        transition_matrix.shift(DOWN * 1)
        self.play(Create(transition_matrix))
        
        # Label rows and columns
        row_labels = VGroup(
            Text("show", font_size=12),
            Text("me", font_size=12),
            Text("my", font_size=12),
            Text("dirs", font_size=12),
            Text("files", font_size=12)
        ).arrange(DOWN, buff=0.8)
        row_labels.next_to(transition_matrix, LEFT)
        self.play(Write(row_labels))
        
        # Key insight
        insight = Text("Each row shows probability of next word", 
                      font_size=20, color=GREEN)
        insight.shift(DOWN * 3)
        self.play(Write(insight))
        
        self.wait(3)
        self.clear_scene()

    def second_order_model(self):
        section_title = Text("Second Order Model", font_size=40, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show the problem
        problem = VGroup(
            Text("Check whether the battery ran down please", font_size=18),
            Text("Check whether the program ran please", font_size=18)
        ).arrange(DOWN)
        problem.shift(UP * 2.5)
        self.play(Write(problem))
        
        # Show the insight
        insight_text = Text("Looking at TWO words gives more context!", 
                          font_size=24, color=GREEN)
        insight_text.shift(UP * 1.5)
        self.play(Write(insight_text))
        
        # Show word pairs
        pairs = VGroup(
            Text("'battery ran' → 'down'", font_size=20, color=BLUE),
            Text("'program ran' → 'please'", font_size=20, color=BLUE)
        ).arrange(DOWN)
        pairs.shift(UP * 0.5)
        self.play(Write(pairs))
        
        # Show matrix size issue
        size_text = Text("Second order matrix has N² rows!", 
                        font_size=20, color=RED)
        size_text.shift(DOWN * 0.5)
        self.play(Write(size_text))
        
        # Show benefit
        benefit = Text("More confidence: more 1s, fewer fractions", 
                      font_size=20, color=GREEN)
        benefit.shift(DOWN * 1.5)
        self.play(Write(benefit))
        
        self.wait(3)
        self.clear_scene()

    def second_order_skips(self):
        section_title = Text("Second Order with Skips", font_size=40, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show long-range dependency problem
        long_sentence = Text("Check the program log and find out whether it ran please", 
                           font_size=16)
        long_sentence.shift(UP * 2.5)
        self.play(Write(long_sentence))
        
        problem_text = Text("Need to look back 8 words! N⁸ rows = impossible", 
                          font_size=20, color=RED)
        problem_text.shift(UP * 2)
        self.play(Write(problem_text))
        
        # Show solution
        solution_title = Text("Solution: Skip connections", font_size=24, color=GREEN)
        solution_title.shift(UP * 1.2)
        self.play(Write(solution_title))
        
        solution_text = Text("Pair most recent word with each previous word", 
                           font_size=18)
        solution_text.shift(UP * 0.8)
        self.play(Write(solution_text))
        
        # Show key insight
        key_pairs = VGroup(
            Text("'program, ran' → votes for 'please'", font_size=16, color=BLUE),
            Text("'battery, ran' → votes for 'down'", font_size=16, color=BLUE)
        ).arrange(DOWN)
        key_pairs.shift(UP * 0.2)
        self.play(Write(key_pairs))
        
        # Show voting concept
        voting_text = Text("Features become VOTES, not probabilities", 
                         font_size=20, color=ORANGE)
        voting_text.shift(DOWN * 0.5)
        self.play(Write(voting_text))
        
        # Show aggregation
        agg_text = Text("Sum all votes to make final prediction", 
                       font_size=18)
        agg_text.shift(DOWN * 1.2)
        self.play(Write(agg_text))
        
        self.wait(3)
        self.clear_scene()

    def masking_attention(self):
        section_title = Text("Masking = Attention!", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show the problem with uninformative features
        problem = Text("Most word pairs give uninformative votes (0.5)", 
                      font_size=20, color=RED)
        problem.shift(UP * 2.5)
        self.play(Write(problem))
        
        # Show feature vector
        feature_text = Text("Feature Vector:", font_size=20)
        feature_text.shift(UP * 1.8)
        self.play(Write(feature_text))
        
        feature_vector = Matrix([[1], [1], [1], [1], [1]]).scale(0.6)
        feature_vector.next_to(feature_text, DOWN)
        self.play(Create(feature_vector))
        
        # Show mask
        mask_text = Text("Attention Mask:", font_size=20)
        mask_text.next_to(feature_vector, RIGHT, buff=1)
        self.play(Write(mask_text))
        
        mask_vector = Matrix([[0], [1], [0], [1], [0]]).scale(0.6)
        mask_vector.next_to(mask_text, DOWN)
        self.play(Create(mask_vector))
        
        # Show element-wise multiplication
        multiply_symbol = Text("⊙", font_size=24)
        multiply_symbol.move_to((feature_vector.get_right() + mask_vector.get_left()) / 2)
        self.play(Write(multiply_symbol))
        
        # Show result
        equals = Text("=", font_size=24)
        result_vector = Matrix([[0], [1], [0], [1], [0]]).scale(0.6)
        equals.next_to(mask_vector, RIGHT, buff=0.5)
        result_vector.next_to(equals, RIGHT, buff=0.5)
        self.play(Write(equals), Create(result_vector))
        
        # The big reveal
        reveal = Text("This selective masking IS attention!", 
                     font_size=24, color=GREEN)
        reveal.shift(DOWN * 1.5)
        self.play(Write(reveal))
        
        # Connection to paper
        paper_connection = Text("QK^T in the attention equation = mask lookup", 
                               font_size=18, color=BLUE)
        paper_connection.shift(DOWN * 2.2)
        self.play(Write(paper_connection))
        
        self.wait(3)
        self.clear_scene()

    def embeddings(self):
        section_title = Text("Embeddings", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show the scaling problem
        problem = Text("50,000 words → 100+ trillion matrix elements!", 
                      font_size=24, color=RED)
        problem.shift(UP * 2.5)
        self.play(Write(problem))
        
        # Show high-dimensional space
        high_dim_text = Text("High-dimensional one-hot space:", font_size=20)
        high_dim_text.shift(UP * 1.8)
        self.play(Write(high_dim_text))
        
        # Create scattered dots representing high-dim space
        high_dim_dots = VGroup()
        for i in range(8):
            for j in range(3):
                dot = Dot(radius=0.05, color=WHITE)
                dot.move_to([i*0.5 - 2, j*0.5 + 0.5, 0])
                high_dim_dots.add(dot)
        self.play(Create(high_dim_dots))
        
        # Show projection arrow
        arrow = Arrow(start=UP*0.2, end=DOWN*0.8, color=YELLOW)
        arrow.shift(DOWN * 0.3)
        self.play(Create(arrow))
        
        projection_text = Text("Project to low-dimensional space", font_size=18)
        projection_text.next_to(arrow, RIGHT)
        self.play(Write(projection_text))
        
        # Show 2D embedding space
        axes = Axes(x_range=[-3, 3], y_range=[-2, 2], 
                   x_length=4, y_length=3)
        axes.shift(DOWN * 2.5)
        self.play(Create(axes))
        
        # Add word points in 2D space
        word_points = VGroup(
            Dot(axes.coords_to_point(-1, 1), color=BLUE).scale(1.5),
            Dot(axes.coords_to_point(1, 1), color=BLUE).scale(1.5),
            Dot(axes.coords_to_point(0, -1), color=GREEN).scale(1.5),
            Dot(axes.coords_to_point(-2, 0), color=RED).scale(1.5)
        )
        
        word_labels = VGroup(
            Text("battery", font_size=12).next_to(word_points[0], UP),
            Text("program", font_size=12).next_to(word_points[1], UP),
            Text("ran", font_size=12).next_to(word_points[2], DOWN),
            Text("check", font_size=12).next_to(word_points[3], LEFT)
        )
        
        self.play(Create(word_points), Write(word_labels))
        
        # Show clustering benefit
        benefit = Text("Similar words cluster together!", 
                      font_size=20, color=GREEN)
        benefit.shift(DOWN * 4.2)
        self.play(Write(benefit))
        
        self.wait(3)
        self.clear_scene()

    def positional_encoding(self):
        section_title = Text("Positional Encoding", font_size=48, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show the problem
        problem = Text("Embeddings lose word order information!", 
                      font_size=24, color=RED)
        problem.shift(UP * 2.5)
        self.play(Write(problem))
        
        # Show solution concept
        solution = Text("Add circular 'wiggles' based on position", 
                       font_size=20, color=GREEN)
        solution.shift(UP * 2)
        self.play(Write(solution))
        
        # Create circular pattern visualization
        axes = Axes(x_range=[-2, 2], y_range=[-2, 2], 
                   x_length=3, y_length=3)
        axes.shift(LEFT * 2 + DOWN * 0.5)
        self.play(Create(axes))
        
        # Create circle showing positional encoding
        circle = Circle(radius=1, color=BLUE)
        circle.move_to(axes.get_origin())
        self.play(Create(circle))
        
        # Show points on circle for different positions
        positions = []
        for i in range(6):
            angle = i * PI / 3
            point = Dot(axes.coords_to_point(np.cos(angle), np.sin(angle)), 
                       color=YELLOW)
            positions.append(point)
        
        position_group = VGroup(*positions)
        self.play(Create(position_group))
        
        # Label positions
        pos_labels = VGroup()
        for i, pos in enumerate(positions):
            label = Text(f"pos{i}", font_size=10)
            label.next_to(pos, UP if i < 3 else DOWN, buff=0.1)
            pos_labels.add(label)
        self.play(Write(pos_labels))
        
        # Show the math concept
        math_text = VGroup(
            Text("Different frequencies for different dimensions", font_size=16),
            Text("sin/cos functions create smooth patterns", font_size=16)
        ).arrange(DOWN)
        math_text.shift(RIGHT * 2.5 + DOWN * 0.5)
        self.play(Write(math_text))
        
        # Show addition to embeddings
        addition_text = Text("Added to word embeddings", font_size=18, color=GREEN)
        addition_text.shift(DOWN * 2.5)
        self.play(Write(addition_text))
        
        self.wait(3)
        self.clear_scene()

    def multihead_attention(self):
        section_title = Text("Multi-Head Attention", font_size=40, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Show the softmax problem
        problem = Text("Softmax focuses on single element", font_size=20, color=RED)
        problem.shift(UP * 2.5)
        self.play(Write(problem))
        
        single_att = Text("Single attention: [0.1, 0.8, 0.1]", font_size=16)
        single_att.shift(UP * 2)
        self.play(Write(single_att))
        
        # Show solution
        solution = Text("Solution: Multiple attention heads!", font_size=20, color=GREEN)
        solution.shift(UP * 1.4)
        self.play(Write(solution))
        
        # Create multiple attention heads visualization
        head_titles = VGroup(
            Text("Head 1", font_size=14),
            Text("Head 2", font_size=14),
            Text("Head 3", font_size=14)
        ).arrange(RIGHT, buff=2)
        head_titles.shift(UP * 0.5)
        self.play(Write(head_titles))
        
        # Show different attention patterns
        att_patterns = VGroup(
            Matrix([["0.7"], ["0.2"], ["0.1"]]).scale(0.4),
            Matrix([["0.1"], ["0.7"], ["0.2"]]).scale(0.4),
            Matrix([["0.2"], ["0.1"], ["0.7"]]).scale(0.4)
        ).arrange(RIGHT, buff=1.5)
        att_patterns.next_to(head_titles, DOWN)
        self.play(Create(att_patterns))
        
        # Show concatenation
        concat_text = Text("Concatenate and project back", font_size=16)
        concat_text.shift(DOWN * 1)
        self.play(Write(concat_text))
        
        # Show efficiency trick
        efficiency = Text("Use smaller dimensions (d_k, d_v) for efficiency", 
                         font_size=16, color=ORANGE)
        efficiency.shift(DOWN * 1.7)
        self.play(Write(efficiency))
        
        # Show dimensions
        dims = VGroup(
            Text("d_model = 512 (paper)", font_size=14),
            Text("d_k = d_v = 64 (paper)", font_size=14),
            Text("h = 8 heads (paper)", font_size=14)
        ).arrange(DOWN, aligned_edge=LEFT)
        dims.shift(DOWN * 2.8 + LEFT * 3)
        self.play(Write(dims))
        
        self.wait(3)
        self.clear_scene()

    def complete_architecture(self):
        section_title = Text("Complete Transformer Architecture", font_size=36, color=YELLOW)
        self.play(Write(section_title))
        self.wait(1)
        self.play(section_title.animate.to_edge(UP))
        
        # Create encoder stack
        encoder_title = Text("Encoder", font_size=20, color=BLUE)
        encoder_title.shift(LEFT * 4 + UP * 2)
        self.play(Write(encoder_title))
        
        encoder_blocks = VGroup()
        for i in range(3):
            block = Rectangle(width=1.5, height=0.8, color=BLUE)
            block.shift(LEFT * 4 + UP * (1 - i * 0.9))
            encoder_blocks.add(block)
        self.play(Create(encoder_blocks))
        
        # Create decoder stack  
        decoder_title = Text("Decoder", font_size=20, color=GREEN)
        decoder_title.shift(RIGHT * 4 + UP * 2)
        self.play(Write(decoder_title))
        
        decoder_blocks = VGroup()
        for i in range(3):
            block = Rectangle(width=1.5, height=0.8, color=GREEN)
            block.shift(RIGHT * 4 + UP * (1 - i * 0.9))
            decoder_blocks.add(block)
        self.play(Create(decoder_blocks))
        
        # Cross-attention connections
        connections = VGroup()
        for i in range(3):
            arrow = Arrow(
                start=encoder_blocks[0].get_right() + UP * (0 - i * 0.9),
                end=decoder_blocks[i].get_left(),
                color=YELLOW
            )
            connections.add(arrow)
        self.play(Create(connections))
        
        cross_att_label = Text("Cross-Attention", font_size=16, color=YELLOW)
        cross_att_label.shift(UP * 0.5)
        self.play(Write(cross_att_label))
        
        # Show data flow
        input_text = Text("Source\nSequence", font_size=14)
        input_text.next_to(encoder_blocks, DOWN)
        self.play(Write(input_text))
        
        output_text = Text("Target\nSequence", font_size=14)
        output_text.next_to(decoder_blocks, DOWN)
        self.play(Write(output_text))
        
        # Show key components
        components = VGroup(
            Text("• Multi-head attention", font_size=12),
            Text("• Feed forward networks", font_size=12),
            Text("• Skip connections", font_size=12),
            Text("• Layer normalization", font_size=12),
            Text("• Positional encoding", font_size=12)
        ).arrange(DOWN, aligned_edge=LEFT)
        components.shift(DOWN * 2.5)
        self.play(Write(components))
        
        self.wait(3)
        self.clear_scene()

    def conclusion(self):
        # Final title
        title = Text("Transformers Demystified!", font_size=56, color=GOLD)
        self.play(Write(title))
        self.wait(1)
        
        # Key takeaways
        takeaways = VGroup(
            Text("✓ Attention is selective masking", font_size=20, color=GREEN),
            Text("✓ Matrix multiplication as lookup", font_size=20, color=GREEN),
            Text("✓ Embeddings enable scaling", font_size=20, color=GREEN),
            Text("✓ Multiple heads provide redundancy", font_size=20, color=GREEN),
            Text("✓ Skip connections ensure robustness", font_size=20, color=GREEN)
        ).arrange(DOWN, buff=0.4)
        takeaways.shift(DOWN * 1.5)
        
        self.play(Write(takeaways))
        self.wait(2)
        
        # Final message
        final_msg = Text("From basic math to state-of-the-art AI!", 
                        font_size=24, color=BLUE)
        final_msg.shift(DOWN * 4)
        self.play(Write(final_msg))
        
        self.wait(3)

    def clear_scene(self):
        self.play(FadeOut(*self.mobjects))
        self.wait(0.5)

# To render: manim -pql transformer_animation.py TransformerTutorial

                                                                                                           

ValueError: latex error converting to dvi. See log output above or the log file: media/Tex/66e1bc57a83e0f07.log