In [None]:
from manim import *
import random

class SVMintro(Scene):
    def construct(self):
        # ---------- 坐标系（仅第一象限） ----------
        axes_left = Axes(
            x_range=[0, 5, 1],
            y_range=[0, 5, 1],
            x_length=5,
            y_length=5,
            tips=True,
        ).move_to(DOWN)
        x_label_left = axes_left.get_x_axis_label(Tex("X"))
        y_label_left = axes_left.get_y_axis_label(Tex("Y"))
        # ---------- 生成左侧两簇点 ----------
        cluster1_left = VGroup()
        cluster2_left = VGroup()
        random.seed(42)
        for _ in range(10):
            x1 = 1.5 + random.uniform(-0.4, 0.5)
            y1 = 1.5 + random.uniform(-0.5, 0.9)
            dot1 = Dot(axes_left.coords_to_point(x1, y1), color=BLUE)
            cluster1_left.add(dot1)
            x2 = 3.5 + random.uniform(-0.5, 0.8)
            y2 = 3.5 + random.uniform(-0.6, 0.5)
            dot2 = Dot(axes_left.coords_to_point(x2, y2), color=RED)
            cluster2_left.add(dot2)
        # ---------- 左侧直线 ----------
        p1 = axes_left.coords_to_point(1.3, 4.5)
        p2 = axes_left.coords_to_point(4, 0)
        line_left = Line(p1, p2, color=YELLOW_D, stroke_width=3)
        # 左侧标题
        title_left = Text("逻辑回归", font="Microsoft YaHei", font_size=60, color=GREEN_C)
        title_left.next_to(axes_left, UP, buff=0.5).shift(LEFT*3.6)
        # ---------- 左侧组合 ----------
        scatter_left = VGroup(axes_left, x_label_left, y_label_left, cluster1_left, cluster2_left, line_left)
        scatter_left.shift(DOWN*5).scale(0.9)
        self.play(scatter_left.animate.shift(UP*4.85),run_time=1.5)

        # ---------- 决策边界箭头与标注 ----------
        mid_x = (1.3 + 4) / 2
        mid_y = (4.5 + 0) / 2
        mid_point = axes_left.coords_to_point(mid_x, mid_y)
        arrow = Arrow(
            start=mid_point + LEFT*0.5 + DOWN*0.5,
            end=mid_point + RIGHT*0.5 + UP*0.5,
            color=YELLOW_D,
            buff=0
        ).shift(RIGHT*1.5+DOWN*1.0)
        label = Text("决策边界", font="Microsoft YaHei", font_size=28, color=YELLOW_D)
        label.next_to(arrow, RIGHT, buff=0.2).shift(UP*1.0+LEFT*1.0)
        self.play(FadeIn(arrow, label))
        # ---------- 新点 ----------
        new_dot = Dot(axes_left.coords_to_point(3, 1.2), color=WHITE, radius=0.1)
        self.play(FadeIn(new_dot))
        gg1 = VGroup(scatter_left, arrow, label, new_dot)
        self.play(gg1.animate.shift(LEFT*3.9))
        self.play(FadeIn(title_left))
        self.wait(1)

        # ---------- 右侧散点图（点更分散） ----------
        axes_right = axes_left.copy().next_to(axes_left, RIGHT, buff=2.5)
        x_label_right = axes_right.get_x_axis_label(Tex("X"))
        y_label_right = axes_right.get_y_axis_label(Tex("Y"))
        # 右侧标题
        title_right = Text("支持向量机", font="Microsoft YaHei", font_size=60, color=RED_C)
        title_right.next_to(axes_right, UP, buff=0.5).shift(UP*0.3+RIGHT*0.2)
        # ---------- 定义波浪线（2~3个起伏） ----------
        wave_curve = VMobject(color=YELLOW_D, stroke_width=3)
        wave_points = []
        num_waves = 3  # 3个完整起伏
        for i in range(0, 51):
            x = i / 10  # x 从0到5
            y = 2.5 + 1.5 * np.sin(num_waves * np.pi * x / 5)  # 波幅1.5，高起伏
            wave_points.append(axes_right.coords_to_point(x, y))
        wave_curve.set_points_smoothly(wave_points)
        # ---------- 根据波浪线生成两类点（数量变多） ----------
        cluster1_wave = VGroup()
        cluster2_wave = VGroup()
        np.random.seed(300)
        num_points = 40  # 每类40个点
        for _ in range(num_points):
            x_pixel = np.random.uniform(0.1, 4.9)
            y_wave = 2.5 + 1.5 * np.sin(num_waves * np.pi * x_pixel / 5)
            # 蓝色点在波浪线上方
            y1 = y_wave + np.random.uniform(0.3, 0.7)
            cluster1_wave.add(Dot(axes_right.coords_to_point(x_pixel, y1), color=BLUE))
            # 红色点在波浪线下方
            y2 = y_wave - np.random.uniform(0.3, 0.7)
            cluster2_wave.add(Dot(axes_right.coords_to_point(x_pixel, y2), color=RED))

        # 右部分
        scatter_right = VGroup(axes_right, x_label_right, y_label_right, cluster1_wave, cluster2_wave, wave_curve)
        self.play(FadeIn(scatter_right), FadeIn(title_right))
        self.wait(3)
        
        gg2 = VGroup(scatter_right)
        self.play(
            FadeOut(title_left, title_right),
            gg1.animate.scale(0.7).move_to(UL*2).shift(LEFT*2),
            gg2.animate.scale(0.7).move_to(DL*2).shift(LEFT*2)
        )

        # ---------- 再复制一个坐标系（右边） ----------
        axes_far_right = axes_left.copy().next_to(axes_right, RIGHT, buff=2.5)
        x_label_far = axes_far_right.get_x_axis_label(Tex("X"))
        y_label_far = axes_far_right.get_y_axis_label(Tex("Y"))

        # 标题
        title_far = Text("分布对比", font="Microsoft YaHei", font_size=60, color=BLUE_C)
        title_far.next_to(axes_far_right, UP, buff=0.5).shift(UP*0.3)

        # ---------- 生成点（蓝红按距离分布） ----------
        cluster1_far = VGroup()
        cluster2_far = VGroup()
        np.random.seed(123)
        num_points = 120  # 总点数

        center_blue = np.array([3.5, 3.5])
        center_red = np.array([0.5, 0.5])

        for _ in range(num_points):
            x = np.random.uniform(0.3, 4.8)
            y = np.random.uniform(0.3, 4.8)

            # 计算到两个中心的距离
            d_blue = np.linalg.norm(np.array([x, y]) - center_blue)
            d_red = np.linalg.norm(np.array([x, y]) - center_red)

            temperature = 2.5  # 建议 1.5 ~ 3.0 之间调
            p_blue = np.exp(-d_blue / temperature)
            p_red = np.exp(-d_red / temperature)
            prob_blue = p_blue / (p_blue + p_red)

            if np.random.rand() < prob_blue:
                cluster1_far.add(Dot(axes_far_right.coords_to_point(x, y), color=BLUE))
            else:
                cluster2_far.add(Dot(axes_far_right.coords_to_point(x, y), color=RED))

        scatter_far = VGroup(axes_far_right, x_label_far, y_label_far, cluster1_far, cluster2_far)

        # ---------- 显示 ----------
        self.play(FadeIn(scatter_far.scale(1.8).shift(UP*1.9+RIGHT*0.8)))
        self.wait(3)

        # --------- 蓝色点密集区域椭圆 ----------
        ellipse_blue = Ellipse(
            width=2.8, height=1.8,  # 椭圆大小可调
            color=BLUE, fill_color=BLUE, fill_opacity=0.6, stroke_width=0, 
        ).scale(1.75)
        ellipse_blue.move_to(axes_far_right.coords_to_point(3.1, 3.1))
        ellipse_blue.rotate(-PI/4)
        # ---------- 红色点密集区域椭圆 ----------
        ellipse_red = Ellipse(
            width=2.8, height=1.8,
            color=RED, fill_color=RED, fill_opacity=0.6, stroke_width=0, 
        ).scale(1.75)
        ellipse_red.move_to(axes_far_right.coords_to_point(1.7, 1.7))
        ellipse_red.rotate(-PI/4)
        # ---------- 加到场景 ----------
        self.play(FadeIn(ellipse_blue), FadeIn(ellipse_red))
        self.wait(5)
        gg3 = VGroup(scatter_far, ellipse_blue, ellipse_red)
        self.play(
            gg1.animate.shift(LEFT*5),
            gg2.animate.shift(LEFT*5),
            gg3.animate.shift(LEFT*5)
        )
        t1 = Text("概率", font="Microsoft YaHei", font_size=80, color=GOLD_C)
        t1.move_to(UR*1.5).shift(RIGHT*2.8).set_opacity(0)
        t2 = Text("论", font="Microsoft YaHei", font_size=80, color=GOLD_C)
        t2.next_to(t1, RIGHT).shift(LEFT*0.6).set_opacity(0)
        self.play(
            t1.animate.shift(LEFT*0.8).set_opacity(1),
            run_time=1.5
        )
        self.wait(5)
        t1.set_opacity(1)
        self.play(
            t1.animate.shift(LEFT*0.8),
            t2.animate.shift(LEFT*1.2).set_opacity(1),
            run_time=1.5
        )
        self.wait(2)
        t3 = Text("贝叶斯理论", font="Microsoft YaHei", font_size=60, color=GOLD_C)
        t3.next_to(t1, DOWN, buff=1).set_opacity(0).shift(RIGHT*1.9)
        self.play(
            t3.animate.shift(LEFT*0.9).set_opacity(1),
        )
        self.wait(1)
        t4 = Text("朴素贝叶斯", font="Microsoft YaHei", font_size=75, color=GOLD_C)
        t4.move_to(UP*2)
        self.play(
            gg3.animate.shift(LEFT*9),
            t1.animate.shift(LEFT*15),
            t2.animate.shift(LEFT*15),
            Transform(t3, t4),
            run_time=1.9
        )
        self.wait(5)