In [None]:
from manim import *
import random
import numpy as np
from sklearn.svm import SVC

class ScatterPlot(Scene):
    def construct(self):
        # ---------- 坐标系（只显示第一象限） ----------
        axes = Axes(
            x_range=[0, 5, 1],
            y_range=[0, 5, 1],
            x_length=5,
            y_length=5,
            axis_config={"include_numbers": True}
        )
        axes.move_to(ORIGIN)
        self.play(Create(axes))
        self.wait(1)
        # ---------- 随机散点 ----------
        random.seed(42)  # 固定随机数种子
        # 红色点簇（集中在 (1,1) 附近）
        red_points = [
            axes.c2p(random.uniform(0.5, 4), random.uniform(0.5, 3))
            for _ in range(5)
        ]
        red_dots = VGroup(*[Dot(p, color=RED, radius=0.1) for p in red_points])
        # 蓝色点簇（集中在 (3.5,3.5) 附近）
        blue_points = [
            axes.c2p(random.uniform(2, 5.5), random.uniform(2, 5.5))
            for _ in range(5)
        ]
        blue_dots = VGroup(*[Dot(p, color=BLUE, radius=0.1) for p in blue_points])
        # ---------- 显示动画 ----------
        self.play(FadeIn(red_dots, shift=DOWN))
        self.play(FadeIn(blue_dots, shift=UP))
        self.wait(2)

         # ---------- 计算质心 ----------
        red_center = np.mean([p for p in red_points], axis=0)
        blue_center = np.mean([p for p in blue_points], axis=0)
        # ---------- 分类直线（质心连线的垂直平分线） ----------
        mid_point = (red_center + blue_center) / 2
        direction = blue_center - red_center
        # 垂直方向向量
        perp_dir = np.array([-direction[1], direction[0], 0])
        # 延伸直线两端
        line = Line(mid_point - 10*perp_dir, mid_point + 10*perp_dir, color=YELLOW_C)
        self.play(Create(line))
        self.wait(2)

        # ---------- 在 (3,3) 画一个点 ----------
        point_33 = Dot(axes.c2p(2.8, 2.8), color=WHITE, radius=0.12)
        self.play(FadeIn(point_33, scale=0.5))
        self.wait(2)
        self.play(FadeOut(line))

        # ---------- 下凸曲线 + 左右平移 ----------
        fit_x = np.linspace(0, 5, 300)
        red_xy = np.array([axes.p2c(p)[:2] for p in red_points])
        blue_xy = np.array([axes.p2c(p)[:2] for p in blue_points])
        red_y_max = red_xy[:,1].max()
        blue_y_min = blue_xy[:,1].min()
        middle_y = (red_y_max + blue_y_min)/2
        offset = 0.5
        a = -0.4
        x_shift = 0.5
        y_fit = middle_y + offset + a*(fit_x - 2.5 - x_shift)**2
        poly_points = [axes.c2p(x, y) for x, y in zip(fit_x, y_fit)]
        poly_curve = VMobject(color=YELLOW_C).set_points_smoothly(poly_points)
        self.play(Create(poly_curve))
        self.wait(2)
        self.play(FadeOut(poly_curve))
        self.wait(2)
        # ---------- 整体放大（坐标轴 + 点） ----------
        all_objects = VGroup(axes, red_dots, blue_dots, point_33)
        self.play(all_objects.animate.scale(1.2).move_to(ORIGIN))
        self.wait(2)
        # ---------- 以白点为中心画虚线圆 ----------
        circle = Circle(radius=1.8, color=WHITE)  # 半径你可以改，比如 1
        circle.move_to(point_33.get_center())   # 圆心对齐白点
        dashed_circle = DashedVMobject(circle, num_dashes=50, dashed_ratio=0.6)
        self.play(Create(dashed_circle))
        self.wait(2)
        self.play(point_33.animate.set_color(BLUE_C))
        self.wait(2)

        all_objects2 = VGroup(axes, red_dots, blue_dots, point_33, dashed_circle)
        self.play(
            all_objects2.animate.shift(LEFT*3),
            point_33.animate.set_color(WHITE).shift(LEFT*3),
            FadeOut(dashed_circle)
        )
                # ---------- 左移后的虚线连接 ----------
        dashed_lines = VGroup()
        all_dots = [*red_dots, *blue_dots]
        for i, dot in enumerate(all_dots, start=1):
            line = DashedLine(
                start=point_33.get_center(),
                end=dot.get_center(),
                color=GRAY,
                dash_length=0.15
            )
            dashed_lines.add(line)
            # 给点标序号
            label = Text(str(i), font_size=24, color=YELLOW).next_to(dot, LEFT*0.2)
            self.add(label)

        self.play(Create(dashed_lines))
        self.wait(2)

        # ---------- 计算每个点到白点的距离 ----------
        distances = []
        white_center = point_33.get_center()
        for i, dot in enumerate(all_dots, start=1):
            dist = np.linalg.norm(dot.get_center() - white_center)
            distances.append((i, dist))

        # ---------- 构造普通的两行表格 ----------
        display_data = distances[:3] + [("...", ""), distances[-1]]
        headers = [str(idx) for idx, _ in display_data]
        values = [f"{d:.2f}" if d != "" else "" for _, d in display_data]
        # 注意：这里传进去的数据是两行
        table = Table(
            [headers, values],
            include_outer_lines=True,   # 外框
            line_config={"stroke_width": 1, "color": WHITE},
            element_to_mobject=lambda e: Text(str(e), font_size=28, color=WHITE)
        )
        table.scale(0.6).to_edge(RIGHT).shift(UP*2)
        self.play(Create(table))
        self.wait(2)
        # ---------- 第二个表格：按距离从小到大 ----------
        # 排序（不包括 "..."）
        sorted_data = sorted(distances, key=lambda x: x[1])  # [(序号, 距离), ...]
        # 取前3个 + 最后1个
        display_sorted = sorted_data[:3] + [("...", ""), sorted_data[-1]]
        headers_sorted = [str(idx) for idx, _ in display_sorted]
        values_sorted = [f"{d:.2f}" if d != "" else "" for _, d in display_sorted]
        table_sorted = Table(
            [headers_sorted, values_sorted],
            include_outer_lines=True,
            line_config={"stroke_width": 1, "color": WHITE},
            element_to_mobject=lambda e: Text(str(e), font_size=28, color=WHITE)
        )
        table_sorted.scale(0.6).next_to(table, DOWN, buff=0.5)  # 放在第一个表格下面
        self.play(Create(table_sorted))
        self.wait(2)

                # ---------- 在表格下面显示 K=3 ----------
        k_text = Text("K=3", font_size=36, color=YELLOW)
        k_text.next_to(table_sorted, DOWN, buff=0.8)
        self.play(FadeIn(k_text))
        self.wait(2)

        # ---------- 过几秒变成 K=4 ----------
        k_text_new = Text("K=4", font_size=36, color=YELLOW).move_to(k_text.get_center())
        self.play(Transform(k_text, k_text_new))
        self.wait(2)