# Reinforcement Learning: Agent-Environment Interaction Visualization

This notebook creates a Manim animation showing the complete RL loop:
- One agent, many parallel environments, one shared learner
- State flows, action sampling, transitions, and learning updates

In [None]:
!sudo apt update
!sudo apt install libcairo2-dev \
    texlive texlive-latex-extra texlive-fonts-extra \
    texlive-latex-recommended texlive-science \
    tipa libpango1.0-dev
!pip install manim
!pip install IPython==8.21.0

In [None]:
from manim import *
import numpy as np

## Color Palette
- **Green**: Policy/Agent
- **Red**: Environment dynamics
- **Blue**: Learning/Training

In [None]:
# Color constants for consistent theming
AGENT_COLOR = "#2ECC71"       # Green - policy
ENV_COLOR = "#E74C3C"         # Red - environment dynamics
TRAINING_COLOR = "#3498DB"    # Blue - learning
STATE_COLOR = "#F39C12"       # Orange - states
ACTION_COLOR = "#9B59B6"      # Purple - actions
REWARD_COLOR = "#F1C40F"      # Gold - rewards

## Helper Classes

In [None]:
class NeuralNetworkGrid(VGroup):
    """A simple grid of nodes to represent a neural network inside the agent."""
    
    def __init__(self, rows=3, cols=4, node_radius=0.08, spacing=0.3, **kwargs):
        super().__init__(**kwargs)
        self.nodes = []
        
        for i in range(rows):
            row_nodes = []
            for j in range(cols):
                node = Circle(
                    radius=node_radius,
                    fill_opacity=0.3,
                    stroke_width=1,
                    color=WHITE
                )
                node.move_to([j * spacing - (cols-1) * spacing / 2, 
                              i * spacing - (rows-1) * spacing / 2, 0])
                row_nodes.append(node)
                self.add(node)
            self.nodes.append(row_nodes)
    
    def pulse_wave(self, scene, direction="forward", run_time=0.5):
        """Animate a wave of activation through the network."""
        cols = len(self.nodes[0]) if self.nodes else 0
        col_range = range(cols) if direction == "forward" else range(cols-1, -1, -1)
        
        for col_idx in col_range:
            anims = []
            for row in self.nodes:
                if col_idx < len(row):
                    node = row[col_idx]
                    anims.append(node.animate.set_fill(AGENT_COLOR, opacity=0.8))
            if anims:
                scene.play(*anims, run_time=run_time/cols)
        
        # Reset
        scene.play(*[node.animate.set_fill(WHITE, opacity=0.3) 
                     for row in self.nodes for node in row], run_time=0.2)

In [None]:
from manim import *
from manim.utils.color.core import ManimColor, interpolate_color
class TransitionPacket(VGroup):
    """A compact packet representing (s, a, r, s') transition."""

    def __init__(self, env_color=ENV_COLOR, **kwargs):
        super().__init__(**kwargs)

        # Container box
        box = RoundedRectangle(
            width=0.6, height=0.4, corner_radius=0.05,
            fill_opacity=0.2, stroke_width=1.5,
            color=env_color
        )

        # State dot (left)
        state_dot = Dot(radius=0.05, color=STATE_COLOR).shift(LEFT * 0.2 + UP * 0.05)

        # Action arrow (small)
        action_arrow = Arrow(
            start=LEFT * 0.05, end=RIGHT * 0.1,
            buff=0, stroke_width=2, color=ACTION_COLOR,
            max_tip_length_to_length_ratio=0.5
        ).shift(UP * 0.05)

        # Reward star
        reward_star = Star(n=5, outer_radius=0.06, inner_radius=0.03,
                          fill_opacity=1, color=REWARD_COLOR).shift(RIGHT * 0.15 + UP * 0.05)

        # Next state dot (below)
        next_state_dot = Dot(radius=0.05, color=STATE_COLOR).shift(DOWN * 0.08)
        # Fix: Convert string color codes to Manim ManimColor objects before interpolation
        next_state_dot.set_color(interpolate_color(ManimColor(STATE_COLOR), ManimColor(WHITE), 0.3))

        self.add(box, state_dot, action_arrow, reward_star, next_state_dot)

