In [None]:
from manim import *
import numpy as np

class ClassificationVsClustering(Scene):
    def construct(self):
        self.camera.background_color = BLACK  # 背景黑色

        # 坐标系 - 主要显示第一象限
        left_axes = Axes(
            x_range=[0, 4, 1], y_range=[0, 4, 1],
            x_length=5, y_length=5,
            axis_config={"color": BLUE}
        ).to_edge(LEFT, buff=1).shift(DOWN*0.5)

        right_axes = Axes(
            x_range=[0, 4, 1], y_range=[0, 4, 1],
            x_length=5, y_length=5,
            axis_config={"color": BLUE}
        ).to_edge(RIGHT, buff=1).shift(DOWN*0.5)

        # 左边数据（红绿）
        np.random.seed(42)
        red_points = np.random.rand(10, 2) * 1.5 + np.array([0.5, 2.0])   # 偏上方
        green_points = np.random.rand(10, 2) * 1.5 + np.array([2.0, 0.5]) # 偏右方

        red_dots = VGroup(*[Dot(left_axes.c2p(x, y), color=RED) for x, y in red_points])
        green_dots = VGroup(*[Dot(left_axes.c2p(x, y), color=GREEN) for x, y in green_points])

        # ==== 计算分类直线（短一些） ==== 
        red_center = red_points.mean(axis=0)
        green_center = green_points.mean(axis=0)

        mid_point = (red_center + green_center) / 2
        vec = green_center - red_center
        perp_vec = np.array([-vec[1], vec[0]])

        # 控制线长度，只覆盖两类点附近
        line_length = 1.5
        line_start = mid_point - perp_vec / np.linalg.norm(perp_vec) * line_length
        line_end   = mid_point + perp_vec / np.linalg.norm(perp_vec) * line_length

        line = Line(
            start=left_axes.c2p(*line_start),
            end=left_axes.c2p(*line_end),
            color=WHITE
        )

        # 右边数据（白色簇状）
        cluster1 = np.random.rand(10, 2) * 1.5 + np.array([0.5, 2.5])
        cluster2 = np.random.rand(10, 2) * 1.5 + np.array([2.5, 3.0])
        cluster3 = np.random.rand(10, 2) * 1.5 + np.array([1.5, 0.5])

        cluster_data = [cluster1, cluster2, cluster3]
        cluster_dots = VGroup()
        for data in cluster_data:
            cluster_dots.add(VGroup(*[Dot(right_axes.c2p(x, y), color=WHITE) for x, y in data]))

        # ===== 动画顺序 =====
        self.play(Create(left_axes))
        self.play(FadeIn(red_dots), FadeIn(green_dots))
        self.wait(1)
        self.play(Create(line))
        self.wait(3)

        self.play(Create(right_axes))
        self.play(FadeIn(cluster_dots))
        self.wait(2)

        # 灰屏
        gray_screen = Rectangle(
            width=self.camera.frame_width,
            height=self.camera.frame_height,
            fill_color=DARK_GRAY,
            fill_opacity=0.9,
            stroke_width=0,
            z_index=0
        )
        self.play(FadeIn(gray_screen))

        # 显示 K-means 和 DBSCAN
        clustering_text = VGroup(
            Text("K-means", font="Arial", font_size=80, color=GREEN_D),
            Text("DBSCAN", font="Arial", font_size=80, color=GREEN_D)
        )
        clustering_text.arrange(DOWN, buff=1)
        clustering_text.move_to(ORIGIN)
        self.play(Write(clustering_text))
        self.wait(3)