In [None]:
from manim import *
import numpy as np
import random

class ActivationFunctionScene(Scene):
    def construct(self):
        # ---------- 1. 主文字 ----------
        activation_text = Text(
            "激活函数",
            font="Microsoft YaHei",
            font_size=100,
            color=GREEN
        ).move_to(ORIGIN)
        self.add(activation_text)
        self.wait(1)

        # ---------- 2. 主文字上移并缩小 ----------
        self.play(activation_text.animate.shift(UP*2.6).scale(0.6))
        self.wait(1)

        # ---------- 3. 激活函数名字 ----------
        relu_text = Text("ReLU", font_size=40, color=BLUE)
        sigmoid_text = Text("Sigmoid", font_size=40, color=RED)
        tanh_text = Text("Tanh", font_size=40, color=YELLOW)
        softmax_text = Text("Softmax", font_size=40, color=PURPLE)

        activations = VGroup(relu_text, sigmoid_text, tanh_text, softmax_text).arrange(RIGHT, buff=1.5)
        activations.next_to(activation_text, DOWN, buff=0.7)
        self.play(FadeIn(activations))
        self.wait(1)

        # ---------- 4. 激活函数公式 ----------
        relu_formula = MathTex(r"\text{ReLU}(x) = \max(0,x)").scale(0.6)
        sigmoid_formula = MathTex(r"\text{Sigmoid}(x) = \frac{1}{1 + e^{-x}}").scale(0.6)
        tanh_formula = MathTex(r"\text{Tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}}").scale(0.6)
        softmax_formula = MathTex(r"\text{Softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}").scale(0.6)

        formulas = VGroup(relu_formula, sigmoid_formula, tanh_formula, softmax_formula).arrange(RIGHT, buff=0.5).shift(LEFT*0.3)

        # ---------- 5. 激活函数曲线图 ----------
        # ReLU
        axes_relu = Axes(
            x_range=[-3, 3, 1],
            y_range=[-1, 5, 1],
            x_length=3,
            y_length=3,
            tips=True,
            axis_config={"tip_length":0.15, "tip_width":0.08}
        ).scale(0.85)
        relu_curve = axes_relu.plot(lambda x: max(0, x), color=BLUE)
        relu_plot = VGroup(axes_relu, relu_curve)
        # Sigmoid
        axes_sigmoid = Axes(
            x_range=[-6, 6, 2],
            y_range=[-0.1, 1.6, 0.2],
            x_length=3,
            y_length=3,
            tips=True,
            axis_config={"tip_length":0.15, "tip_width":0.08}
        ).scale(0.85)
        sigmoid_curve = axes_sigmoid.plot(lambda x: 1/(1 + np.exp(-x)), color=RED)
        sigmoid_line = DashedLine(
            start=axes_sigmoid.c2p(-6, 1),
            end=axes_sigmoid.c2p(6, 1),
            color=GRAY
        )
        y_label_sigmoid = MathTex("1").scale(0.5).next_to(axes_sigmoid.c2p(0, 1), LEFT, buff=0.1)
        sigmoid_plot = VGroup(axes_sigmoid, sigmoid_curve, sigmoid_line, y_label_sigmoid)

        # Tanh
        axes_tanh = Axes(
            x_range=[-3, 3, 1],
            y_range=[-1.1, 1.6, 0.5],
            x_length=3,
            y_length=3,
            tips=True,
            axis_config={"tip_length":0.15, "tip_width":0.08}
        ).scale(0.85)
        tanh_curve = axes_tanh.plot(lambda x: np.tanh(x), color=YELLOW)
        tanh_line1 = DashedLine(
            start=axes_tanh.c2p(-3, 1),
            end=axes_tanh.c2p(3, 1),
            color=GRAY
        )
        tanh_line2 = DashedLine(
            start=axes_tanh.c2p(-3, -1),
            end=axes_tanh.c2p(3, -1),
            color=GRAY
        )
        y_label_tanh1 = MathTex("1").scale(0.5).next_to(axes_tanh.c2p(0, 1), LEFT, buff=0.1)
        y_label_tanh2 = MathTex("-1").scale(0.5).next_to(axes_tanh.c2p(0, -1), LEFT, buff=0.1)
        tanh_plot = VGroup(axes_tanh, tanh_curve, tanh_line1, tanh_line2, y_label_tanh1, y_label_tanh2)

        # Softmax（柱状图演示）
        axes_softmax = Axes(
            x_range=[0, 3, 1],
            y_range=[0, 1.1, 0.2],
            x_length=3,
            y_length=3,
            tips=True,
            axis_config={"tip_length":0.15, "tip_width":0.08}
        ).scale(0.85)
        x_vals = np.array([0, 1, 2])
        exp_vals = np.exp(x_vals)
        probs = exp_vals / exp_vals.sum()
        bars = VGroup()
        for i, p in enumerate(probs):
            bar = Rectangle(width=0.4, height=p * axes_softmax.y_length,
                            fill_color=PURPLE, fill_opacity=0.7)
            bar.next_to(axes_softmax.c2p(x_vals[i], 0), UP, buff=0)
            bars.add(bar)
        softmax_plot = VGroup(axes_softmax, bars)

        plots = VGroup(relu_plot, sigmoid_plot, tanh_plot, softmax_plot).arrange(RIGHT, buff=0.7)

        # ---------- 6. 上下排列（公式一行 + 图一行） ----------
        all_content = VGroup(formulas, plots).arrange(DOWN, buff=0.7)
        all_content.next_to(activations, DOWN, buff=0.5)
        self.play(FadeIn(all_content, lag_ratio=0.2))
        self.wait(5)

        # ---------- 7. 给前三个激活函数加大虚线方框（包含名字、公式、图） ----------
        names_box = VGroup(relu_text, sigmoid_text, tanh_text)
        formulas_box = VGroup(relu_formula, sigmoid_formula, tanh_formula)
        plots_box = VGroup(relu_plot, sigmoid_plot, tanh_plot)

        box_group = VGroup(names_box, formulas_box, plots_box)
        dashed_rect = DashedVMobject(SurroundingRectangle(box_group, color=RED, buff=0.3),
                    num_dashes=58, stroke_width=7)
        self.play(Create(dashed_rect))
        self.wait(3)
        # ---------- 先把 ReLU 的名字、公式、图都变灰 ----------
        self.play(
            relu_text.animate.set_color(GRAY),
            relu_formula.animate.set_color(GRAY),
            relu_plot.animate.set_color(GRAY)
        )
        self.wait(5)
        self.play(
            tanh_text.animate.set_color(GRAY),
            tanh_formula.animate.set_color(GRAY),
            tanh_plot.animate.set_color(GRAY)
        )
        self.wait(3)
        # ---------- 1. ReLU 和 Tanh 消失 ----------
        self.play(
            FadeOut(relu_text),
            FadeOut(relu_formula),
            FadeOut(relu_plot),
            FadeOut(tanh_text),
            FadeOut(tanh_formula),
            FadeOut(tanh_plot),
            FadeOut(dashed_rect),
            sigmoid_text.animate.shift(LEFT*1.5),
            sigmoid_formula.animate.shift(LEFT*1.5).scale(1.2),
            sigmoid_plot.animate.shift(LEFT*1.5),
            softmax_text.animate.shift(LEFT*1.4),
            softmax_formula.animate.shift(LEFT*1.4).scale(1.2),
            softmax_plot.animate.shift(LEFT*1.4),
        )
        self.wait(3)
        self.play(
            activation_text.animate.shift(UP*0.2).scale(0.9),
            FadeOut(softmax_text),
            FadeOut(softmax_formula),
            FadeOut(softmax_plot),
            sigmoid_text.animate.shift(RIGHT*3.2+UP*0.2).scale(1.3),
            sigmoid_formula.animate.shift(RIGHT*3.2+UP*0.2).scale(1.3),
            sigmoid_plot.animate.shift(RIGHT*3.2+UP*0.2).scale(1.3)
        )
        self.wait(2)
        stext = Text("函数", font="Microsoft YaHei",font_size=50,color=RED)
        stext.next_to(sigmoid_text, RIGHT).shift(LEFT*0.82+UP*0.05)
        stext.shift(RIGHT*0.8)
        stext.set_opacity(0)
        self.play(
            sigmoid_text.animate.shift(LEFT*0.82),
            stext.animate.shift(LEFT*0.8).set_opacity(1),
        )
        self.wait(2)
        ss = Text("= 逻辑函数", font="Microsoft YaHei",font_size=50,color=RED)
        ss.next_to(stext, RIGHT).shift(LEFT*1.4)
        ss.shift(RIGHT*0.8)
        ss.set_opacity(0)
        self.play(
            sigmoid_text.animate.shift(LEFT*1.4),
            stext.animate.shift(LEFT*1.4),
            ss.animate.shift(LEFT*0.8).set_opacity(1),
        )
        self.wait(2)
        sss1 = MathTex(r"\hat{f} = \sigma(z)", font_size=60, color=RED).scale(1.4)
        sss1.move_to(ORIGIN)
        self.play(
            FadeOut(activation_text),
            FadeOut(sigmoid_text),
            FadeOut(stext),
            FadeOut(ss),
            FadeOut(sigmoid_formula),
            FadeOut(sigmoid_plot),
            Write(sss1)
        )
        self.wait(3)
        # ---------- 表格 ----------
        table_content = [
            ["f", "\\hat{f}", "judge"],
            ["1", "0.9", "good"],
            ["1", "0.1", "bad"],
            ["0", "0.1", "good"],
            ["0", "0.8", "bad"]
        ]
        table = Table(
            table_content,
            include_outer_lines=True,
            line_config={"stroke_width": 1},
            element_to_mobject=lambda x: MathTex(x)  # 不再区分有无 \，全部用 MathTex
        )
        table.scale(0.8)
        table.move_to(DOWN)
        self.play(
            sss1.animate.shift(UP*2.55),
            Create(table)
        )
        self.wait(3)
        # ---------- 修改颜色并播放 ----------
        rows = table.get_rows()
        self.play(
            *[
                row.submobjects[i].animate.set_color(GREEN)
                for row_idx in [1, 3]
                for row in [rows[row_idx]]
                for i in [0, 1, 2]
            ]
        )
        self.play(
            *[
                row.submobjects[i].animate.set_color(RED)
                for row_idx in [2, 4]
                for row in [rows[row_idx]]
                for i in [0, 1, 2]
            ]
        )
        self.wait(2)
        st = MathTex(r"f \in \{0,1\}",font_size=60,color=GOLD_C)
        st.next_to(table,UP,buff=0.5).shift(LEFT*3)
        st2 = MathTex(r"p(f=1\mid x) = \sigma(z)",font_size=70,color=PURPLE_C)
        st2.next_to(table, RIGHT).shift(LEFT*2.65+UP*0.7)
        st3 = sss1.copy().shift(RIGHT*3.0+UP*2)
        st4 = st3.copy().shift(DOWN*2.3)
        st5 = sss1.copy().shift(RIGHT*3.0+UP*2)
        st6 = MathTex("p(f \mid x; \\theta)", font_size=70, color=BLUE)
        st6.next_to(st5, DOWN).shift(DOWN*5.8)
        arrow = Arrow(
                    start=st4.get_bottom(), 
                    end=st2.get_top(),
                    buff=0.2,
                    color=WHITE
                )
        arrow2 = Arrow(
                    start=st2.get_bottom(), 
                    end=st6.get_top(),
                    buff=0.2,
                    color=WHITE
                )
        self.play(
            sss1.animate.shift(UP*2),
            table.animate.shift(LEFT*3.0),
            Write(st),
        )
        self.wait(2)
        self.play(st5.animate.shift(DOWN*2.3))
        self.wait(2)
        self.play(
            FadeIn(arrow),
            FadeIn(st2)
        )
        self.play(FadeIn(arrow2),FadeIn(st6))
        self.wait(5)
        self.play(
            FadeOut(table),
            FadeOut(st),
            FadeOut(arrow),
            FadeOut(st2),
            FadeOut(st5),
            FadeOut(arrow2),
            st6.animate.move_to(UP).scale(1.2).shift(UP*1.6)
        )
        self.wait(2)
        q1 = MathTex(r"\theta = (w^T, b)", font_size=60, color=BLUE)
        q1.move_to(UP).shift(UP*0.5)
        self.play(Write(q1))
        q2 = MathTex(
            r"\hat{\theta} = \arg\max_{\theta} \prod_{i=1}^{N} p(y_i \mid x_i; \theta)",
            font_size=65,
            color=GOLD_C
        )
        q2.next_to(q1, DOWN, buff=0.7)
        self.play(Write(q2))
        self.wait(5)
        q3 = Text("最大似然估计", font="Microsoft YaHei",font_size=45,color=GOLD_D)
        q3.move_to(UP).shift(UP*1.99)
        q4 = MathTex(
            r"\sum_{i=1}^N a_i",
            font_size=70,
            color=GREEN)
        q5 = MathTex(
            r"\prod_{i=1}^N a_i",
            font_size=70,
            color=RED)
        q4.move_to(UP, DOWN).shift(RIGHT*2.0+DOWN*2.5)
        q5.move_to(UP, DOWN).shift(LEFT*2.0+DOWN*2.5)
        self.play(
            st6.animate.shift(UP*3.0),
            q1.animate.shift(UP*3),
            q2.animate.shift(UP*2.0),
            FadeIn(q3),
        )  
        q6 = MathTex(
            r"\ell(\theta) = \sum_{i=1}^N \Big( y_i \log p(y_i \mid x_i; \theta) "
            r"+ (1-y_i)\log\big(1 - p(y_i \mid x_i; \theta)\big) \Big)",
            font_size=50,
            color=YELLOW_B
        )
        q7 = MathTex(
            r"\hat{\theta} = \arg\max_{\theta} \sum_{i=1}^N \Big( y_i \log p(y_i \mid x_i; \theta) "
            r"+ (1-y_i)\log\big(1 - p(y_i \mid x_i; \theta)\big) \Big)",
            font_size=45,
            color=YELLOW_B
        ).move_to(ORIGIN).shift(DOWN*0.6)
        q6.move_to(ORIGIN).shift(DOWN*0.6)
        self.wait(3)
        self.play(Write(q5))
        self.wait(4)
        self.play(Write(q4))
        self.wait(3)
        self.play(
            q4.animate.shift(LEFT*10),
            q5.animate.shift(LEFT*6),
            Write(q6)
        )
        self.wait(2)
        self.play(TransformMatchingTex(q6, q7),run_time=2)
        self.wait(2)
        q8 = MathTex(
            r"L(\theta) = - \ell(\theta)",
            font_size=70,
            color=WHITE
        ).next_to(q6, DOWN, buff=0.9)
        self.play(Write(q8))
        self.wait(2)
        q9 = MathTex(
            r"L(\theta) = - \sum_{i=1}^N \Big( y_i \log \hat{y}_i "
            r"+ (1-y_i)\log(1-\hat{y}_i) \Big)",
            font_size=60,
            color=WHITE
        ).next_to(q6, DOWN, buff=0.3)
        self.play(TransformMatchingTex(q8, q9))
        self.wait(2)
        self.play(
            FadeOut(q3),
            FadeOut(q2),
            FadeOut(q7),
            q9.animate.move_to(ORIGIN)
        )
        self.wait(1)
        self.play(q9.animate.shift(UP*1.5))
        # ---------- 针对 y=1 的解释 ----------
        y1_text = MathTex(
            r"f=1: L = - \log(\hat{f})",
            font_size=60,
            color=BLUE
        ).next_to(q9, DOWN, buff=0.8)
        self.play(Write(y1_text))
        self.wait(2)
        # ---------- 针对 y=0 的解释 ----------
        y0_text = MathTex(
            r"f=0: L = - \log(1-\hat{f})",
            font_size=60,
            color=RED
        ).next_to(y1_text, DOWN, buff=0.5)
        self.play(Write(y0_text))
        self.wait(3)
        t11 = Text("交叉熵损失", font="Microsoft YaHei",font_size=70,color=GOLD_D)
        t11.next_to(q9, UP).shift(DOWN*0.5)
        t11.shift(UP*0.8)
        t11.set_opacity(0)
        self.play(
            q9.animate.shift(DOWN*0.8),
            y1_text.animate.shift(DOWN*0.8),
            y0_text.animate.shift(DOWN*0.8),
            t11.animate.shift(DOWN*0.8).set_opacity(1)
        )
        self.wait(2)
        t22 = Text("最大似然估计", font="Microsoft YaHei",font_size=35,color=GOLD_D)
        t22.next_to(t11, RIGHT).shift(DOWN*0.3)
        self.play(Write(t22))
        self.wait(5)
        self.play(
            FadeOut(q9),
            FadeOut(y1_text),
            FadeOut(y0_text),
            FadeOut(t11),
            FadeOut(t22)
        )