In [None]:
from manim import *
from manim.utils.color import interpolate_color
import random

class LogisticRegressionSteps(Scene):
    def construct(self):
        steps = [
            "1. 用线性回归思想初始化一条线性打分函数",
            "2. 用逻辑函数 sigmoid 套打分函数得出的分数转换为概率",
            "3. 将概率套到交叉熵，得到交叉熵损失函数",
            "4. 交叉熵损失函数做梯度下降，迭代优化打分函数参数",
        ]
        color_start = WHITE  # 第一行颜色
        color_end = RED_D    # 第四行颜色

        text_mobjects = []
        for i, step in enumerate(steps):
            txt = Text(step, font="Microsoft YaHei", font_size=36)
            txt.to_edge(UP).shift(DOWN * i)
            txt.shift(LEFT * 1.5)
            txt.set_opacity(0)
            # 计算渐变颜色
            t = i / (len(steps) - 1)
            txt.set_color(interpolate_color(color_start, color_end, t))
            text_mobjects.append(txt)

        group = VGroup(*text_mobjects)
        self.add(group)
        self.play(group.animate.shift(DOWN * 2), run_time=1)
        for txt in text_mobjects:
            self.play(txt.animate.shift(RIGHT * 1.5).set_opacity(1), run_time=1)
            self.wait(2)
        self.wait(1)
        tt = Text("训练结束", font="Microsoft YaHei", font_size=60, color=RED_E)
        tt.move_to(DOWN).shift(DOWN*4.0)
        self.play(
            group.animate.shift(UP * 0.8), 
            tt.animate.shift(UP*2.5),
            run_time=1.4
        )
        self.wait(1)
        animations = []
        i = 10
        for txt in text_mobjects:
            animations.append(txt.animate.shift(UP * i))
            i -= 1.5

        # ---------- 坐标系（仅第一象限） ----------
        axes = Axes(
            x_range=[0, 5, 1],
            y_range=[0, 5, 1],
            x_length=5,
            y_length=5,
            tips=True,
        ).move_to(DOWN)
        x_label = axes.get_x_axis_label(Tex("X"))
        y_label = axes.get_y_axis_label(Tex("Y"))
        # ---------- 生成两簇点 ----------
        cluster1 = VGroup()
        cluster2 = VGroup()
        random.seed(42)
        for _ in range(10):
            x1 = 1.5 + random.uniform(-0.4, 0.5)
            y1 = 1.5 + random.uniform(-0.5, 0.9)
            dot1 = Dot(axes.coords_to_point(x1, y1), color=BLUE)
            cluster1.add(dot1)
            x2 = 3.5 + random.uniform(-0.5, 0.8)
            y2 = 3.5 + random.uniform(-0.6, 0.5)
            dot2 = Dot(axes.coords_to_point(x2, y2), color=RED)
            cluster2.add(dot2)
        p1 = axes.coords_to_point(1.3, 4.5)
        p2 = axes.coords_to_point(4, 0)
        line = Line(p1, p2, color=YELLOW_D, stroke_width=3)
        scatter_group = VGroup(axes, x_label, y_label, cluster1, cluster2, line)
        scatter_group.shift(DOWN*5).scale(0.9)

        self.play(
            *animations,
            tt.animate.shift(UP*5.2),
            scatter_group.animate.shift(UP*4.85),
            run_time=1.5)
        self.wait(2)
        # ---------- 在最终位置标注决策边界 ----------
        # 先确定箭头位置（例如沿直线中点）
        mid_x = (1.3 + 4) / 2
        mid_y = (4.5 + 0) / 2
        mid_point = axes.coords_to_point(mid_x, mid_y)
        # 箭头朝向稍微偏上的方向
        arrow = Arrow(
            start=mid_point + LEFT*0.5 + DOWN*0.5,
            end=mid_point + RIGHT*0.5 + UP*0.5,
            color=YELLOW_D,
            buff=0
        ).shift(RIGHT*1.5+DOWN*1.0)
        # 文字标注
        label = Text("决策边界", font="Microsoft YaHei", font_size=28, color=YELLOW_D)
        label.next_to(arrow, RIGHT, buff=0.2).shift(UP*1.0+LEFT*1.0)
        # 添加到场景中
        self.play(FadeIn(arrow, label))
        self.wait(1)
        # ---------- 在坐标 (3, 1.2) 处画一个点 ----------
        new_dot = Dot(axes.coords_to_point(3, 1.2), color=WHITE, radius=0.1)
        self.play(FadeIn(new_dot))
        gg1 = VGroup(scatter_group, arrow, label, new_dot)
        self.play(gg1.animate.shift(LEFT*3.5))

        # ---------- 坐标系 ----------
        axes2 = Axes(
            x_range=[-6, 6, 1],
            y_range=[0, 1.2, 0.2],
            x_length=7,
            y_length=4,
            tips=True,
        ).to_edge(DOWN)
        x_label2 = axes2.get_x_axis_label(Tex("x"))
        y_label2 = axes2.get_y_axis_label(Tex("y"))
        # ---------- Sigmoid 函数 ----------
        sigmoid_func = lambda x: 1 / (1 + np.exp(-x))
        sigmoid_curve = axes2.plot(sigmoid_func, x_range=[-6, 6], color=BLUE)
        # ---------- Y轴标记 0, 0.5, 1 ----------
        y_ticks = VGroup()
        for y_val in [0, 0.5, 1]:
            tick = axes2.get_horizontal_line(axes2.coords_to_point(0, y_val), color=GRAY)
            label = MathTex(str(y_val)).next_to(axes2.coords_to_point(0, y_val), LEFT, buff=0.2)
            y_ticks.add(tick, label)
        # ---------- 0.5 虚线垂直于 Y 轴 ----------
        vline = DashedLine(
            start=axes2.coords_to_point(-6, 0.5),
            end=axes2.coords_to_point(6, 0.5),
            color=YELLOW,
            stroke_width=2
        )
        # ---------- sigmoid(-1) 的点 ----------
        x_pt = -1
        y_pt = sigmoid_func(x_pt)
        point = Dot(axes2.coords_to_point(x_pt, y_pt), color=RED, radius=0.1)
        # ---------- VGroup ----------
        sigmoid_group = VGroup(
            axes2, x_label2, y_label2,
            sigmoid_curve, y_ticks, vline, point
        ).shift(RIGHT*3).scale(0.8)
        self.play(FadeIn(sigmoid_group))
        self.wait(3)

        b_screen = Rectangle(
            width=self.camera.frame_width,
            height=self.camera.frame_height,
            fill_color=BLACK,
            fill_opacity=1,
            stroke_width=0,
            z_index=0
        )
        self.play(FadeIn(b_screen))

        steps = [
            "1. 用线性回归思想初始化一条线性打分函数",
            "2. 用逻辑函数 sigmoid 套打分函数得出的分数转换为概率",
            "3. 将概率套到交叉熵，得到交叉熵损失函数",
            "4. 交叉熵损失函数做梯度下降，迭代优化打分函数参数",
        ]
        color_start = WHITE  # 第一行颜色
        color_end = RED_D    # 第四行颜色
        text_mobjects = []
        for i, step in enumerate(steps):
            txt = Text(step, font="Microsoft YaHei", font_size=36)
            txt.to_edge(UP).shift(DOWN * i)
            txt.shift(LEFT * 1.5)
            txt.set_opacity(0)
            # 计算渐变颜色
            t = i / (len(steps) - 1)
            txt.set_color(interpolate_color(color_start, color_end, t))
            text_mobjects.append(txt)

        group = VGroup(*text_mobjects)
        self.add(group)
        self.play(group.animate.shift(DOWN * 3), run_time=1)
        for txt in text_mobjects[:2]:
            self.play(txt.animate.shift(RIGHT * 1.5).set_opacity(1), run_time=1)
            self.wait(1)
        ww = Text("逻辑回归", font="Microsoft YaHei", font_size=80, color=ORANGE)
        ww.move_to(UP)
        self.play(
            group.animate.shift(DOWN*0.8),
            FadeIn(ww)
        )
        self.wait(2)