In [None]:
class ProbabilityCloud(VGroup):
    """Represents probability distribution as rotating wedges."""
    
    def __init__(self, num_wedges=5, radius=0.25, **kwargs):
        super().__init__(**kwargs)
        
        angles = np.random.dirichlet(np.ones(num_wedges)) * TAU
        start_angle = 0
        
        colors = [RED, ORANGE, YELLOW, GREEN, BLUE]
        
        for i, angle in enumerate(angles):
            wedge = AnnularSector(
                inner_radius=0,
                outer_radius=radius,
                angle=angle,
                start_angle=start_angle,
                fill_opacity=0.6,
                stroke_width=0.5,
                color=colors[i % len(colors)]
            )
            self.add(wedge)
            start_angle += angle

## Main Scene: Complete RL Loop Visualization

In [None]:
%%manim -qm -v WARNING RLLoopVisualization

class RLLoopVisualization(Scene):
    def construct(self):
        self.num_envs = 4
        self.setup_cast()

        # Run all scenes
        self.scene_1_cast_appears()
        self.scene_2_state_flows()
        self.scene_3_agent_computes()
        self.scene_4_actions_sampled()
        self.scene_5_actions_hit_envs()
        self.scene_6_transitions_produced()
        self.scene_7_transitions_to_training()
        self.scene_8_learning_happens()
        self.scene_9_loop_closes()

    def setup_cast(self):
        """Initialize all visual elements."""

        # === AGENT (Left) ===
        self.agent_box = RoundedRectangle(
            width=2.5, height=3.5, corner_radius=0.3,
            fill_opacity=0.15, stroke_width=3,
            color=AGENT_COLOR
        ).shift(LEFT * 4.5)

        self.agent_label = Text("Agent", font_size=24, color=AGENT_COLOR)
        self.agent_label.next_to(self.agent_box, UP, buff=0.2)

        self.neural_net = NeuralNetworkGrid(rows=4, cols=5, node_radius=0.06, spacing=0.35)
        self.neural_net.move_to(self.agent_box.get_center())

        self.agent_group = VGroup(self.agent_box, self.agent_label, self.neural_net)

        # === ENVIRONMENTS (Center-Right, vertical stack) ===
        self.env_boxes = VGroup()
        env_height = 0.6
        env_spacing = 0.8

        for i in range(self.num_envs):
            env = RoundedRectangle(
                width=1.8, height=env_height, corner_radius=0.1,
                fill_opacity=0.2, stroke_width=2,
                color=ENV_COLOR
            )
            y_pos = (self.num_envs - 1) / 2 * env_spacing - i * env_spacing
            env.shift(RIGHT * 0.5 + UP * y_pos)
            self.env_boxes.add(env)

        self.env_label = Text("Environments", font_size=20, color=ENV_COLOR)
        self.env_label.next_to(self.env_boxes, UP, buff=0.3)

        # === TRAINING AREA (Far Right) ===
        self.training_box = RoundedRectangle(
            width=2, height=3.5, corner_radius=0.2,
            fill_opacity=0.1, stroke_width=2,
            color=TRAINING_COLOR
        ).shift(RIGHT * 4.5)

        self.training_label = Text("Training", font_size=20, color=TRAINING_COLOR)
        self.training_label.next_to(self.training_box, UP, buff=0.2)

        # Probability clouds (created when needed)
        self.prob_clouds = []

        # Action arrows (created when needed)
        self.action_arrows = []

        # Transition packets (created when needed)
        self.transition_packets = []

    # ========== SCENE 1: Cast Appears ==========
    def scene_1_cast_appears(self):
        """Everything fades in gently. Environments breathe."""

        # Fade in agent
        self.play(
            FadeIn(self.agent_box),
            FadeIn(self.agent_label),
            FadeIn(self.neural_net),
            run_time=1.5
        )

        # Fade in environments
        self.play(
            *[FadeIn(env) for env in self.env_boxes],
            FadeIn(self.env_label),
            run_time=1
        )

        # Fade in training area
        self.play(
            FadeIn(self.training_box),
            FadeIn(self.training_label),
            run_time=1
        )

        # Environments "breathe" - subtle scaling animation
        breath_anims = []
        for i, env in enumerate(self.env_boxes):
            # Stagger the breathing slightly
            breath_anims.append(
                env.animate.scale(1.05)
            )

        self.play(*breath_anims, run_time=0.5, rate_func=there_and_back)
        self.wait(0.5)

    # ========== SCENE 2: State Flows ==========
    def scene_2_state_flows(self):
        """Streams of dots flow from environments to agent."""

        state_streams = VGroup()

        for i, env in enumerate(self.env_boxes):
            # Create dot stream from env to agent
            start_pos = env.get_left()
            end_pos = self.agent_box.get_right() + UP * (0.3 - i * 0.2)

            # Create multiple dots for the stream
            for j in range(5):
                dot = Dot(radius=0.04, color=STATE_COLOR)
                dot.move_to(start_pos)
                state_streams.add(dot)

        # Animate dots flowing
        self.play(FadeIn(state_streams), run_time=0.3)

        flow_anims = []
        dot_idx = 0
        for i, env in enumerate(self.env_boxes):
            end_pos = self.agent_box.get_right() + UP * (0.3 - i * 0.25)
            for j in range(5):
                # Add wiggle by using path arc
                wiggle = 0.3 * np.sin(j * 0.5 + i * 0.3)
                path = Line(state_streams[dot_idx].get_center(), end_pos)
                flow_anims.append(
                    MoveAlongPath(state_streams[dot_idx], path, rate_func=linear)
                )
                dot_idx += 1

        # Agent brightens as streams arrive
        self.play(
            *flow_anims,
            self.agent_box.animate.set_fill(opacity=0.3),
            run_time=1.5
        )

        self.play(FadeOut(state_streams), run_time=0.3)
        self.state_streams = state_streams

    # ========== SCENE 3: Agent Computes ==========
    def scene_3_agent_computes(self):
        """Internal nodes light up, probability clouds appear inside."""

        # Pulse through neural network
        self.neural_net.pulse_wave(self, direction="forward", run_time=0.8)

        # Create probability clouds inside agent
        self.prob_clouds = []
        for i in range(3):
            cloud = ProbabilityCloud(num_wedges=4, radius=0.2)
            cloud.move_to(self.agent_box.get_center() +
                         RIGHT * 0.5 + UP * (0.4 - i * 0.4))
            cloud.set_opacity(0)
            self.prob_clouds.append(cloud)
            self.add(cloud)

        # Clouds appear and rotate
        self.play(
            *[cloud.animate.set_opacity(0.7) for cloud in self.prob_clouds],
            *[Rotate(cloud, angle=PI/4) for cloud in self.prob_clouds],
            run_time=0.8
        )

        self.wait(0.3)

    # ========== SCENE 4: Actions Sampled ==========
    def scene_4_actions_sampled(self):
        """Discrete arrows shoot from agent to environments."""

        action_colors = [RED, BLUE, GREEN, ORANGE]
        self.action_arrows = []

        for i, env in enumerate(self.env_boxes):
            arrow = Arrow(
                start=self.agent_box.get_right() + UP * (0.3 - i * 0.25),
                end=env.get_left(),
                buff=0.1,
                stroke_width=4,
                color=action_colors[i % len(action_colors)],
                max_tip_length_to_length_ratio=0.15
            )
            self.action_arrows.append(arrow)

        # Arrows emerge one at a time then synchronize
        for i, arrow in enumerate(self.action_arrows):
            self.play(GrowArrow(arrow), run_time=0.15)

        # Probability clouds collapse/vanish
        self.play(
            *[cloud.animate.scale(0.1).set_opacity(0) for cloud in self.prob_clouds],
            run_time=0.4
        )

        for cloud in self.prob_clouds:
            self.remove(cloud)

        self.wait(0.2)

    # ========== SCENE 5: Actions Hit Environments ==========
    def scene_5_actions_hit_envs(self):
        """Actions land in environments, causing reactions."""

        # Save the initial state of environments before any modifications in this scene
        for env in self.env_boxes:
            env.save_state()

        # Flash and deform environments on impact
        impact_anims = []

        for i, env in enumerate(self.env_boxes):
            # Different reactions: squash, stretch, rotate slightly
            if i % 3 == 0:
                impact_anims.append(env.animate.stretch(1.2, 0).stretch(0.85, 1))
            elif i % 3 == 1:
                impact_anims.append(env.animate.rotate(0.1))
            else:
                impact_anims.append(env.animate.stretch(0.9, 0).stretch(1.15, 1))

            # Flash effect
            impact_anims.append(
                env.animate.set_fill(opacity=0.5)
            )

        self.play(*impact_anims, run_time=0.3)

        # Reset environments using restore() to the state saved at the beginning of the method
        self.play(
            *[env.animate.restore() for env in self.env_boxes],
            run_time=0.3
        )

        # Fade out action arrows
        self.play(
            *[FadeOut(arrow) for arrow in self.action_arrows],
            run_time=0.3
        )

    # ========== SCENE 6: Transitions Produced ==========
    def scene_6_transitions_produced(self):
        """Compact packets form from each environment."""

        self.transition_packets = []

        for i, env in enumerate(self.env_boxes):
            packet = TransitionPacket()
            packet.move_to(env.get_center())
            packet.scale(0.01)  # Start tiny
            self.transition_packets.append(packet)
            self.add(packet)

        # Packets pop out and hover
        self.play(
            *[packet.animate.scale(100).shift(RIGHT * 0.5 + UP * 0.1)
              for packet in self.transition_packets],
            run_time=0.6
        )

        # Hover effect
        self.play(
            *[packet.animate.shift(UP * 0.1) for packet in self.transition_packets],
            run_time=0.3,
            rate_func=there_and_back
        )

        self.wait(0.3)

    # ========== SCENE 7: Transitions to Training ==========
    def scene_7_transitions_to_training(self):
        """Packets stream into training container."""

        training_center = self.training_box.get_center()

        # Move packets to training
        self.play(
            *[packet.animate.move_to(training_center +
                                      UP * (0.3 - i * 0.2) +
                                      LEFT * 0.1 * (i % 2))
              for i, packet in enumerate(self.transition_packets)],
            run_time=1
        )

        # Compress into dense cluster
        self.play(
            *[packet.animate.scale(0.6).move_to(training_center)
              for packet in self.transition_packets],
            self.training_box.animate.set_fill(opacity=0.25),
            run_time=0.5
        )

        # Training container glows
        glow = self.training_box.copy()
        glow.set_stroke(width=6, opacity=0.6)

        self.play(
            FadeIn(glow),
            glow.animate.scale(1.1).set_opacity(0),
            run_time=0.5
        )
        self.remove(glow)

        # Agent remains idle
        self.play(
            self.agent_box.animate.set_fill(opacity=0.1),
            run_time=0.3
        )

    # ========== SCENE 8: Learning Happens ==========
    def scene_8_learning_happens(self):
        """Gradient wave flows from training back to agent."""

        # Create gradient wave
        wave = Line(
            start=self.training_box.get_left(),
            end=self.agent_box.get_right(),
            stroke_width=8,
            color=TRAINING_COLOR
        )
        wave.set_opacity(0)

        # Wave travels from training to agent
        wave_dot = Dot(radius=0.15, color=TRAINING_COLOR)
        wave_dot.move_to(self.training_box.get_left())

        self.add(wave_dot)

        # Create a soft glow effect following the wave
        glow_trail = TracedPath(
            wave_dot.get_center,
            stroke_color=TRAINING_COLOR,
            stroke_width=4,
            stroke_opacity=0.5
        )
        self.add(glow_trail)

        self.play(
            wave_dot.animate.move_to(self.agent_box.get_right()),
            run_time=1.2,
            rate_func=smooth
        )

        # Agent transforms: nodes rearrange
        node_anims = []
        for row in self.neural_net.nodes:
            for node in row:
                # Random subtle shift
                shift = np.array([np.random.uniform(-0.05, 0.05),
                                  np.random.uniform(-0.05, 0.05), 0])
                node_anims.append(node.animate.shift(shift))
                # Some connections thicken (represented by opacity)
                node_anims.append(node.animate.set_fill(opacity=np.random.uniform(0.2, 0.6)))

        self.play(
            *node_anims,
            self.agent_box.animate.set_fill(opacity=0.2),
            run_time=0.8
        )

        # Clean up
        self.play(
            FadeOut(wave_dot),
            FadeOut(glow_trail),
            *[FadeOut(packet) for packet in self.transition_packets],
            run_time=0.4
        )

    # ========== SCENE 9: Loop Closes ==========
    def scene_9_loop_closes(self):
        """Reset and show the loop continuity."""

        # Reset training box opacity
        self.play(
            self.training_box.animate.set_fill(opacity=0.08),
            run_time=0.5
        )

        # Environments breathe again
        self.play(
            *[env.animate.scale(1.05) for env in self.env_boxes],
            run_time=0.3,
            rate_func=there_and_back
        )

        # Agent pulses with new shape
        self.play(
            self.agent_box.animate.set_stroke(width=5),
            run_time=0.2
        )
        self.play(
            self.agent_box.animate.set_stroke(width=3),
            run_time=0.2
        )

        # Quick loop demonstration (abbreviated)
        # Show concept: Act → Collect → Learn → Act again
        loop_text = Text("Act → Collect → Learn → Repeat",
                        font_size=28, color=WHITE)
        loop_text.to_edge(DOWN, buff=0.5)

        self.play(FadeIn(loop_text), run_time=0.5)
        self.wait(1)

        # Final fade
        self.play(
            *[FadeOut(mob) for mob in self.mobjects],
            run_time=1.5
        )


