In [None]:
from manim import *
from math import log2
import locale
from manim.utils.tex import TexTemplate
import numpy as np

# 强制使用 UTF-8 编码
try:
    locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
except:
    pass

# 创建支持中文的 LaTeX 模板
chinese_template = TexTemplate()
chinese_template.add_to_preamble(r"""
\usepackage[UTF8]{ctex}
\usepackage{amsmath}
\usepackage{bm}
""")

config.frame_width = 22
config.frame_height = 8

class EconomicConditionalEntropy(Scene):
    def construct(self):
        data = [
            ["男", "高", "0"], ["女", "中", "0"], ["男", "低", "1"], ["女", "高", "0"], ["男", "高", "0"],
            ["男", "中", "0"], ["男", "中", "1"], ["女", "中", "0"], ["女", "低", "1"], ["女", "中", "0"],
            ["女", "高", "0"], ["男", "低", "1"], ["女", "低", "1"], ["男", "高", "0"], ["男", "高", "0"]
        ]
        headers = ["性别", "经济水平", "是否流失"]
        table_data = [headers] + data

        # 表格设置
        main_table = Table(
            table_data,
            include_outer_lines=True,
            line_config={"stroke_width": 3, "color": WHITE},
            element_to_mobject_config={"font_size": 36, "weight": "BOLD"},
        ).scale(0.55).to_edge(LEFT, buff=1.0)

        final_formula = MathTex(
            r"H(Y|X) = \sum_{x \in X} P(x) H(Y|X=x)",
            tex_template=chinese_template,
            font_size=60
        ).move_to(np.array([2.9, 4.5, 0]))

        # 一开始就显示公式，不用动画淡入
        self.add(final_formula)

        econ_title = Text("经济水平条件熵计算:", font="SimHei", font_size=50, color=GREEN, weight="BOLD")
        econ_title.next_to(final_formula, DOWN, buff=0.5).shift(RIGHT * 0.2)

        total = len(data)
        high_data = [row for row in data if row[1] == "高"]
        mid_data = [row for row in data if row[1] == "中"]
        low_data = [row for row in data if row[1] == "低"]
        high_count = len(high_data)
        mid_count = len(mid_data)
        low_count = len(low_data)

        econ_stats = Tex(
            f"高: {high_count} 个 \\quad 中: {mid_count} 个 \\quad 低: {low_count} 个",
            tex_template=chinese_template,
            font_size=48
        ).next_to(econ_title, DOWN, buff=0.6)

        high_color = "#4682B4"
        mid_color = "#8FBC8F"
        low_color = "#CD853F"

        high_rects, mid_rects, low_rects = [], [], []
        for i, row in enumerate(data, start=1):
            for j in range(3):
                cell = main_table.get_cell((i + 1, j + 1))[0]
                if row[1] == "高":
                    high_rects.append(cell.animate.set_fill(high_color, opacity=0.3))
                elif row[1] == "中":
                    mid_rects.append(cell.animate.set_fill(mid_color, opacity=0.3))
                else:
                    low_rects.append(cell.animate.set_fill(low_color, opacity=0.3))

        table_spacing = 5.0

        high_table = Table([headers] + high_data, include_outer_lines=True,
                           line_config={"stroke_width": 2, "color": high_color},
                           element_to_mobject_config={"font_size": 32, "weight": "BOLD"}
                           ).scale(0.4).next_to(econ_stats, DOWN, buff=0.8).shift(LEFT * table_spacing)

        mid_table = Table([headers] + mid_data, include_outer_lines=True,
                          line_config={"stroke_width": 2, "color": mid_color},
                          element_to_mobject_config={"font_size": 32, "weight": "BOLD"}
                          ).scale(0.4).next_to(econ_stats, DOWN, buff=0.8)

        low_table = Table([headers] + low_data, include_outer_lines=True,
                          line_config={"stroke_width": 2, "color": low_color},
                          element_to_mobject_config={"font_size": 32, "weight": "BOLD"}
                          ).scale(0.4).next_to(econ_stats, DOWN, buff=0.8).shift(RIGHT * table_spacing)

        for i, row in enumerate(high_data, start=1):
            cell = high_table.get_cell((i + 1, 3))
            if row[2] == "1":
                cell.set_fill(high_color, opacity=0.7)
                cell.set_color(WHITE)
            else:
                cell.set_fill(WHITE, opacity=0.5)
                cell.set_color(BLACK)

        for i, row in enumerate(mid_data, start=1):
            cell = mid_table.get_cell((i + 1, 3))
            if row[2] == "1":
                cell.set_fill(mid_color, opacity=0.7)
                cell.set_color(WHITE)
            else:
                cell.set_fill(WHITE, opacity=0.5)
                cell.set_color(BLACK)

        for i, row in enumerate(low_data, start=1):
            cell = low_table.get_cell((i + 1, 3))
            if row[2] == "1":
                cell.set_fill(low_color, opacity=0.7)
                cell.set_color(WHITE)
            else:
                cell.set_fill(WHITE, opacity=0.5)
                cell.set_color(BLACK)

        # 熵值计算
        def entropy(group_data, count):
            c0 = sum(1 for row in group_data if row[2] == "0")
            c1 = sum(1 for row in group_data if row[2] == "1")
            p0 = c0 / count
            p1 = c1 / count
            return -(p0 * log2(p0) + p1 * log2(p1)) if p0 > 0 and p1 > 0 else 0, c0, c1

        H_high, high_0, high_1 = entropy(high_data, high_count)
        H_mid, mid_0, mid_1 = entropy(mid_data, mid_count)
        H_low, low_0, low_1 = entropy(low_data, low_count)

        detailed_high = MathTex(
            f"H(\\text{{高}}) = -({high_0}/{high_count} \\log_2\\frac{{{high_0}}}{{{high_count}}} + {high_1}/{high_count} \\log_2\\frac{{{high_1}}}{{{high_count}}})",
            tex_template=chinese_template, font_size=28, color=high_color
        ).next_to(high_table, DOWN, buff=0.3)

        detailed_mid = MathTex(
            f"H(\\text{{中}}) = -({mid_0}/{mid_count} \\log_2\\frac{{{mid_0}}}{{{mid_count}}} + {mid_1}/{mid_count} \\log_2\\frac{{{mid_1}}}{{{mid_count}}})",
            tex_template=chinese_template, font_size=28, color=mid_color
        ).next_to(mid_table, DOWN, buff=0.3)

        detailed_low = MathTex(
            f"H(\\text{{低}}) = -({low_0}/{low_count} \\log_2\\frac{{{low_0}}}{{{low_count}}} + {low_1}/{low_count} \\log_2\\frac{{{low_1}}}{{{low_count}}})",
            tex_template=chinese_template, font_size=28, color=low_color
        ).next_to(low_table, DOWN, buff=0.3)

        high_val = MathTex(f"= {H_high:.3f}", tex_template=chinese_template, font_size=28, color=high_color
                          ).next_to(detailed_high, DOWN, buff=0.2)
        mid_val = MathTex(f"= {H_mid:.3f}", tex_template=chinese_template, font_size=28, color=mid_color
                         ).next_to(detailed_mid, DOWN, buff=0.2)
        low_val = MathTex(f"= {H_low:.3f}", tex_template=chinese_template, font_size=28, color=low_color
                         ).next_to(detailed_low, DOWN, buff=0.2)

        high_entropy = MathTex(f"H(\\text{{高}}) = {H_high:.3f}", tex_template=chinese_template,
                               font_size=55, color=high_color).next_to(econ_stats, DOWN, buff=0.5).shift(LEFT * table_spacing)
        mid_entropy = MathTex(f"H(\\text{{中}}) = {H_mid:.3f}", tex_template=chinese_template,
                              font_size=55, color=mid_color).next_to(econ_stats, DOWN, buff=0.5)
        low_entropy = MathTex(f"H(\\text{{低}}) = {H_low:.3f}", tex_template=chinese_template,
                              font_size=55, color=low_color).next_to(econ_stats, DOWN, buff=0.5).shift(RIGHT * table_spacing)

        H_cond = (high_count / total) * H_high + (mid_count / total) * H_mid + (low_count / total) * H_low
        final_entropy = MathTex(
            f"H(Y|X=\\text{{经济水平}}) = \\frac{{{high_count}}}{{{total}}}\\times{H_high:.3f} + "
            f"\\frac{{{mid_count}}}{{{total}}}\\times{H_mid:.3f} + \\frac{{{low_count}}}{{{total}}}\\times{H_low:.3f} = {H_cond:.3f}",
            tex_template=chinese_template, font_size=45, color=YELLOW
        ).next_to(VGroup(high_entropy, mid_entropy, low_entropy), DOWN, buff=0.5)

        # ====== 动画顺序 ======
        self.play(Create(main_table))
        self.play(FadeIn(econ_title, shift=DOWN))
        self.wait(0.5)

        self.play(*high_rects, run_time=1.5)
        self.play(*mid_rects, run_time=1.5)
        self.play(*low_rects, run_time=1.5)
        self.wait(1)

        self.play(Write(econ_stats))
        self.wait(1)

        self.play(Create(high_table), Create(mid_table), Create(low_table), run_time=1.5)
        self.wait(1)

        self.play(Write(detailed_high), Write(detailed_mid), Write(detailed_low), run_time=1.5)
        self.wait(1)

        self.play(Write(high_val), Write(mid_val), Write(low_val), run_time=1.5)
        self.wait(2)

        self.play(
            FadeOut(high_table), FadeOut(mid_table), FadeOut(low_table),
            FadeOut(detailed_high), FadeOut(detailed_mid), FadeOut(detailed_low),
            ReplacementTransform(high_val, high_entropy),
            ReplacementTransform(mid_val, mid_entropy),
            ReplacementTransform(low_val, low_entropy),
            run_time=1.5
        )
        self.wait(1)

        # 这里去掉了 FadeIn(final_formula) 的动画
        self.play(Write(final_entropy))
        self.wait(2)