In [1]:
import torch
hidden_states = torch.tensor(
    # [
    [[-0.083, 0.147],
     [0.029, 0.008],
     [-0.204, 0.132]],
)
d_hidden = 2
w_query = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_query.weight.data = torch.tensor(
    [[-0.381, -0.354],
     [ 0.407, -0.601]], 
    requires_grad=True).T
w_query(hidden_states)

tensor([[ 0.0915, -0.0590],
        [-0.0078, -0.0151],
        [ 0.1314, -0.0071]], grad_fn=<MmBackward0>)

## Example

Let's consider a simple example with 3 tokens, and let's work with an embedding dimension of 3. That means each token turns into a 3-dimensional vector.



In [25]:
prompt = "012"
# for i in range(3,10):
#     print(f"{i}: torch." )
#     print(torch.randint(-3,3,(1,3), dtype=torch.float32))

In [26]:
tokenizer = {
    0: torch.tensor([ 1.,  0.,  0.], dtype=torch.float32),
    1: torch.tensor([-2.,  1.,  0.], dtype=torch.float32),
    2: torch.tensor([-1.,  1.,  1.], dtype=torch.float32),
    3: torch.tensor([ 2.,  2.,  2.], dtype=torch.float32),
    4: torch.tensor([-1.,  0.,  0.], dtype=torch.float32),
    5: torch.tensor([-2., -1.,  0.], dtype=torch.float32),
    6: torch.tensor([-1.,  0., -1.], dtype=torch.float32),
    7: torch.tensor([ 2., -1., -3.], dtype=torch.float32),
    8: torch.tensor([-1.,  0., -3.], dtype=torch.float32),
    9: torch.tensor([-2.,  1.,  0.], dtype=torch.float32),
}

In [27]:
embedding_list = []
for input_id in prompt:
    embedding_list.append(tokenizer[int(input_id)])
hidden_states = torch.stack(embedding_list)
hidden_states

tensor([[ 1.,  0.,  0.],
        [-2.,  1.,  0.],
        [-1.,  1.,  1.]])

In [30]:
import torch
hidden_states = torch.tensor(
    [[1,0,0],
    [-2,1,0],
    [-1,1,1]],
    dtype=torch.float32,
)
weights = torch.tensor(
    [[2,1,1],
    [4,-6,0],
    [-2,7,2]],
    dtype=torch.float32,
    requires_grad=True)

d_hidden = 2
w_query = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_query.weight.data = weights.T

w_key = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_key.weight.data = weights.T

w_value = torch.nn.Linear(in_features=d_hidden, out_features=d_hidden, bias=False)
w_value.weight.data = weights.T

queries = w_query(hidden_states)
print(f"queries are \n{queries}")

keys = w_key(hidden_states)
print(f"keys.T are \n{keys.T}")

attention_map = torch.matmul(queries, keys.T)
print(f"torch.matmul(queries, keys.T) is \n{torch.matmul(queries, keys.T)}")

queries are 
tensor([[ 2.,  1.,  1.],
        [ 0., -8., -2.],
        [ 0.,  0.,  1.]], grad_fn=<MmBackward0>)
keys.T are 
tensor([[ 2.,  0.,  0.],
        [ 1., -8.,  0.],
        [ 1., -2.,  1.]], grad_fn=<PermuteBackward0>)
torch.matmul(queries, keys.T) is 
tensor([[  6., -10.,   1.],
        [-10.,  68.,  -2.],
        [  1.,  -2.,   1.]], grad_fn=<MmBackward0>)


## Visualising things with Manim
(Claude wrote this)

In [4]:
from manim import *
import os
os.environ['PATH'] = '/Library/TeX/texbin:' + os.environ['PATH']

In [5]:
%%manim -qm AttentionMatMul

from manim import *
import numpy as np