## Individual Scene Classes (for testing separately)

In [None]:
%%manim -ql -v WARNING Scene1CastAppears

class Scene1CastAppears(Scene):
    """Scene 1: The cast appears (static universe)"""
    
    def construct(self):
        # Agent (Left)
        agent_box = RoundedRectangle(
            width=2.5, height=3.5, corner_radius=0.3,
            fill_opacity=0.15, stroke_width=3,
            color=AGENT_COLOR
        ).shift(LEFT * 4.5)
        
        agent_label = Text("Agent", font_size=24, color=AGENT_COLOR)
        agent_label.next_to(agent_box, UP, buff=0.2)
        
        neural_net = NeuralNetworkGrid(rows=4, cols=5)
        neural_net.move_to(agent_box.get_center())
        
        # Environments (Center-right)
        env_boxes = VGroup()
        for i in range(4):
            env = RoundedRectangle(
                width=1.8, height=0.6, corner_radius=0.1,
                fill_opacity=0.2, stroke_width=2,
                color=ENV_COLOR
            )
            env.shift(RIGHT * 0.5 + UP * (1.2 - i * 0.8))
            env_boxes.add(env)
        
        env_label = Text("Environments", font_size=20, color=ENV_COLOR)
        env_label.next_to(env_boxes, UP, buff=0.3)
        
        # Training (Far right)
        training_box = RoundedRectangle(
            width=2, height=3.5, corner_radius=0.2,
            fill_opacity=0.1, stroke_width=2,
            color=TRAINING_COLOR
        ).shift(RIGHT * 4.5)
        
        training_label = Text("Training", font_size=20, color=TRAINING_COLOR)
        training_label.next_to(training_box, UP, buff=0.2)
        
        # Animate
        self.play(
            FadeIn(agent_box), FadeIn(agent_label), FadeIn(neural_net),
            run_time=1.5
        )
        self.play(
            *[FadeIn(env) for env in env_boxes],
            FadeIn(env_label),
            run_time=1
        )
        self.play(
            FadeIn(training_box), FadeIn(training_label),
            run_time=1
        )
        
        # Breathing effect
        for _ in range(2):
            self.play(
                *[env.animate.scale(1.05) for env in env_boxes],
                run_time=0.5,
                rate_func=there_and_back
            )
        
        self.wait()

