### Playground

_We'll use this to prototype quickly on video generation._

In [1]:
from manim import *

config.media_width = "75%"
config.verbosity = "WARNING"

In [4]:
%%manim -qm AttentionInLLMs

class AttentionInLLMs(Scene):
    def construct(self):
        # Title
        title = Text("Attention in Large Language Models", font_size=40)
        self.play(Write(title))
        self.play(title.animate.to_edge(UP))
        self.wait()

        # Explain the basic concept
        basic_concept = Text("Attention helps the model focus on relevant parts of the input", 
                             font_size=24)
        basic_concept.next_to(title, DOWN, buff=0.5)
        self.play(Write(basic_concept))
        self.wait(2)

        # Create input sequence
        input_text = Text("The cat sat on the mat", font_size=24)
        input_text.to_edge(LEFT).shift(UP)
        self.play(Transform(basic_concept, input_text))
        self.wait()

        # Highlight words
        highlights = [
            SurroundingRectangle(input_text[4:7], color=YELLOW, buff=0.1),  # "cat"
            SurroundingRectangle(input_text[-3:], color=GREEN, buff=0.1)    # "mat"
        ]
        self.play(*[Create(highlight) for highlight in highlights])
        self.wait()

        # Show attention weights
        weights = VGroup(*[
            Arrow(input_text.get_bottom(), input_text[-3:].get_top(), 
                  color=interpolate_color(RED, GREEN, i/5))
            for i in range(6)
        ])
        self.play(LaggedStart(*[GrowArrow(weight) for weight in weights], lag_ratio=0.2))
        self.wait()

        # Clear previous elements
        self.play(*[FadeOut(mob) for mob in self.mobjects if mob != title])

        # Explain self-attention
        self_attention = Text("Self-Attention: Each word attends to all words", font_size=24)
        self_attention.next_to(title, DOWN, buff=0.5)
        self.play(Write(self_attention))
        self.wait()

        # Create a simple matrix
        matrix = Matrix([
            ["1", "0.2", "0.1"],
            ["0.2", "1", "0.3"],
            ["0.1", "0.3", "1"]
        ])
        matrix.scale(0.8)
        matrix.next_to(self_attention, DOWN, buff=0.5)
        self.play(Create(matrix))
        self.wait()

        # Label matrix
        matrix_label = Text("Attention Matrix", font_size=20)
        matrix_label.next_to(matrix, DOWN)
        self.play(Write(matrix_label))
        self.wait()

        # Show multi-head attention
        multi_head = Text("Multi-Head Attention: Multiple attention mechanisms in parallel", 
                          font_size=24)
        multi_head.next_to(matrix_label, DOWN, buff=0.5)
        self.play(Write(multi_head))
        self.wait()

        # Create multiple attention heads
        heads = VGroup(*[
            Circle(radius=0.2, color=BLUE, fill_opacity=0.5)
            for _ in range(4)
        ]).arrange(RIGHT, buff=0.3)
        heads.next_to(multi_head, DOWN, buff=0.5)
        self.play(Create(heads))
        self.wait()

        # Conclusion
        conclusion = Text("Attention mechanisms enable LLMs to process and understand context", 
                          font_size=24)
        conclusion.next_to(heads, DOWN, buff=0.5)
        self.play(Write(conclusion))
        self.wait(2)

        # Fade out everything
        self.play(*[FadeOut(mob) for mob in self.mobjects])
        self.wait()

                                                                                                                                                 

In [6]:
%%manim -qm AttentionInLLMs

class AttentionInLLMs(Scene):
    def construct(self):
        # Title
        title = Text("Attention Mechanism in Large Language Models", font_size=40)
        self.play(Write(title))
        self.play(title.animate.to_edge(UP))
        self.wait()

        # Explain attention mechanism
        attention_text = Text(
            "Attention allows the model to focus on relevant parts of the input",
            font_size=24
        )
        attention_text.next_to(title, DOWN, buff=0.5)
        self.play(Write(attention_text))
        self.wait(2)

        # Input sentence
        sentence = Text("The cat sat on the mat", font_size=24)
        sentence.to_edge(LEFT).shift(UP)
        self.play(Transform(attention_text, sentence))
        self.wait()

        # Highlight key words
        cat_highlight = SurroundingRectangle(sentence[1], color=YELLOW, buff=0.1)
        mat_highlight = SurroundingRectangle(sentence[5], color=GREEN, buff=0.1)
        self.play(Create(cat_highlight), Create(mat_highlight))
        self.wait()

        # Show attention arrows
        arrows = VGroup(*[
            Arrow(
                cat_highlight.get_bottom(),
                mat_highlight.get_top(),
                color=interpolate_color(RED, GREEN, i / 5)
            )
            for i in range(5)
        ])
        self.play(LaggedStart(*[GrowArrow(arrow) for arrow in arrows], lag_ratio=0.1))
        self.wait()

        # Self-Attention explanation
        self_attention = Text("Self-Attention: Each word attends to all words", font_size=24)
        self_attention.next_to(title, DOWN, buff=0.5)
        self.play(Transform(attention_text, self_attention))
        self.wait()

        # Attention matrix visualization
        matrix = Matrix([
            ["1", "0.2", "0.1", "0.3", "0.4"],
            ["0.2", "1", "0.3", "0.5", "0.2"],
            ["0.1", "0.3", "1", "0.2", "0.1"],
            ["0.3", "0.5", "0.2", "1", "0.3"],
            ["0.4", "0.2", "0.1", "0.3", "1"]
        ])
        matrix.scale(0.8)
        matrix.next_to(self_attention, DOWN, buff=0.5)
        self.play(Create(matrix))
        self.wait()

        # Highlight matrix
        matrix_highlight = SurroundingRectangle(matrix, color=BLUE, buff=0.1)
        self.play(Create(matrix_highlight))
        self.wait()

        # Conclusion
        conclusion = Text(
            "Attention mechanisms enable LLMs to understand context effectively",
            font_size=24
        )
        conclusion.next_to(matrix, DOWN, buff=0.5)
        self.play(Write(conclusion))
        self.wait(2)

        # Fade out
        self.play(*[FadeOut(mob) for mob in self.mobjects])
        self.wait()

                                                                                                                                                    