class AttentionMatMul(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix mobjects
        def create_matrix(data, name, color=WHITE):
            matrix = Matrix(data.astype(int), h_buff=1.5, v_buff=1.5)
            matrix.set_color(color)
            label = Text(name, font_size=24).next_to(matrix, UP)
            return VGroup(matrix, label)
        
        # Create all matrices
        hidden_matrix = create_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3)
        weights_T_matrix.move_to(RIGHT * 3)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: Compute queries = hidden_states @ weights.T", font_size=28, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication
        mult_sign = MathTex("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2),
            weights_T_matrix.animate.move_to(RIGHT * 2),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = MathTex("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.7).move_to(LEFT * 4),
            weights_T_matrix.animate.scale(0.7).move_to(LEFT * 1.5),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Highlight computation for one element
        highlight_rect = SurroundingRectangle(queries_matrix[0][0][0], color=YELLOW, buff=0.1)
        self.play(Create(highlight_rect))
        
        # Show calculation for queries[0,0] = 1*2 + 0*4 + 0*(-2) = 2
        calc_text = MathTex("queries[0,0] = 1 \\cdot 2 + 0 \\cdot 4 + 0 \\cdot (-2) = 2", font_size=24)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(highlight_rect), FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: Compute attention_map = queries @ keys.T", font_size=28, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3))
        keys_T_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = MathTex("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements (self-attention)
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0][i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.5)
            self.play(FadeOut(rect), run_time=0.3)
        
        self.wait(2)
        
        # Add interpretation text
        interp_text = Text("Higher values indicate stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

# To render this animation in Jupyter:
# Use %%manim -qm -v WARNING AttentionMatMul at the top of the cell
# Or save to a file and run: manim -pql attention_matmul.py AttentionMatMul

ValueError: latex error converting to dvi. See log output above or the log file: media/Tex/6ecf9f51170c1a70.log

In [6]:
%%manim -qm -v WARNING AttentionMatMulNoLatex

from manim import *
import numpy as np

class AttentionMatMulNoLatex(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix mobjects
        def create_matrix(data, name, color=WHITE):
            matrix = Matrix(data.astype(int), h_buff=1.5, v_buff=1.5)
            matrix.set_color(color)
            label = Text(name, font_size=24).next_to(matrix, UP)
            return VGroup(matrix, label)
        
        # Create all matrices
        hidden_matrix = create_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3)
        weights_T_matrix.move_to(RIGHT * 3)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: Compute queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication using Text instead of MathTex
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2),
            weights_T_matrix.animate.move_to(RIGHT * 2),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = Text("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.7).move_to(LEFT * 4),
            weights_T_matrix.animate.scale(0.7).move_to(LEFT * 1.5),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Highlight computation for one element
        highlight_rect = SurroundingRectangle(queries_matrix[0][0][0], color=YELLOW, buff=0.1)
        self.play(Create(highlight_rect))
        
        # Show calculation using Text instead of MathTex
        calc_text = Text("queries[0,0] = 1×2 + 0×4 + 0×(-2) = 2", font_size=24)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(highlight_rect), FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: Compute attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3))
        keys_T_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements (self-attention)
        for i in range(3):
            rect = SurroundingRectangle(attention_matrix[0][i][i], color=GOLD, buff=0.1)
            self.play(Create(rect), run_time=0.5)
            self.play(FadeOut(rect), run_time=0.3)
        
        self.wait(2)
        
        # Add interpretation text
        interp_text = Text("Higher values indicate stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

ValueError: latex error converting to dvi. See log output above or the log file: media/Tex/6ecf9f51170c1a70.log

In [8]:
%%manim -qm -v WARNING AttentionMatMulPureText

from manim import *
import numpy as np

class AttentionMatMulPureText(Scene):
    def construct(self):
        # Define the matrices
        hidden_states = np.array([
            [1, 0, 0],
            [-2, 1, 0],
            [-1, 1, 1]
        ], dtype=np.float32)
        
        weights = np.array([
            [2, 1, 1],
            [4, -6, 0],
            [-2, 7, 2]
        ], dtype=np.float32)
        
        weights_T = weights.T
        
        # Compute results
        queries = hidden_states @ weights_T
        keys = queries  # Same as queries in your code
        attention_map = queries @ keys.T
        
        # Create matrix display without using Matrix class
        def create_text_matrix(data, name, color=WHITE):
            # Convert to integers for display
            data_int = data.astype(int)
            
            # Create text for each element
            elements = []
            for i in range(data_int.shape[0]):
                row = []
                for j in range(data_int.shape[1]):
                    elem = Text(str(data_int[i, j]), font_size=30)
                    row.append(elem)
                elements.append(row)
            
            # Arrange in grid
            matrix_group = VGroup()
            for i, row in enumerate(elements):
                for j, elem in enumerate(row):
                    elem.move_to(np.array([j * 0.8 - 0.8, -i * 0.6 + 0.6, 0]))
                    matrix_group.add(elem)
            
            # Add brackets manually
            left_bracket = Text("[", font_size=60)
            right_bracket = Text("]", font_size=60)
            
            matrix_group.move_to(ORIGIN)
            left_bracket.next_to(matrix_group, LEFT, buff=0.2)
            right_bracket.next_to(matrix_group, RIGHT, buff=0.2)
            
            # Combine everything
            full_matrix = VGroup(left_bracket, matrix_group, right_bracket)
            full_matrix.set_color(color)
            
            # Add label
            label = Text(name, font_size=24).next_to(full_matrix, UP)
            
            return VGroup(full_matrix, label)
        
        # Create all matrices
        hidden_matrix = create_text_matrix(hidden_states, "hidden_states", BLUE)
        weights_T_matrix = create_text_matrix(weights_T, "weights.T", GREEN)
        queries_matrix = create_text_matrix(queries, "queries", YELLOW)
        keys_T_matrix = create_text_matrix(keys.T, "keys.T", ORANGE)
        attention_matrix = create_text_matrix(attention_map, "attention_map", RED)
        
        # Title
        title = Text("Attention Mechanism Matrix Multiplication", font_size=36)
        title.to_edge(UP)
        self.play(Write(title))
        self.wait()
        
        # Step 1: Show initial matrices
        step1_text = Text("Step 1: Initial Matrices", font_size=28, color=BLUE_C)
        step1_text.next_to(title, DOWN)
        self.play(Write(step1_text))
        
        hidden_matrix.move_to(LEFT * 3.5)
        weights_T_matrix.move_to(RIGHT * 3.5)
        
        self.play(
            FadeIn(hidden_matrix),
            FadeIn(weights_T_matrix)
        )
        self.wait(2)
        
        # Step 2: Compute queries
        self.play(FadeOut(step1_text))
        step2_text = Text("Step 2: queries = hidden_states @ weights.T", font_size=26, color=GREEN_C)
        step2_text.next_to(title, DOWN)
        self.play(Write(step2_text))
        
        # Show multiplication
        mult_sign = Text("@", font_size=48).move_to(ORIGIN)
        self.play(
            hidden_matrix.animate.move_to(LEFT * 2.5),
            weights_T_matrix.animate.move_to(RIGHT * 2.5),
            Write(mult_sign)
        )
        self.wait()
        
        # Show result
        equals_sign = Text("=", font_size=48).move_to(ORIGIN)
        self.play(
            ReplacementTransform(mult_sign, equals_sign),
            hidden_matrix.animate.scale(0.6).move_to(LEFT * 4.5),
            weights_T_matrix.animate.scale(0.6).move_to(LEFT * 2),
            equals_sign.animate.move_to(RIGHT * 0.5)
        )
        
        queries_matrix.move_to(RIGHT * 3)
        self.play(FadeIn(queries_matrix))
        self.wait(2)
        
        # Show calculation example
        calc_text = Text("Example: queries[0,0] = 1*2 + 0*4 + 0*(-2) = 2", font_size=20)
        calc_text.to_edge(DOWN)
        self.play(Write(calc_text))
        self.wait(2)
        self.play(FadeOut(calc_text))
        
        # Step 3: Compute attention map
        self.play(
            FadeOut(hidden_matrix),
            FadeOut(weights_T_matrix),
            FadeOut(equals_sign),
            FadeOut(step2_text)
        )
        
        step3_text = Text("Step 3: attention_map = queries @ keys.T", font_size=26, color=RED_C)
        step3_text.next_to(title, DOWN)
        self.play(Write(step3_text))
        
        # Move queries to left and create keys.T
        self.play(queries_matrix.animate.move_to(LEFT * 3.5))
        keys_T_matrix.move_to(RIGHT * 3.5)
        self.play(FadeIn(keys_T_matrix))
        
        # Show multiplication
        mult_sign2 = Text("@", font_size=48).move_to(ORIGIN)
        self.play(Write(mult_sign2))
        self.wait()
        
        # Transform to attention map
        self.play(
            FadeOut(queries_matrix),
            FadeOut(keys_T_matrix),
            FadeOut(mult_sign2),
            FadeIn(attention_matrix.move_to(ORIGIN))
        )
        self.wait()
        
        # Final display with all results
        self.play(FadeOut(step3_text))
        final_text = Text("Final Attention Map", font_size=32, color=GOLD)
        final_text.next_to(title, DOWN)
        self.play(Write(final_text))
        
        # Show the attention pattern
        self.play(attention_matrix.animate.scale(1.2))
        
        # Highlight diagonal elements
        highlight_text = Text("Diagonal values show self-attention", font_size=22, color=GOLD)
        highlight_text.to_edge(DOWN)
        self.play(Write(highlight_text))
        self.wait(2)
        
        # Add interpretation text
        self.play(FadeOut(highlight_text))
        interp_text = Text("Higher values = stronger attention", font_size=24, color=GRAY)
        interp_text.to_edge(DOWN)
        self.play(Write(interp_text))
        self.wait(3)
        
        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                               