## Optional: Scene with Camera Movement

In [None]:
%%manim -qm -v WARNING RLLoopWithCamera

class RLLoopWithCamera(MovingCameraScene):
    """Version with camera panning for extra polish."""
    
    def construct(self):
        self.num_envs = 4
        self.setup_cast()
        
        # Initial view
        self.camera.frame.save_state()
        
        # Scene 1
        self.scene_1_cast_appears()
        
        # Pan right during rollout (scenes 2-6)
        self.play(
            self.camera.frame.animate.shift(RIGHT * 1.5),
            run_time=2
        )
        
        # Abbreviated action sequence
        self.quick_rollout()
        
        # Pan left during learning (scenes 7-8)
        self.play(
            self.camera.frame.animate.shift(LEFT * 1.5),
            run_time=1.5
        )
        
        self.quick_learning()
        
        # Reset camera
        self.play(Restore(self.camera.frame), run_time=1)
        self.wait()
    
    def setup_cast(self):
        """Same as main scene."""
        self.agent_box = RoundedRectangle(
            width=2.5, height=3.5, corner_radius=0.3,
            fill_opacity=0.15, stroke_width=3,
            color=AGENT_COLOR
        ).shift(LEFT * 4.5)
        
        self.agent_label = Text("Agent", font_size=24, color=AGENT_COLOR)
        self.agent_label.next_to(self.agent_box, UP, buff=0.2)
        
        self.neural_net = NeuralNetworkGrid(rows=4, cols=5)
        self.neural_net.move_to(self.agent_box.get_center())
        
        self.env_boxes = VGroup()
        for i in range(self.num_envs):
            env = RoundedRectangle(
                width=1.8, height=0.6, corner_radius=0.1,
                fill_opacity=0.2, stroke_width=2,
                color=ENV_COLOR
            )
            env.shift(RIGHT * 0.5 + UP * (1.2 - i * 0.8))
            self.env_boxes.add(env)
        
        self.env_label = Text("Environments", font_size=20, color=ENV_COLOR)
        self.env_label.next_to(self.env_boxes, UP, buff=0.3)
        
        self.training_box = RoundedRectangle(
            width=2, height=3.5, corner_radius=0.2,
            fill_opacity=0.1, stroke_width=2,
            color=TRAINING_COLOR
        ).shift(RIGHT * 4.5)
        
        self.training_label = Text("Training", font_size=20, color=TRAINING_COLOR)
        self.training_label.next_to(self.training_box, UP, buff=0.2)
    
    def scene_1_cast_appears(self):
        self.play(
            FadeIn(self.agent_box),
            FadeIn(self.agent_label),
            FadeIn(self.neural_net),
            *[FadeIn(env) for env in self.env_boxes],
            FadeIn(self.env_label),
            FadeIn(self.training_box),
            FadeIn(self.training_label),
            run_time=2
        )
    
    def quick_rollout(self):
        """Abbreviated rollout sequence."""
        # Quick action arrows
        arrows = VGroup()
        for i, env in enumerate(self.env_boxes):
            arrow = Arrow(
                start=self.agent_box.get_right(),
                end=env.get_left(),
                buff=0.1, stroke_width=3,
                color=[RED, BLUE, GREEN, ORANGE][i]
            )
            arrows.add(arrow)
        
        self.play(*[GrowArrow(a) for a in arrows], run_time=0.5)
        self.play(*[FadeOut(a) for a in arrows], run_time=0.3)
        
        # Quick packets
        packets = VGroup()
        for env in self.env_boxes:
            p = TransitionPacket()
            p.move_to(env.get_right() + RIGHT * 0.5)
            packets.add(p)
        
        self.play(FadeIn(packets), run_time=0.4)
        self.play(
            packets.animate.move_to(self.training_box.get_center()),
            run_time=0.8
        )
        self.play(FadeOut(packets), run_time=0.3)
    
    def quick_learning(self):
        """Abbreviated learning."""
        wave = Dot(radius=0.15, color=TRAINING_COLOR)
        wave.move_to(self.training_box.get_left())
        
        self.play(
            wave.animate.move_to(self.agent_box.get_right()),
            run_time=1
        )
        self.play(FadeOut(wave), run_time=0.2)

---

## Usage Notes

1. **Run the main scene**: Execute the `RLLoopVisualization` cell to generate the full animation
2. **Quality settings**:
   - `-ql` = low quality (480p, fast preview)
   - `-qm` = medium quality (720p)
   - `-qh` = high quality (1080p)
   - `-qk` = 4K quality
3. **Test individual scenes**: Use `Scene1CastAppears` or similar for quick iteration
4. **Camera version**: `RLLoopWithCamera` adds subtle camera panning for polish