In [None]:
from manim import *

class GD3(Scene):
    def construct(self):
        self.camera.background_color = BLACK
        
        t1 = Text("动量", font="Microsoft YaHei", font_size=44, color=RED).to_edge(UP + LEFT).shift(DOWN*0.2+RIGHT*0.5)
        self.play(LaggedStart(
            FadeIn(t1, shift=RIGHT),
            lag_ratio=0.15
        ))
        self.wait(1)
        
        # Configuration parameters
        curve_spacing = 1.2  # Vertical spacing between curves
        horizontal_shift = 0.4  # Right shift for each upper curve
        green_colors = [GREEN_A, GREEN_B, GREEN_C, GREEN_E]  # Gradient of green colors
        waviness = 0.3  # Waviness coefficient
        line_slope = 0.335  # Slope of the straight lines
        line_length = 5.0  # Length of the straight lines
        
        # Curve function
        def wavy_concave(x):
            x_norm = x / 2.5
            concave_core = 3 * (x_norm**2 - 0.3)
            wave = waviness * np.sin(3 * x_norm * PI)
            decay_factor = np.exp(-0.5 * (x_norm**2))
            return (concave_core + wave) * decay_factor
        
        # Create four wavy curves
        curves = VGroup()
        for i in range(4):
            curve = ParametricFunction(
                lambda t, i=i: np.array([t + i*horizontal_shift, 
                                        wavy_concave(t) + i * curve_spacing, 
                                        0]),
                t_range=[-3, 3],
                color=green_colors[i],
                stroke_width=4 + i*0.3
            )
            curve.set_smoothness(0.95)
            curves.add(curve)
        
        # Create five slanted lines (white to green gradient)
        vertical_lines = VGroup()
        x_positions = [-1.24, -0.7, 0.41, 1.0, 1.85]  # X positions
        
        for x in x_positions:
            # Calculate intersection points with curves
            intersection_points = []
            for i in range(4):
                adjusted_x = x - i*horizontal_shift
                y = wavy_concave(adjusted_x) + i * curve_spacing
                intersection_points.append(np.array([x, y, 0]))
            
            # Calculate center point
            center_point = sum(intersection_points)/len(intersection_points)
            
            # Create fixed length line
            half_length = line_length / 2
            base_line = Line(
                start=center_point + np.array([-half_length*line_slope, -half_length, 0]),
                end=center_point + np.array([half_length*line_slope, half_length, 0]),
                stroke_width=3.0,
                stroke_opacity=0.8,
                color=GREEN_A
            )
            
            vertical_lines.add(base_line)
        
        # Group and position all elements
        all_elements = VGroup(curves, vertical_lines)
        all_elements.scale(0.7)
        all_elements.shift(LEFT*3.8 + DOWN*2.6)
        
        # Animation sequence
        self.play(
            LaggedStart(
                *[Create(curve) for curve in reversed(curves)],
                lag_ratio=0.4,
                run_time=3
            )
        )
        self.wait(0.5)
        
        self.play(
            AnimationGroup(
                *[Create(line) for line in vertical_lines],
                lag_ratio=0.15
            ),
            run_time=2
        )
        self.wait(2)

        # Create zigzag arrow lines
        arrow_lines = VGroup()
        start_point = curves[0].get_end() + LEFT * 3 + UP * 1  
        # Core parameters
        horiz_length = 2.3  # Horizontal segment length
        vert_step = -0.4    # Vertical step (negative means downward)
        alternation = 1.0    # Alternation amplitude
        segments_count = 4  # Number of segments
        
        start_dot1 = Dot(point=start_point, color=RED, radius=0.1)
        self.play(FadeIn(start_dot1))
        
        current_point = start_point
        for i in range(segments_count):
            direction = (-1)**i
            end_point = current_point + np.array([
                horiz_length * direction * alternation,
                vert_step,
                0
            ])
            arrow_line = Arrow(
                start=current_point,
                end=end_point,
                buff=0.1,
                stroke_width=5,
                color=RED,
                tip_length=0.15
            )
            arrow_lines.add(arrow_line)
            current_point = end_point
        
        self.play(
            LaggedStart(
                *[GrowArrow(line) for line in arrow_lines],
                lag_ratio=0.25,
                run_time=4.5
            )
        )
        self.wait(5)
        
        # Create right-side elements
        right_elements = all_elements.copy()
        right_elements.shift(RIGHT * 6.0)
        
        self.play(
            LaggedStart(
                *[Create(curve) for curve in reversed(right_elements[0])],
                lag_ratio=0.4,
                run_time=1.0
            )
        )
        
        self.play(
            AnimationGroup(
                *[Create(line) for line in right_elements[1]],
                lag_ratio=0.02
            ),
            run_time=2
        )
        self.wait(1)
        
        # Create right-side arrows
        right_arrows = arrow_lines.copy()
        right_arrows.shift(RIGHT * 6.0)
        
        start_point_right = right_arrows[0].get_start()
        start_dot_right = Dot(point=start_point_right, color=RED, radius=0.1)
        
        self.play(FadeIn(right_arrows[0]), FadeIn(start_dot_right))

        # Create blue arrow
        first_arrow = right_arrows[0]
        blue_arrow = Arrow(
            start=first_arrow.get_end(),
            end=first_arrow.get_end() + (first_arrow.get_end() - first_arrow.get_start()) * 0.6,
            buff=0.1,
            stroke_width=5,
            color=BLUE,
            tip_length=0.12
        )
        
        self.play(GrowArrow(blue_arrow), run_time=1.5)
        self.wait(0.5)
        self.play(FadeIn(right_arrows[1]))
        self.wait(0.5)

        # Create green arrow
        start_point = right_arrows[0].get_end()
        end_point = right_arrows[0].get_end()+LEFT*0.9+DOWN*0.7
        
        green_arrow = Arrow(
            start=start_point,
            end=end_point,
            buff=0.1,
            stroke_width=40,
            color=YELLOW,
            tip_length=0.2
        )
        
        self.play(GrowArrow(green_arrow), run_time=1.5)
        self.wait(2)

        # Blue arrow animation
        self.play(GrowArrow(blue_arrow), run_time=1.5)
        self.wait(0.5)
        
        # Create inertia text
        blue_arrow_copy = blue_arrow.copy()
        inertia_text = Text("惯性", font="Microsoft YaHei", font_size=50, color=BLUE)
        inertia_text.next_to(t1, RIGHT, buff=0.8)
        
        self.play(
            blue_arrow_copy.animate.scale(1.2),
            run_time=0.5
        )
        self.play(
            Transform(blue_arrow_copy, inertia_text),
            run_time=1.5
        )
        
        t7 = Text("累积历史梯度 → 减少振荡 → 加速优化", font="Microsoft YaHei", font_size=32, color=BLUE)
        t7.next_to(inertia_text, RIGHT, buff=0.5)
        self.play(FadeIn(t7))
        self.wait(1.5)

        # Fade out all objects
        all_objects = Group(*self.mobjects)
        self.play(
            FadeOut(all_objects),
            run_time=2
        )
        self.wait(1)