In [None]:
from manim import *
import numpy as np

class KMeansAnimation(Scene):
    def construct(self):
        # Step 1: 显示英文"K-means"
        title_en = Text("K-means", font="Arial", font_size=96, color=BLUE)
        self.play(Write(title_en))
        self.wait(1)

        # Step 2: 翻转 -> 中文"K-均值"
        title_cn = Text("K-均值", font="Microsoft YaHei", font_size=96, color=BLUE)
        self.play(Transform(title_en, title_cn))
        self.wait(1)

        # Step 3: 黑屏过渡
        self.play(FadeOut(title_en))
        self.wait(0.5)
        self.camera.background_color = BLACK

        # Step 4: 左边画坐标轴（蓝色 + 箭头）
        axes = Axes(
            x_range=[0, 10, 1],
            y_range=[0, 6, 1],
            axis_config={"include_numbers": True, "color": BLUE},
            x_length=8,
            y_length=4,
            tips=True
        ).scale(0.8).move_to(ORIGIN).shift(DOWN*1.3)
        self.play(Create(axes))
        self.wait(0.5)

        # Step 5: 添加数据点（每簇 6 个点）
        np.random.seed(1)
        cluster1 = np.random.normal(loc=(2, 3), scale=0.5, size=(6, 2))
        cluster2 = np.random.normal(loc=(3, 5), scale=0.5, size=(6, 2))
        cluster3 = np.random.normal(loc=(8, 3), scale=0.6, size=(6, 2))
        dots1 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.08) for x, y in cluster1])
        dots2 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.08) for x, y in cluster2])
        dots3 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.08) for x, y in cluster3])
        self.play(LaggedStartMap(FadeIn, VGroup(dots1, dots2, dots3), lag_ratio=0.1))
        self.wait(2)

        # Step 6: cluster1/2 → 红色，cluster3 → 蓝色，同时右边写"2-means"
        text_2means = Text("2-means", font="Arial", font_size=64, color=GREEN_C).move_to(ORIGIN).shift(UP*2)
        self.play(
            *[dot.animate.set_color(RED) for dot in dots1],
            *[dot.animate.set_color(RED) for dot in dots2],
            *[dot.animate.set_color(BLUE) for dot in dots3],
        )
        self.wait(1)
        self.play(Write(text_2means))

        # Step 7: 再次变化 → cluster1红，cluster2绿，cluster3蓝；"2-means"→"3-means"
        text_3means = Text("3-means", font="Arial", font_size=64, color=GREEN_C).move_to(ORIGIN).shift(UP*2)
        self.play(
            *[dot.animate.set_color(RED) for dot in dots1],
            *[dot.animate.set_color(GREEN) for dot in dots2],
            *[dot.animate.set_color(BLUE) for dot in dots3],
        )
        self.play(Transform(text_2means, text_3means))
        self.wait(2)

        # Step 8: 把"3-means"换成"k-means"
        text_kmeans = Text("k-means", font="Arial", font_size=64, color=GREEN_C).move_to(ORIGIN).shift(UP*2)
        self.play(Transform(text_2means, text_kmeans))
        self.wait(3)
        t1 = Text("簇", font="Software Yahei", font_size=84, color=RED).next_to(text_kmeans, LEFT, buff=1.3)
        t2 = Text("质心", font="Software Yahei", font_size=84, color=RED).next_to(text_kmeans, RIGHT, buff=1.3)
        self.play(Write(t1))
        self.wait(1)
        self.play(Write(t2))
        self.wait(1)

        # Step 9: k-means 消失
        self.play(FadeOut(text_2means, t1, t2))
        # 新簇数据
        cluster1_small = np.random.normal(loc=(2, 2.5), scale=0.5, size=(3, 2))
        cluster2_small = np.random.normal(loc=(3, 4.5), scale=0.5, size=(3, 2))
        cluster3_small = np.random.normal(loc=(8, 3), scale=0.6, size=(3, 2))
        all_points = np.vstack([cluster1_small, cluster2_small, cluster3_small])
        # 普通点（不用 always_redraw）
        new_dots1 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.1) for x, y in cluster1_small])
        new_dots2 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.1) for x, y in cluster2_small])
        new_dots3 = VGroup(*[Dot(axes.c2p(x, y), color=WHITE, radius=0.1) for x, y in cluster3_small])
        # 添加到场景
        self.play(FadeIn(new_dots1, new_dots2, new_dots3),
                  FadeOut(dots1), FadeOut(dots2), FadeOut(dots3))
        # 把点和坐标轴打包
        axes_and_points = VGroup(axes, new_dots1, new_dots2, new_dots3)
        # 坐标轴+点整体移动到左下角
        self.play(
            axes_and_points.animate.to_corner(DL).scale(0.8).shift(DOWN*0.5),
            run_time=2
        )
        self.wait(1)

        # Step 10: 右边列一个 10 行 2 列的表格（编号, (x,y)）
        table_data = []
        for i, (x, y) in enumerate(all_points[:10], start=1):  # 最多取10个点
            xy_str = f"({x:.1f}, {y:.1f})"  # 格式化为 "(x, y)"，保留1位小数
            table_data.append([str(i), xy_str])
        # 补满空行到 9 行
        for _ in range(len(table_data), 9):
            table_data.append(["", ""])
        table = Table(
            table_data,
            col_labels=None,
            include_outer_lines=True,
            line_config={"stroke_width": 1, "color": WHITE},
            v_buff=0.3,
            h_buff=1.0,
        ).scale(0.5).to_corner(UL).shift(RIGHT*2.0+UP*0.3)
        self.play(Create(table))
        self.wait(2)

       # Step 11: 右上写 "2-means"
        text_2means_final = Text("2-means", font="Arial", font_size=64, color=GREEN_C)
        text_2means_final.to_corner(UR).shift(LEFT*2.3+DOWN*0.5)
        self.play(Write(text_2means_final))

        # Step 12: 固定红色1号点、蓝色6号点
        left_point = new_dots1[0]    # 1号点
        right_point = new_dots2[2]   # 6号点
        self.play(
            left_point.animate.set_color(RED),
            right_point.animate.set_color(BLUE),
            table.get_rows()[0].animate.set_color(RED),
            table.get_rows()[5].animate.set_color(BLUE)
        )
        self.wait(0.5)

        # Step 13: 展示详细欧氏距离公式 (白色公式，最后结果上色)
        x1, y1 = all_points[0]
        x2, y2 = all_points[1]
        x6, y6 = all_points[5]
        dist1 = np.linalg.norm([x1-x2, y1-y2])
        dist2 = np.linalg.norm([x2-x6, y2-y6])
        formula1 = MathTex(
            r"d_{1,2} = \sqrt{(", f"{x1:.1f}", r"-", f"{x2:.1f}", r")^2 + (", f"{y1:.1f}", r"-", f"{y2:.1f}", r")^2} = ",
            f"{dist1:.2f}"
        ).scale(0.7)
        formula2 = MathTex(
            r"d_{2,6} = \sqrt{(", f"{x2:.1f}", r"-", f"{x6:.1f}", r")^2 + (", f"{y2:.1f}", r"-", f"{y6:.1f}", r")^2} = ",
            f"{dist2:.2f}"
        ).scale(0.7)
        # 设置颜色
        for m in range(len(formula1[:-1])):
            formula1[m].set_color(WHITE)
        formula1[-1].set_color(RED)
        for m in range(len(formula2[:-1])):
            formula2[m].set_color(WHITE)
        formula2[-1].set_color(BLUE)
        formula1.next_to(text_2means_final, DOWN, buff=1)
        formula2.next_to(formula1, DOWN, buff=0.3)
        self.play(Write(formula1), Write(formula2))
        self.wait(1)

        # Step 12: 把结果数字移下方比较
        result1 = formula1[-1].copy().move_to(RIGHT*1.1 + DOWN*0.7).scale(2)
        result2 = formula2[-1].copy().move_to(RIGHT*3.5 + DOWN*0.7).scale(2)
        compare_symbol = MathTex(">").scale(2).move_to(DOWN*0.7+RIGHT*2.3)
        self.play(FadeIn(result1), FadeIn(result2), FadeIn(compare_symbol))
        self.wait(1)

        # Step 14: 根据比较结果修改2号点颜色及表格
        point2 = new_dots1[1]  # 2号点
        row2 = table.get_rows()[1]
        if dist1 < dist2:
            new_color = RED
        else:
            new_color = BLUE
        self.play(
            point2.animate.set_color(new_color),
            row2.animate.set_color(new_color)
        )
        self.wait(2.5)

        # Step 15: 剩余点自动判断并变色
        # 1号点红色，6号点蓝色固定
        ref_red = all_points[0]   # 1号点坐标
        ref_blue = all_points[5]  # 6号点坐标
        # 对应 VGroup 和表格索引
        all_dots = list(new_dots1) + list(new_dots2) + list(new_dots3)
        # 前面1号和2号已经处理过，6号也固定
        skip_indices = [0, 1, 5]  
        for idx, dot in enumerate(all_dots):
            if idx in skip_indices:
                continue
            x, y = all_points[idx]
            dist_to_red = np.linalg.norm([x - ref_red[0], y - ref_red[1]])
            dist_to_blue = np.linalg.norm([x - ref_blue[0], y - ref_blue[1]])
            if dist_to_red < dist_to_blue:
                color = RED
            else:
                color = BLUE
            # 点动画变色
            self.play(dot.animate.set_color(color), table.get_rows()[idx].animate.set_color(color))
        
        # Step 17: 消失公式 + 数字 + 比较符号，同时整体布局调整
        self.play(
            FadeOut(formula1, formula2, result1, result2, compare_symbol),
            axes_and_points.animate.to_edge(RIGHT, buff=1).scale(1.4).shift(UP*1+LEFT*1.3),
            table.animate.to_edge(LEFT, buff=1).scale(1.3).shift(DOWN*2.3+RIGHT*0.3),
            text_2means_final.animate.to_edge(UP, buff=0.5).move_to(UP*3)
        )
        # 在坐标轴上显示 Iteration 1
        iteration_text = Text("Iteration 1", font="Arial", font_size=48, color=YELLOW).next_to(axes, UP, buff=0.5)
        self.play(Write(iteration_text))
        self.wait(1.5)

        # Step 18: 计算红色和蓝色点的中心
        iteration_text_new = iteration_text.copy()
        self.play(Transform(iteration_text, Text("Iteration 2", font="Arial", font_size=48, color=YELLOW).next_to(axes, UP, buff=0.5)))

        red_dots = [dot for dot in all_dots if dot.get_color() == RED]
        blue_dots = [dot for dot in all_dots if dot.get_color() == BLUE]
        # 获取坐标
        red_coords = np.array([axes.p2c(dot.get_center())[:2] for dot in red_dots])
        blue_coords = np.array([axes.p2c(dot.get_center())[:2] for dot in blue_dots])
        red_center = red_coords.mean(axis=0)
        blue_center = blue_coords.mean(axis=0)
        # 在坐标轴上用 × 标出中心
        red_cross = Cross(stroke_color=RED, stroke_width=3).scale(0.1).move_to(axes.c2p(*red_center))
        blue_cross = Cross(stroke_color=BLUE, stroke_width=3).scale(0.1).move_to(axes.c2p(*blue_center))
        self.play(FadeIn(red_cross), FadeIn(blue_cross))
        self.wait(2)

        # Step 19: 根据红蓝中心重新判断每个点颜色
        for idx, dot in enumerate(all_dots):
            # 跳过红蓝中心 × 对应的点（如果固定了的话，可跳过）
            if idx in [0, 5]:  # 1号点和6号点固定
                continue
            # 点坐标
            x, y = axes.p2c(dot.get_center())[:2]
            dist_to_red_center = np.linalg.norm([x - red_center[0], y - red_center[1]])
            dist_to_blue_center = np.linalg.norm([x - blue_center[0], y - blue_center[1]])
            # 判断颜色
            new_color = RED if dist_to_red_center < dist_to_blue_center else BLUE
            # 动画变色
            self.play(dot.animate.set_color(new_color), table.get_rows()[idx].animate.set_color(new_color))

        # Step 20: 计算新的红蓝中心
        self.play(Transform(iteration_text, Text("Iteration 3", font="Arial", font_size=48, color=YELLOW).next_to(axes, UP, buff=0.5)))

        red_dots = [dot for dot in all_dots if dot.get_color() == RED]
        blue_dots = [dot for dot in all_dots if dot.get_color() == BLUE]
        # 获取坐标
        red_coords_new = np.array([axes.p2c(dot.get_center())[:2] for dot in red_dots])
        blue_coords_new = np.array([axes.p2c(dot.get_center())[:2] for dot in blue_dots])
        red_center_new = red_coords_new.mean(axis=0)
        blue_center_new = blue_coords_new.mean(axis=0)
        # 将原来的 × 移动到新中心位置
        self.play(
            red_cross.animate.move_to(axes.c2p(*red_center_new)),
            blue_cross.animate.move_to(axes.c2p(*blue_center_new))
        )
        self.wait(1)
        
        # Step 21: 再次根据新中心更新每个点颜色
        for idx, dot in enumerate(all_dots):
            # 跳过固定的红蓝点（如1号和6号）
            if idx in [0, 4]:
                continue
            x, y = axes.p2c(dot.get_center())[:2]
            dist_to_red_center = np.linalg.norm([x - red_center_new[0], y - red_center_new[1]])
            dist_to_blue_center = np.linalg.norm([x - blue_center_new[0], y - blue_center_new[1]])
            new_color = RED if dist_to_red_center < dist_to_blue_center else BLUE
            self.play(dot.animate.set_color(new_color), table.get_rows()[idx].animate.set_color(new_color))

        # Step 22: 计算新的红蓝中心
        self.play(Transform(iteration_text, Text("Iteration 4", font="Arial", font_size=48, color=YELLOW).next_to(axes, UP, buff=0.5)))

        red_dots = [dot for dot in all_dots if dot.get_color() == RED]
        blue_dots = [dot for dot in all_dots if dot.get_color() == BLUE]
        # 获取坐标
        red_coords_new = np.array([axes.p2c(dot.get_center())[:2] for dot in red_dots])
        blue_coords_new = np.array([axes.p2c(dot.get_center())[:2] for dot in blue_dots])
        red_center_new = red_coords_new.mean(axis=0)
        blue_center_new = blue_coords_new.mean(axis=0)
        # 将原来的 × 移动到新中心位置
        self.play(
            red_cross.animate.move_to(axes.c2p(*red_center_new)),
            blue_cross.animate.move_to(axes.c2p(*blue_center_new))
        )
        self.wait(3)

        # Step 23: 清理动画，重置颜色
        # 所有叉消失
        self.play(FadeOut(red_cross), FadeOut(blue_cross), Transform(iteration_text, Text("Iteration 1", font="Arial", font_size=48, color=YELLOW).next_to(axes, UP, buff=0.5)))
        # 所有点和表格变白
        for idx, dot in enumerate(all_dots):
            self.play(dot.animate.set_color(WHITE), table.get_rows()[idx].animate.set_color(WHITE), run_time=0.1)
        self.wait(2)
        # 将1号和6号点恢复红蓝
        self.play(
            all_dots[0].animate.set_color(RED),
            all_dots[5].animate.set_color(BLUE),
            table.get_rows()[0].animate.set_color(RED),
            table.get_rows()[5].animate.set_color(BLUE)
        )
        self.wait(2)

        # Step 24: 全屏变灰 + 显示 k-means++
        gray_screen = Rectangle(
            width=self.camera.frame_width,
            height=self.camera.frame_height,
            fill_color=DARK_GRAY,
            fill_opacity=0.9,
            stroke_width=0,
            z_index=0
        )
        self.play(FadeIn(gray_screen))
        # 显示 k-means++
        kmeans_pp_text = Text("k-means++", font="Arial", font_size=100, color=GREEN_D, z_index=1)
        self.play(Write(kmeans_pp_text))
        self.wait(3)