# 信息增益决策树动画
使用 Manim 展示信息增益计算中“经济水平”优于“性别”的划分效果。

In [None]:
from manim import *
from manim.utils.tex import TexTemplate

In [None]:
class EconomicGainTree(Scene):
    def construct(self):
        chinese_template = TexTemplate()
        chinese_template.add_to_preamble(r"""
        \usepackage[UTF8]{ctex}
        \usepackage{amsmath}
        """)

        final_gain_gender = MathTex(
            r"IG(\text{性别}) = 0.029",
            tex_template=chinese_template,
            color=PINK
        ).scale(0.8)

        final_gain_econ = MathTex(
            r"IG(\text{经济水平}) = 0.678",
            tex_template=chinese_template,
            color=BLUE_C
        ).scale(0.8)

        final_group = VGroup(final_gain_gender, final_gain_econ).arrange(RIGHT, buff=1.5).to_edge(UP, buff=1)
        self.play(Write(final_group))

        self.play(
            final_gain_econ.animate.scale(1.3).set_color(BLUE_C).set_stroke(width=2),
            final_gain_gender.animate.scale(0.9),
            run_time=1
        )

        oval = Ellipse(width=2.4, height=1.0, color=BLUE_C).move_to(DOWN * -1.5)
        oval_text = Text("经济水平", font="SimHei", font_size=32, color=BLUE).move_to(oval.get_center())
        self.play(Create(oval), FadeIn(oval_text))
        self.wait(0.5)

        offset_list = [-2.5, 0, 2.5]
        lines = VGroup()
        for offset_x in offset_list:
            start = oval.get_bottom() + DOWN * 0.1
            end = start + DOWN * 1.0 + RIGHT * offset_x
            line = Line(start, end, color=WHITE)
            lines.add(line)
        self.play(Create(lines))
        self.wait()

        tree_group = VGroup(oval, oval_text, lines)
        self.play(
            tree_group.animate.shift(LEFT * 3.0),
            final_gain_econ.animate.shift(LEFT * 5.4),
            FadeOut(final_gain_gender),
            run_time=1
        )
        self.wait(1)

        self.play(lines[0].animate.set_color(RED), run_time=0.5)

        data = [
            ["男", "高", "0"], ["女", "中", "0"], ["男", "低", "1"], ["女", "高", "0"], ["男", "高", "0"],
            ["男", "中", "0"], ["男", "中", "1"], ["女", "中", "0"], ["女", "低", "1"], ["女", "中", "0"],
            ["女", "高", "0"], ["男", "低", "1"], ["女", "低", "1"], ["男", "高", "0"], ["男", "高", "0"]
        ]

        def create_table(label, filtered_data):
            return Table(
                filtered_data,
                include_outer_lines=True,
                col_labels=[
                    Text("性别", font="SimHei"),
                    Text("经济", font="SimHei"),
                    Text("label", font="SimHei")
                ],
                top_left_entry=Text("", font="SimHei")
            ).scale(0.4).move_to(RIGHT * 3.5 + DOWN * 0.5)

        filtered_high = [row for row in data if row[1] == "高"]
        table_high = create_table("高", filtered_high)
        self.play(Create(table_high))
        self.wait(1)

        rect = Rectangle(width=1.5, height=0.7, color=GREEN)
        rect_text = Text("label=0", font="Microsoft YaHei", font_size=24, color=GREEN, weight=BOLD)
        rect_text.move_to(rect.get_center())
        rect_group = VGroup(rect, rect_text)
        rect_group.move_to(lines[0].get_end() + DOWN * 0.5)
        self.play(Create(rect_group))
        self.wait(1)

        self.play(lines[1].animate.set_color(RED), run_time=0.5)
        self.play(FadeOut(table_high))

        filtered_mid = [row for row in data if row[1] == "中"]
        table_mid = create_table("中", filtered_mid)
        self.play(Create(table_mid))
        self.wait(1)

        oval2 = Ellipse(width=2.4, height=1.0, color=BLUE_C)
        oval2_text = Text("性别", font="SimHei", font_size=32, color=BLUE)
        oval2_group = VGroup(oval2, oval2_text)
        oval2_group.move_to(lines[1].get_end() + DOWN * 0.5)
        oval2_text.move_to(oval2.get_center())
        self.play(Create(oval2), FadeIn(oval2_text))
        self.wait(1)

        self.play(lines[2].animate.set_color(RED), run_time=0.5)
        self.play(FadeOut(table_mid))

        filtered_low = [row for row in data if row[1] == "低"]
        table_low = create_table("低", filtered_low)
        self.play(Create(table_low))
        self.wait(1)

        rect2 = Rectangle(width=1.5, height=0.7, color=GREEN)
        rect2_text = Text("label=1", font="Microsoft YaHei", font_size=24, color=GREEN, weight=BOLD)
        rect2_text.move_to(rect2.get_center())
        rect2_group = VGroup(rect2, rect2_text)
        rect2_group.move_to(lines[2].get_end() + DOWN * 0.5)
        self.play(Create(rect2_group))
        self.wait(2)

        self.play(FadeOut(table_low))

        table_mid_clone = create_table("中", filtered_mid).shift(UP * 1.7)
        self.play(Create(table_mid_clone))

        title_text = Text("性别总表", font="Microsoft YaHei", font_size=24, color=YELLOW).next_to(table_mid_clone, UP, buff=0.3)
        self.play(FadeIn(title_text))

        gender_lines = VGroup()
        gender_offsets = [-1.5, 1.5]
        for offset_x in gender_offsets:
            start = oval2.get_bottom() + DOWN * 0.1
            end = start + DOWN * 1.0 + RIGHT * offset_x
            gender_line = Line(start, end, color=WHITE)
            gender_lines.add(gender_line)
        self.play(Create(gender_lines))
        self.wait(1)

        total_table_bottom = table_mid_clone.get_bottom() + DOWN * 0.2
        final_male_target = RIGHT * 2.2 + DOWN * 2.2 + UP * 0.8
        final_female_target = RIGHT * 4.8 + DOWN * 2.2 + UP * 0.8

        final_branch_lines = VGroup(
            Line(total_table_bottom, final_male_target, color=WHITE),
            Line(total_table_bottom, final_female_target, color=WHITE)
        )
        self.play(Create(final_branch_lines))
        self.wait(1)

        self.play(gender_lines[0].animate.set_color(RED), run_time=0.5)

        filtered_male_mid = [row for row in data if row[0] == "男" and row[1] == "中"]
        table_male_mid = create_table("男中", filtered_male_mid).scale(0.80).move_to(RIGHT * 2.1 + DOWN * 2.2)
        self.play(Create(table_male_mid))

        rect3 = Rectangle(width=1.8, height=0.7, color=GREEN)
        rect3_text = Text("label=0/1", font="Microsoft YaHei", font_size=24, color=GREEN, weight=BOLD)
        rect3_text.move_to(rect3.get_center())
        rect3_group = VGroup(rect3, rect3_text)
        rect3_group.move_to(gender_lines[0].get_end() + DOWN * 0.6)
        self.play(Create(rect3_group))
        self.wait(1)

        self.play(gender_lines[1].animate.set_color(RED), run_time=0.5)

        filtered_female_mid = [row for row in data if row[0] == "女" and row[1] == "中"]
        table_female_mid = create_table("女中", filtered_female_mid).scale(0.80).move_to(RIGHT * 4.9 + DOWN * 2.45)
        self.play(Create(table_female_mid))

        rect4 = Rectangle(width=1.5, height=0.7, color=GREEN)
        rect4_text = Text("label=0", font="Microsoft YaHei", font_size=24, color=GREEN, weight=BOLD)
        rect4_text.move_to(rect4.get_center())
        rect4_group = VGroup(rect4, rect4_text)
        rect4_group.move_to(gender_lines[1].get_end() + DOWN * 0.6)
        self.play(Create(rect4_group))
        self.wait(2)