In [None]:
from manim import *

class DecisionTreeDemo(Scene):
    def construct(self):
        self.camera.background_color = "#282c34"

        # ---------- 根节点（第一层，深红） ----------
        root = Rectangle(width=2.5, height=1, color=RED_E, fill_color=RED_E, fill_opacity=0.9)
        root.move_to(UP*2.5)
        self.play(FadeIn(root))

        # ---------- 第二层 ----------
        left = Rectangle(width=2, height=0.8, color=RED_C, fill_color=RED_C, fill_opacity=0.8)
        left.next_to(root, DOWN*3.5+LEFT*2.7)
        right = Rectangle(width=2, height=0.8, color=RED_C, fill_color=RED_C, fill_opacity=0.8)
        right.next_to(root, DOWN*3.5+RIGHT*2.8)
        edge_left = Line(root.get_bottom(), left.get_top(), color=WHITE)
        edge_right = Line(root.get_bottom(), right.get_top(), color=WHITE)
        self.play(
            GrowFromCenter(left), Create(edge_left),
            GrowFromCenter(right), Create(edge_right)
        )

        # ---------- 第三层 ----------
        left_child1 = Rectangle(width=1.5, height=0.7, color=RED_A, fill_color=RED_A, fill_opacity=0.8)
        left_child1.next_to(left, DOWN*2.5+LEFT*0.01)
        left_child2 = Rectangle(width=1.5, height=0.7, color=RED_A, fill_color=RED_A, fill_opacity=0.8)
        left_child2.next_to(left, DOWN*2.5+RIGHT*0.01)
        edge_l1 = Line(left.get_bottom(), left_child1.get_top(), color=WHITE)
        edge_l2 = Line(left.get_bottom(), left_child2.get_top(), color=WHITE)
        right_child1 = Rectangle(width=1.5, height=0.7, color=RED_A, fill_color=RED_A, fill_opacity=0.8)
        right_child2 = Rectangle(width=1.5, height=0.7, color=RED_A, fill_color=RED_A, fill_opacity=0.8)
        right_child3 = Rectangle(width=1.5, height=0.7, color=RED_A, fill_color=RED_A, fill_opacity=0.8)
        right_children = VGroup(right_child1, right_child2, right_child3).arrange(RIGHT, buff=0.6)
        right_children.next_to(right, DOWN*2.5)
        edge_r1 = Line(right.get_bottom(), right_child1.get_top(), color=WHITE)
        edge_r2 = Line(right.get_bottom(), right_child2.get_top(), color=WHITE)
        edge_r3 = Line(right.get_bottom(), right_child3.get_top(), color=WHITE)
        self.play(
            GrowFromCenter(left_child1), Create(edge_l1),
            GrowFromCenter(left_child2), Create(edge_l2),
            GrowFromCenter(right_child1), Create(edge_r1),
            GrowFromCenter(right_child2), Create(edge_r2),
            GrowFromCenter(right_child3), Create(edge_r3)
        )

        # ---------- 第四层 ----------
        ll_child1 = Rectangle(width=1.2, height=0.6, color=WHITE, fill_color=WHITE, fill_opacity=0.9)
        ll_child2 = Rectangle(width=1.2, height=0.6, color=WHITE, fill_color=WHITE, fill_opacity=0.9)
        ll_children = VGroup(ll_child1, ll_child2).arrange(RIGHT, buff=0.8)
        ll_children.next_to(left_child1, DOWN*2)
        edge_ll1 = Line(left_child1.get_bottom(), ll_child1.get_top(), color=WHITE)
        edge_ll2 = Line(left_child1.get_bottom(), ll_child2.get_top(), color=WHITE)
        rm_child1 = Rectangle(width=1.2, height=0.6, color=WHITE, fill_color=WHITE, fill_opacity=0.9)
        rm_child2 = Rectangle(width=1.2, height=0.6, color=WHITE, fill_color=WHITE, fill_opacity=0.9)
        rm_children = VGroup(rm_child1, rm_child2).arrange(RIGHT, buff=0.8)
        rm_children.next_to(right_child2, DOWN*2)
        edge_rm1 = Line(right_child2.get_bottom(), rm_child1.get_top(), color=WHITE)
        edge_rm2 = Line(right_child2.get_bottom(), rm_child2.get_top(), color=WHITE)
        self.play(
            GrowFromCenter(ll_child1), Create(edge_ll1),
            GrowFromCenter(ll_child2), Create(edge_ll2),
            GrowFromCenter(rm_child1), Create(edge_rm1),
            GrowFromCenter(rm_child2), Create(edge_rm2)
        )
        self.wait(2)

        # ---------- 第一次剪枝：删掉第四层 ----------
        self.play(
            FadeOut(ll_child1),
            FadeOut(ll_child2),
            FadeOut(rm_child1),
            FadeOut(rm_child2),
            FadeOut(edge_ll1),
            FadeOut(edge_ll2),
            FadeOut(edge_rm1),
            FadeOut(edge_rm2),
        )
        self.wait(1)
        # ---------- 第二次剪枝：再删掉第三层 ----------
        self.play(
            FadeOut(left_child1),
            FadeOut(left_child2),
            FadeOut(right_child1),
            FadeOut(right_child2),
            FadeOut(right_child3),
            FadeOut(edge_l1),
            FadeOut(edge_l2),
            FadeOut(edge_r1),
            FadeOut(edge_r2),
            FadeOut(edge_r3),
        )
        self.wait(1)
        # ---------- 还原到完整树（第三层 + 第四层都回来） ----------
        self.play(
            FadeIn(left_child1),
            FadeIn(left_child2),
            FadeIn(right_child1),
            FadeIn(right_child2),
            FadeIn(right_child3),
            FadeIn(edge_l1),
            FadeIn(edge_l2),
            FadeIn(edge_r1),
            FadeIn(edge_r2),
            FadeIn(edge_r3),
            FadeIn(ll_child1),
            FadeIn(ll_child2),
            FadeIn(rm_child1),
            FadeIn(rm_child2),
            FadeIn(edge_ll1),
            FadeIn(edge_ll2),
            FadeIn(edge_rm1),
            FadeIn(edge_rm2),
        )
        self.wait(1)
        # ---------- 最后再模拟剪枝（只删掉右-中子节点） ----------
        self.play(
            FadeOut(rm_child1),
            FadeOut(rm_child2),
            FadeOut(edge_rm1),
            FadeOut(edge_rm2),
            FadeOut(left_child2),
            FadeOut(edge_l2)
        )
        self.wait(3)
        # ---------- 最终全部消失，黑屏 ----------
        # ---------- 背景也变黑 ----------
        black_bg = Rectangle(
            width=config.frame_width, height=config.frame_height,
            fill_color=BLACK, fill_opacity=1, stroke_width=0
        ).move_to(ORIGIN)

        t9 = Text("集成学习", font="Microsoft YaHei", font_size=80, color=GOLD_C)
        t9.move_to(ORIGIN).shift(UP*1.5).set_opacity(0)
        t8 = Text("Bagging", font="Microsoft YaHei", font_size=70, color=GOLD_A)
        t7 = Text("Boosting", font="Microsoft YaHei", font_size=70, color=GOLD_A)
        t7.next_to(t9, DOWN, buff=1).shift(RIGHT*3)
        t8.next_to(t9, DOWN, buff=1).shift(LEFT*3)
        # 英文
        e1 = Text("Weak Learner", font="Arial", font_size=72, color=GOLD_C)
        e1.move_to(UP*0.5)
        # 中文
        c1 = Text("弱学习器", font="Microsoft YaHei", font_size=60, color=GOLD_A)
        c1.next_to(e1, DOWN, buff=0.5)
         # 英文
        e2 = Text("Base Learner", font="Arial", font_size=72, color=GOLD_C)
        e2.move_to(UP*0.5)
        # 中文
        c2 = Text("基学习器", font="Microsoft YaHei", font_size=60, color=GOLD_A)
        c2.next_to(e1, DOWN, buff=0.5)       
        # 英文
        e3 = Text("Strong Learner", font="Arial", font_size=72, color=GOLD_C)
        e3.move_to(UP*0.43).shift(RIGHT*3.3).scale(0.8)
        # 中文
        c3 = Text("强学习器", font="Microsoft YaHei", font_size=60, color=GOLD_A)
        c3.next_to(e3, DOWN, buff=0.4).scale(0.8)

        aa = Arrow(c1.get_right()+LEFT*3.3, c3.get_left(), buff=0.4, color=WHITE, stroke_width=6).shift(UP*0.4)

        self.play(FadeIn(black_bg))
        self.wait(5)
        self.play(Write(e1))
        self.play(Write(c1))
        self.wait(2)
        self.play(
            Transform(e1,e2),
            Transform(c1,c2),
        )
        self.wait(3)
        self.play(
            e1.animate.shift(LEFT*3.3).scale(0.8),
            c1.animate.shift(LEFT*3.3).scale(0.8),
            FadeIn(c3),
            FadeIn(e3),
            GrowArrow(aa)
        )
        self.wait(2)
        gg = VGroup(e1,e3,c1,c3,aa)

        self.play(
            gg.animate.shift(DOWN*7),
            t9.animate.shift(DOWN*1.3).set_opacity(1),
        )
        self.wait(1)
        self.play(
            t9.animate.shift(UP*1.2),
            Write(t7),
            Write(t8)
        )
        self.wait(3)
        self.play(
            t9.animate.shift(UP*4),
            t7.animate.shift(DOWN*4),
            t8.animate.shift(DOWN*4),
        )