In [None]:
CHARM_COLOR = "#ff6666"
class Periodicity(Scene):
    def construct(self):
        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        2: ORANGE,
        3: CHARM_COLOR,
        }

        markov_chain = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            )
        )

        to_fade, markov_chain_g = self.show_convergence(markov_chain, 20, 5/15)
        self.play(
            FadeOut(to_fade)
        )
        self.wait()

        self.clear()

        new_markov_chain = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            ),
            dist=np.array([1, 0, 0, 0])
        )

        to_fade, markov_chain_g = self.show_convergence(new_markov_chain, 25, 1/15)

        self.play(
            FadeOut(to_fade),
            FadeOut(markov_chain_g)
        )
        self.wait()

        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        }

        periodic_markov_chain = MarkovChain(
            2,
            [
                (0, 1),
                (1, 0),
            ],
        )

        markov_chain_g = MarkovChainGraph(periodic_markov_chain, enable_curved_double_arrows=True, layout="circular")

        markov_chain_g.scale(1.5)
        markov_chain_t_labels = markov_chain_g.get_transition_labels()
        markov_chain_g.clear_updaters()

        markov_chain_group = VGroup(markov_chain_g, markov_chain_t_labels)

        self.play(
            *[Write(markov_chain_g.vertices[state]) for state in periodic_markov_chain.get_states()]
        )

        self.wait()

        self.play(
            *[Write(markov_chain_g.edges[e]) for e in periodic_markov_chain.get_edges()],
            FadeIn(markov_chain_t_labels)
        )
        self.wait()

        self.play(
            markov_chain_group.animate.shift(UP * 2.5),
        )
        self.wait()

        markov_chain_sim, left_axes_group, left_state_to_line_segments = self.show_periodicity(periodic_markov_chain, markov_chain_group, DOWN * 1.5)
        self.play(
            *[FadeOut(u) for u in markov_chain_sim.get_users()],
            left_axes_group.animate.shift(LEFT * 3.1),
            *[line_segs.animate.shift(LEFT * 3.1) for line_segs in left_state_to_line_segments.values()]
        )
        self.wait()

        periodic_markov_chain.set_starting_dist(np.array([0.8, 0.2]))

        markov_chain_sim, right_axes_group, right_state_to_line_segments = self.show_periodicity(periodic_markov_chain, markov_chain_group, RIGHT * 3.5 + DOWN * 1.5)

        periodic_markov_chain_title = Text("Periodic Markov Chains", weight=BOLD).scale(0.7)
        periodic_markov_chain_title.to_edge(LEFT * 1.5).shift(UP * 1)
        self.play(
            *[FadeOut(u) for u in markov_chain_sim.get_users()],
            FadeOut(left_axes_group),
            *[FadeOut(seg) for seg in list(left_state_to_line_segments.values())],
            Write(periodic_markov_chain_title)
        )
        self.wait()

        intuitive_def = BulletedList(
            "Must be an irreducible Markov chain",
            r"User visits states in regular interval (period) $ > 1 $ (*)",
            r"No guarantee of convergence to stationary distribution",
            r"No such period $> 1$ exists $\rightarrow$ aperiodic Markov chain",
            buff=0.4,
        ).scale(0.55)

        intuitive_def.next_to(periodic_markov_chain_title, DOWN, buff=0.5, aligned_edge=LEFT)
        note = Tex("(*) more rigorous and precise definitions exist").scale(0.5).next_to(intuitive_def, DOWN, aligned_edge=LEFT, buff=0.5)

        for i, point in enumerate(intuitive_def):
            self.play(
                FadeIn(point)
            )
            if i == 1:
                self.add(note)
            self.wait()

    def show_periodicity(self, markov_chain, markov_chain_group, axes_position):
        markov_chain_g, markov_chain_t_labels = markov_chain_group

        markov_chain_sim = MarkovChainSimulator(
            markov_chain, markov_chain_g, num_users=80,
        )

        users = markov_chain_sim.get_users()
        self.play(
            *[FadeIn(u) for u in users]
        )
        self.wait()

        num_steps = 10
        axes, state_to_line_segments = self.get_distribution_plot(
            markov_chain, num_steps,
            axes_position=axes_position,
            ax_width=5, ax_height=3.2,
            dist=markov_chain.get_starting_dist()
        )
        legend = self.get_legend().move_to(axes.get_center() + RIGHT * (axes.width / 2 - MED_SMALL_BUFF) + UP * (axes.height / 2) - MED_SMALL_BUFF)
        starting_dist = markov_chain.get_starting_dist()

        initial_dist = MathTex(
            r"\pi_0(0) = {0} \quad \pi_0(1) = {1}".format(starting_dist[0], starting_dist[1])
        ).scale(0.8).next_to(axes, UP).shift(RIGHT * SMALL_BUFF + UP * SMALL_BUFF * 2)

        self.play(
            Write(axes),
            Write(legend),
            Write(initial_dist)
        )
        self.wait()
        for state in state_to_line_segments:
            if state == 1 or starting_dist[0] != starting_dist[1]:
                continue
            for seg in state_to_line_segments[state]:
                seg.set_stroke(width=10)

        self.demo_convergence_graph(state_to_line_segments, markov_chain_sim, num_steps, step_threshold=3)

        return markov_chain_sim, VGroup(axes, legend, initial_dist), state_to_line_segments

    def demo_convergence_graph(self, state_to_line_segments, markov_chain_sim, num_steps, short_wait_time=1/15, step_threshold=5):
        for step in range(num_steps):
            if step < step_threshold:
                rate_func = smooth
            else:
                rate_func = linear
            transition_animations = markov_chain_sim.get_instant_transition_animations()
            dist_graph_aniamtions = self.get_dist_graph_step_animations(state_to_line_segments, step)
            self.play(
                *transition_animations + dist_graph_aniamtions, rate_func=rate_func
            )
            if step < step_threshold:
                self.wait()
            else:
                self.wait(short_wait_time)

        self.wait()

    def show_convergence(self, markov_chain, num_steps, short_wait_time):
        markov_chain_g = MarkovChainGraph(
            markov_chain,
            enable_curved_double_arrows=True,
            layout="circular",
            state_color_map=self.state_color_map
        )

        markov_chain_g.scale(1.5)
        markov_chain_t_labels = markov_chain_g.get_transition_labels()
        self.play(
            FadeIn(markov_chain_g),
            FadeIn(markov_chain_t_labels)
        )
        self.wait()

        markov_chain_g.clear_updaters()
        markov_chain_group = VGroup(markov_chain_g, markov_chain_t_labels)

        self.play(
            markov_chain_group.animate.scale(1.1 / 1.5).shift(LEFT * 3.5)
        )
        self.wait()

        markov_chain_sim = MarkovChainSimulator(
            markov_chain, markov_chain_g, num_users=100,
        )

        users = markov_chain_sim.get_users()
        self.play(
            *[FadeIn(u) for u in users]
        )
        self.wait()

        stationary_dist = markov_chain.get_true_stationary_dist()

        axes, state_to_line_segments = self.get_distribution_plot(markov_chain, num_steps, dist=markov_chain.get_starting_dist())
        legend = self.get_legend().to_edge(RIGHT * 3).shift(UP * 2.5)
        self.play(
            Write(axes),
            Write(legend),
        )
        self.wait()

        for step in range(num_steps):
            if step < 5:
                rate_func = smooth
            else:
                rate_func = linear
            transition_animations = markov_chain_sim.get_instant_transition_animations()
            dist_graph_aniamtions = self.get_dist_graph_step_animations(state_to_line_segments, step)
            self.play(
                *transition_animations + dist_graph_aniamtions, rate_func=rate_func
            )
            if step < 5:
                self.wait()
            else:
                self.wait(short_wait_time)

        self.wait()
        return VGroup(axes, legend, VGroup(*list(state_to_line_segments.values())), VGroup(*users)), markov_chain_group

    def get_dist_graph_step_animations(self, state_to_line_segments, step):
        animations = []
        for state, line_segments in state_to_line_segments.items():
            animations.append(
                Create(line_segments[step])
            )
        return animations

    def get_distribution_plot(self, markov_chain, num_steps, axes_position=RIGHT*3, ax_width=5, ax_height=5, dist=None):
        markov_chain_copy = MarkovChain(
            len(markov_chain.get_states()),
            markov_chain.get_edges(),
            transition_matrix=markov_chain.get_transition_matrix(),
            dist=dist,
        )
        distribution_sequence = [markov_chain_copy.get_current_dist()]
        for _ in range(num_steps):
            markov_chain_copy.update_dist()
            distribution_sequence.append(markov_chain_copy.get_current_dist())

        distribution_sequence = np.array(distribution_sequence)
        print(distribution_sequence)

        axes = Axes(
            x_range=(0, num_steps, 1),
            y_range=(0, 1, 0.1),
            x_length=ax_width,
            y_length=ax_height,
            tips=False,
            axis_config={"include_numbers": True, "include_ticks": False},
            x_axis_config={"numbers_to_exclude": range(num_steps + 1)},
            y_axis_config={"numbers_to_exclude": np.arange(0.1, 1.05, 0.1)}
        ).move_to(axes_position)

        custom_y_label = Text("1.0").scale(0.4)
        custom_y_label_pos = axes.y_axis.n2p(1) + LEFT * 0.4
        custom_y_label.move_to(custom_y_label_pos)

        y_axis_label = MathTex(r"\pi_n( \cdot )").scale(0.7).next_to(axes.y_axis, LEFT)


        custom_x_label = Text(f"{num_steps}").scale(0.4)
        custom_x_label_pos = axes.x_axis.n2p(num_steps) + DOWN * MED_SMALL_BUFF
        custom_x_label.move_to(custom_x_label_pos)

        x_axis_label = Text("Step/Iteration").scale(0.5).next_to(axes.x_axis, DOWN)

        axes.add(custom_x_label, custom_y_label, x_axis_label, y_axis_label)

        state_to_line_segments = {}
        for state in markov_chain.get_states():
            x_values=list(range(num_steps+1))
            y_values=distribution_sequence[:, state]
            line_color=self.state_color_map[state]
            line_segments = self.get_line_segments(axes, x_values, y_values, line_color)
            state_to_line_segments[state] = line_segments

        return axes, state_to_line_segments

    def get_line_segments(self, axes, x_values, y_values, line_color):
        line_segments = []
        for i in range(len(x_values) - 1):
            start = axes.coords_to_point(x_values[i], y_values[i])
            end = axes.coords_to_point(x_values[i + 1], y_values[i + 1])
            line_seg = Line(start, end).set_stroke(color=line_color)
            line_segments.append(line_seg)

        return VGroup(*line_segments)

    def get_legend(self):
        legend = VGroup()
        for state, color in self.state_color_map.items():
            label = Text(str(state)).scale(0.4)
            line = Line(LEFT*SMALL_BUFF, RIGHT*SMALL_BUFF).set_stroke(color)
            legend_item = VGroup(line, label).arrange(RIGHT)
            legend.add(legend_item)
        return legend.arrange_in_grid(rows=2)

In [None]:
class PeriodicityUsersMovement(Scene):
    def construct(self):
        periodic_markov_chain = MarkovChain(
            2,
            [
                (0, 1),
                (1, 0),
            ],
        )

        markov_chain_g = MarkovChainGraph(periodic_markov_chain, enable_curved_double_arrows=True, layout="circular")

        markov_chain_g.scale(1.5)
        markov_chain_t_labels = markov_chain_g.get_transition_labels()

        markov_chain_group = VGroup(markov_chain_g, markov_chain_t_labels).shift(UP * 2.5)
        markov_chain_sim = MarkovChainSimulator(
            periodic_markov_chain, markov_chain_g, num_users=1,
        )
        users = markov_chain_sim.get_users()
        users[0].scale(1.2)
        self.play(
            FadeIn(users[0])
        )
        self.wait()

        for step in range(10):
            transition_animations = markov_chain_sim.get_instant_transition_animations()
            self.play(
                *transition_animations
            )

In [None]:
class IntroduceBigTheorem1(Periodicity):
    def construct(self):
        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        2: ORANGE,
        3: CHARM_COLOR,
        }

        markov_chain_1 = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            ),
        )

        self.make_convergence_scene(markov_chain_1, 40)

    def make_convergence_scene(self, markov_chain, num_steps, short_wait_time=1/15):
        markov_chain_g = MarkovChainGraph(
            markov_chain,
            enable_curved_double_arrows=True,
            layout="circular",
            state_color_map=self.state_color_map
        )

        markov_chain_t_labels = markov_chain_g.get_transition_labels()

        markov_chain_g.clear_updaters()
        markov_chain_group = VGroup(markov_chain_g, markov_chain_t_labels)

        markov_chain_group.scale(1.1).shift(LEFT * 3.5 + UP * 0.5)

        markov_chain_sim = MarkovChainSimulator(
            markov_chain, markov_chain_g, num_users=100,
        )

        users = markov_chain_sim.get_users()
        self.play(
            FadeIn(markov_chain_group),
            *[FadeIn(u) for u in users]
        )
        self.wait()

        stationary_dist = markov_chain.get_true_stationary_dist()

        axes, state_to_line_segments = self.get_distribution_plot(markov_chain, num_steps, dist=markov_chain.get_starting_dist(), axes_position=RIGHT * 3 + UP * 0.5)
        legend = self.get_legend().to_edge(RIGHT * 3).shift(UP * 2.5)

        starting_dist = markov_chain.get_starting_dist()
        initial_dist = MathTex(
            r"\pi_0(0) = {0} \quad \pi_0(1) = {1} \quad \pi_0(2) = {2} \quad \pi_0(3) = {3}".format(
                starting_dist[0], starting_dist[1], starting_dist[2], starting_dist[3]
            )
        ).scale(0.8).move_to(DOWN * 3.2)
        self.play(
            Write(axes),
            Write(legend),
            FadeIn(initial_dist)
        )
        self.wait()
        wait_time = 1
        for step in range(num_steps):
            if step < 5:
                rate_func = smooth
            else:
                rate_func = linear

            transition_animations = markov_chain_sim.get_instant_transition_animations()
            dist_graph_aniamtions = self.get_dist_graph_step_animations(state_to_line_segments, step)

            self.play(
                *transition_animations + dist_graph_aniamtions, rate_func=rate_func
            )
            if step < 5:
                self.wait(wait_time)
                wait_time *= 0.8

        self.wait()

        self.play(
            *[mob.animate.fade(0.6) for mob in self.mobjects]
        )
        return VGroup(axes, legend, VGroup(*list(state_to_line_segments.values())), VGroup(*users)), markov_chain_group


class IntroduceBigTheorem2(IntroduceBigTheorem1):
    def construct(self):
        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        2: ORANGE,
        3: CHARM_COLOR,
        }

        markov_chain_1 = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            ),
            dist=[0, 1, 0, 0]
        )

        self.make_convergence_scene(markov_chain_1, 40)

class IntroduceBigTheorem3(IntroduceBigTheorem1):
    def construct(self):
        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        2: ORANGE,
        3: CHARM_COLOR,
        }

        markov_chain_1 = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            ),
            dist=[0.5, 0.1, 0.2, 0.2]
        )

        self.make_convergence_scene(markov_chain_1, 40)

class IntroduceBigTheorem4(IntroduceBigTheorem1):
    def construct(self):
        self.state_color_map = {
        0: PURPLE,
        1: GREEN,
        2: ORANGE,
        3: CHARM_COLOR,
        }

        markov_chain_1 = MarkovChain(
            4,
            [
                (0, 1),
                (1, 0),
                (0, 2),
                (1, 2),
                (1, 3),
                (2, 0),
                (2, 3),
                (3, 1),
            ],
            transition_matrix=np.array(
                [
                [0, 0.7, 0.3, 0],
                [0.2, 0, 0.6, 0.2],
                [0.6, 0, 0, 0.4],
                [0, 1, 0, 0],
                ]
            ),
            dist=[0.1, 0.1, 0.7, 0.1]
        )

        self.make_convergence_scene(markov_chain_1, 40)

In [None]:
class BruteForceMethod(TransitionMatrixCorrected3):
    def construct(self):

        frame = self.camera.frame
        markov_ch = MarkovChain(
            4,
            edges=[
                (2, 0),
                (2, 3),
                (0, 3),
                (3, 1),
                (2, 1),
                (1, 2),
            ],
            dist=[0.2, 0.5, 0.2, 0.1],
        )

        markov_ch_mob = MarkovChainGraph(
            markov_ch,
            curved_edge_config={"radius": 2, "tip_length": 0.1},
            straight_edge_config={"max_tip_length_to_length_ratio": 0.08},
            layout="circular",
        )

        markov_ch_sim = MarkovChainSimulator(markov_ch, markov_ch_mob, num_users=50)
        users = markov_ch_sim.get_users()

        count_labels = self.get_current_count_mobs(
            markov_chain_g=markov_ch_mob, markov_chain_sim=markov_ch_sim, use_dist=True
        )

        stationary_dist_tex = (
            MathTex("\pi_{n+1} = \pi_{n} P")
            .scale(1.3)
            .next_to(markov_ch_mob, RIGHT, buff=6, aligned_edge=LEFT)
            .shift(UP * 2)
        )
        ############### ANIMATIONS

        self.play(Write(markov_ch_mob))
        self.play(
            LaggedStart(*[FadeIn(u) for u in users]),
            LaggedStart(
                *[FadeIn(l) for l in count_labels.values()],
            ),
            run_time=0.5,
        )

        self.play(frame.animate.shift(RIGHT * 4 + UP * 0.5).scale(1.2))

        title = (
            Text("Brute Force Method",  weight=BOLD)
            .scale(1)
            .move_to(frame.get_top())
            .shift(DOWN * 0.9)
        )
        self.play(FadeIn(title))
        self.wait()

        self.play(Write(stationary_dist_tex[0][-1]))
        self.play(Write(stationary_dist_tex[0][5:7]))
        self.play(Write(stationary_dist_tex[0][:5]))

        last_dist = markov_ch_sim.get_user_dist().values()
        last_dist_mob = (
            self.vector_to_mob(last_dist)
            .scale_to_fit_width(stationary_dist_tex[0][5:7].width)
            .next_to(stationary_dist_tex[0][5:7], DOWN, buff=0.4)
        )
        self.play(FadeIn(last_dist_mob))
        self.wait()

        # first iteration
        transition_map = markov_ch_sim.get_lagged_smooth_transition_animations()
        count_labels, count_transforms = self.update_count_labels(
            count_labels, markov_ch_mob, markov_ch_sim, use_dist=True
        )

        current_dist = markov_ch_sim.get_user_dist().values()
        current_dist_mob = (
            self.vector_to_mob(current_dist)
            .scale_to_fit_width(last_dist_mob.width)
            .next_to(stationary_dist_tex[0][:4], DOWN, buff=0.4)
        )
        self.play(
            *[LaggedStart(*transition_map[i]) for i in markov_ch.get_states()],
            *count_transforms,
            FadeIn(current_dist_mob),
        )

        distance = dist(current_dist, last_dist)
        distance_definition = (
            MathTex(r"D(\pi_{n+1}, \pi_{n}) =  ||\pi_{n+1} - \pi_{n}||_2")
            .scale(0.7)
            .next_to(stationary_dist_tex, DOWN, buff=2.5, aligned_edge=LEFT)
        )
        distance_mob = (
            VGroup(
                MathTex("D(\pi_{" + str(1) + "}, \pi_{" + str(0) + "})"),
                MathTex("="),
                Text(f"{distance:.5f}").scale(0.6),
            )
            .arrange(RIGHT, buff=0.2)
            .scale(0.7)
            .next_to(stationary_dist_tex, DOWN, buff=2.5, aligned_edge=LEFT)
        )

        tolerance = 0.001
        tolerance_mob = (
            Text(
                "Threshold = " + str(tolerance)
                
             
            )
            .scale(0.4)
            .next_to(distance_mob, DOWN, buff=0.2, aligned_edge=LEFT)
        )

        self.play(FadeIn(distance_definition))
        self.wait()
        self.play(
            FadeOut(distance_definition, shift=UP * 0.3),
            FadeIn(distance_mob, shift=UP * 0.3),
        )
        self.wait()

        self.play(FadeIn(tolerance_mob, shift=UP * 0.3))

        tick = (
            ImageMobject("images.png")
            .scale(0.1)
            .set_color(PURE_GREEN)
            .next_to(tolerance_mob, RIGHT, buff=0.3)
        )

        self.wait()
        ## start the loop
        for i in range(2, 100):
            transition_animations = markov_ch_sim.get_instant_transition_animations()

            count_labels, count_transforms = self.update_count_labels(
                count_labels, markov_ch_mob, markov_ch_sim, use_dist=True
            )

            last_dist = current_dist
            current_dist = markov_ch_sim.get_user_dist().values()

            distance = dist(current_dist, last_dist)

            i_str = str(i)
            i_minus_one_str = str(i - 1)
            new_distance_mob = (
                VGroup(
                    MathTex("D(\pi_{" + i_str + "}, \pi_{" + i_minus_one_str + "})"),
                    MathTex("="),
                    Text(f"{distance:.5f}")
                )
                .arrange(RIGHT, buff=0.2)
                .scale(0.7)
                .next_to(stationary_dist_tex, DOWN, buff=2.5, aligned_edge=LEFT)
            )

            run_time = 0.8 if i < 6 else 1 / i

            if i < 6:
                current_to_last_shift = current_dist_mob.animate.move_to(last_dist_mob)
                fade_last_dist = FadeOut(last_dist_mob)
                last_dist_mob = current_dist_mob

                current_dist_mob = (
                    self.vector_to_mob(current_dist)
                    .scale_to_fit_width(last_dist_mob.width)
                    .next_to(stationary_dist_tex[0][:4], DOWN, buff=0.4)
                )

                self.play(
                    *transition_animations + count_transforms,
                    current_to_last_shift,
                    fade_last_dist,
                    FadeIn(current_dist_mob),
                    FadeTransform(distance_mob, new_distance_mob),
                    run_time=run_time,
                )

                distance_mob = new_distance_mob
            else:

                self.remove(last_dist_mob)
                last_dist_mob = current_dist_mob.move_to(last_dist_mob)

                current_dist_mob = (
                    self.vector_to_mob(current_dist)
                    .scale_to_fit_width(last_dist_mob.width)
                    .next_to(stationary_dist_tex[0][:4], DOWN, buff=0.4)
                )

                self.add(current_dist_mob)

                self.play(
                    *transition_animations + count_transforms,
                    FadeTransform(distance_mob, new_distance_mob),
                    run_time=run_time,
                )
                distance_mob = new_distance_mob

            if distance <= tolerance:
                found_iteration = (
                    Text(
                        f"iteration: {str(i)}"
                        
        
                    )
                    .scale(0.3)
                    .next_to(tick, RIGHT, buff=0.1)
                )
                self.play(
                    FadeIn(tick, shift=UP * 0.3),
                    FadeIn(found_iteration, shift=UP * 0.3),
                )

                # get out of the loop
                break

        self.wait()

        ### the final distribution is:

        self.play(
            FadeOut(distance_mob),
            FadeOut(tolerance_mob),
            FadeOut(found_iteration),
            FadeOut(tick),
            FadeOut(last_dist_mob),
            current_dist_mob.animate.next_to(stationary_dist_tex, DOWN, buff=1.5).scale(
                2
            ),
        )
        self.wait()
        vertices_down = (
            VGroup(*[dot.copy().scale(0.8) for dot in markov_ch_mob.vertices.values()])
            .arrange(DOWN, buff=0.3)
            .next_to(current_dist_mob.copy().shift(RIGHT * 0.25), LEFT, buff=0.2)
        )
        self.play(FadeIn(vertices_down), current_dist_mob.animate.shift(RIGHT * 0.25))

    def vector_to_mob(self, vector: Iterable):
        str_repr = np.array([f"{a:.2f}" for a in vector]).reshape(-1, 1)
        return Matrix(
            str_repr,
            left_bracket="[",
            right_bracket="]",
            element_to_mobject=Text,
            h_buff=2.3,
            v_buff=1.3,
        )

In [None]:
class SystemOfEquationsMethod(BruteForceMethod):
    def construct(self):
        frame = self.camera.frame
        markov_ch = MarkovChain(
            4,
            edges=[
                (2, 0),
                (2, 3),
                (0, 3),
                (3, 1),
                (2, 1),
                (1, 2),
            ],
            dist=[0.2, 0.5, 0.2, 0.1],
        )

        markov_ch_mob = MarkovChainGraph(
            markov_ch,
            curved_edge_config={"radius": 2, "tip_length": 0.1},
            straight_edge_config={"max_tip_length_to_length_ratio": 0.08},
            layout="circular",
        )

        markov_ch_sim = MarkovChainSimulator(markov_ch, markov_ch_mob, num_users=50)

        equations_mob = (
            self.get_balance_equations(markov_chain=markov_ch)
            .scale(1)
            .next_to(markov_ch_mob, RIGHT, buff=2.5)
        )
        last_equation = equations_mob[0][38:]

        pi_dists = []
        for s in markov_ch.get_states():
            state = markov_ch_mob.vertices[s]
            label_direction = normalize(state.get_center() - markov_ch_mob.get_center())
            pi_dists.append(
                MathTex(f"\pi({s})")
                .scale(0.8)
                .next_to(state, label_direction, buff=0.1)
            )

        pi_dists_vg = VGroup(*pi_dists)

        self.play(Write(markov_ch_mob))
        self.play(Write(pi_dists_vg))
        self.play(frame.animate.shift(RIGHT * 3.3 + UP * 0.8).scale(1.2))

        title = (
            Text("System of Equations Method",  weight=BOLD)
            .scale(1)
            .move_to(frame.get_top())
            .shift(DOWN * 0.9)
        )
        self.play(Write(title))

        add_to_one = (
            MathTex("1 = " + "+".join([f"\pi({s})" for s in markov_ch.get_states()]))
            .scale(0.9)
            .next_to(equations_mob, DOWN, aligned_edge=LEFT)
        )

        stationary_def = MathTex(r"\pi = \pi P ").scale(2.5).move_to(equations_mob)

        self.play(FadeIn(stationary_def, shift=UP * 0.3))

        self.wait()

        self.play(
            FadeIn(equations_mob),
            stationary_def.animate.next_to(equations_mob, UP, buff=0.5).scale(0.6),
        )
        self.wait()

        infinite_solutions = (
            Text("Infinite solutions!",  weight=BOLD)
            .scale(0.3)
            .move_to(equations_mob, UP + LEFT)
            .rotate(15 * DEGREES)
            .shift(LEFT * 1.6 + UP * 0.3)
        )

        self.play(FadeIn(infinite_solutions, shift=UP * 0.3))

        for i in range(2):
            self.play(
                infinite_solutions.animate.set_opacity(0),
                run_time=3 / config.frame_rate,
            )
            # self.wait(1 / config.frame_rate)
            self.play(
                infinite_solutions.animate.set_opacity(1),
                run_time=3 / config.frame_rate,
            )
            self.wait(3 / config.frame_rate)

        self.wait()
        self.play(
            FadeIn(add_to_one, shift=UP * 0.3),
            FadeOut(infinite_solutions, shift=UP * 0.3),
        )
        self.wait()

        self.play(
            FadeOut(last_equation, shift=UP * 0.3),
            add_to_one.animate.move_to(last_equation, aligned_edge=LEFT),
        )

        stationary_distribution = self.solve_system(markov_ch)

        tex_strings = []
        for i, s in enumerate(stationary_distribution):
            tex_str = f"\pi({i}) &= {s:.3f}"
            tex_strings.append(tex_str)

        stationary_dist_mob = MathTex("\\\\".join(tex_strings)).move_to(equations_mob)

        self.play(
            FadeOut(equations_mob[0][:38], shift=UP * 0.3),
            FadeOut(add_to_one, shift=UP * 0.3),
            FadeIn(stationary_dist_mob, shift=UP * 0.3),
        )

        line = (
            Line()
            .set_stroke(width=2)
            .stretch_to_fit_width(stationary_dist_mob.width * 1.3)
            .next_to(stationary_dist_mob, DOWN, buff=-0.1)
        )
        total = (
            Text("Total = 1.000",  weight=BOLD)
            .scale_to_fit_width(stationary_dist_mob.width)
            .next_to(line, DOWN, buff=0.3)
        )
        self.wait()
        self.play(stationary_dist_mob.animate.shift(UP * 0.4))

        self.play(Write(line), Write(total))

    def solve_system(self, markov_chain: MarkovChain):
        P = markov_chain.get_transition_matrix()

        # P.T gives us the balance equations
        dependent_system = P.T

        # in this step, we are essentially moving every term
        # to the left, so we end up with 0s on the other side
        # of the equation
        for i, eq in enumerate(dependent_system):
            eq[i] -= 1

        # this removes the last equation and substitutes it
        # for our probability constraint
        dependent_system[-1] = [1.0 for s in range(dependent_system.shape[1])]

        # now we create the other side of the equations, which
        # will be a vector of size len(states) with all zeros but
        # a single 1 for the last element
        right_side = [0.0 for s in range(dependent_system.shape[1])]
        right_side[-1] = 1

        # we finally solve the system!
        return np.linalg.solve(dependent_system, right_side)

    def get_balance_equations(self, markov_chain: MarkovChain):
        trans_matrix_T = markov_chain.get_transition_matrix().T
        state_names = markov_chain.get_states()

        balance_equations = []
        for equation in trans_matrix_T:
            balance_equations.append(
                [
                    (
                        Fraction(term).limit_denominator().numerator,
                        Fraction(term).limit_denominator().denominator,
                    )
                    for term in equation
                ]
            )

        tex_strings = []
        for state, fractions in zip(state_names, balance_equations):
            pi_sub_state = f"\pi({state})"

            terms = []
            for i, term in enumerate(fractions):
                state_term = f"\pi({state_names[i]})"
                if term[0] == 1 and term[1] == 1:
                    terms.append(state_term)
                else:
                    if term[0] != 0:
                        fraction = r"\frac{" + str(term[0]) + "}{" + str(term[1]) + "}"
                        terms.append(fraction + state_term)

            terms = "+".join(terms)

            full_equation_tex = pi_sub_state + "&=" + terms
            tex_strings.append(full_equation_tex)

        tex_strings = "\\\\".join(tex_strings)
        return MathTex(tex_strings)

In [None]:
class EigenValueMethodFixed2(Scene):
    def construct(self):
        markov_chain = MarkovChain(
            3,
            [(0, 1), (1, 2), (1, 0), (0, 2), (2, 1)]
        )

        markov_scale = 0.8
        markov_chain_g = MarkovChainGraph(markov_chain)
        markov_chain_g.clear_updaters()
        markov_chain_g.scale(markov_scale).shift(UP * 2)

        self.play(
            FadeIn(markov_chain_g)
        )
        self.wait()

        transpose_transition_eq = self.show_transition_equation(markov_chain_g, markov_chain)

        self.show_eigen_concept(transpose_transition_eq)

        self.show_example()

    def show_transition_equation(self, markov_chain_g, markov_chain):
        transition_eq = MathTex(r"\pi_{n + 1} = \pi_n P").next_to(markov_chain_g, RIGHT, buff=0.5)

        pi_n_1_row_vec = Matrix(
            [[r"\pi_{n + 1}(0)", r"\pi_{n+1}(1)", r"\pi_{n + 1}(2)"]],
            h_buff=2,
        ).scale(0.7)

        equals = MathTex("=")
        pi_n_row_vec =  Matrix(
            [[r"\pi_n(0)", r"\pi_n(1)", r"\pi_n(2)"]],
            h_buff=1.7,
        ).scale(0.7)

        p_matrix = Matrix(
            [
            ["P(0, 0)", "P(0, 1)", "P(0, 2)"],
            ["P(1, 0)", "P(1, 1)", "P(1, 2)"],
            ["P(2, 0)", "P(2, 1)", "P(2, 2)"],
            ],
            h_buff=2
        ).scale(0.7)

        row_vec_equation = VGroup(pi_n_1_row_vec, equals, pi_n_row_vec, p_matrix).arrange(RIGHT)

        vector_scale = 0.7

        pi_n_1_col_vec = Matrix(
            [[r"\pi_{n + 1}(0)"], [r"\pi_{n+1}(1)"], [r"\pi_{n + 1}(2)"]],
            v_buff=1,
        ).scale(vector_scale)

        pi_n_col_vec = Matrix(
            [[r"\pi_{n}(0)"], [r"\pi_{n}(1)"], [r"\pi_{n}(2)"]],
            v_buff=1,
        ).scale(vector_scale)

        p_transpose_matrix = Matrix(
            [
            ["P(0, 0)", "P(1, 0)", "P(2, 0)"],
            ["P(0, 1)", "P(1, 1)", "P(2, 1)"],
            ["P(0, 2)", "P(1, 2)", "P(2, 2)"],
            ],
            h_buff=2,
            v_buff=1,
        ).scale(vector_scale)

        col_vec_equation = VGroup(pi_n_1_col_vec, equals.copy(), p_transpose_matrix, pi_n_col_vec).arrange(RIGHT)

        equation_transformation = VGroup(row_vec_equation, col_vec_equation).arrange(DOWN, buff=0.7).scale(0.8).next_to(markov_chain_g, DOWN)

        self.play(
            FadeIn(transition_eq),
            markov_chain_g.animate.shift(LEFT * 2)
        )
        self.wait()

        self.play(
            FadeIn(row_vec_equation)
        )
        self.wait()


        self.play(
            FadeIn(col_vec_equation),
        )
        self.wait()

        transpose_transition_eq = MathTex(r"\pi_{n + 1} = P^T \pi_n").next_to(transition_eq, DOWN, aligned_edge=LEFT)
        transpose_transition_eq.shift(UP * 0.5)

        self.play(
            transition_eq.animate.shift(UP * 0.5)
        )

        self.play(
            Write(transpose_transition_eq),
        )
        self.wait()

        surround_rects = [
        SurroundingRectangle(pi_n_1_row_vec[0], color=PURPLE, buff=SMALL_BUFF),
        SurroundingRectangle(pi_n_row_vec[0], color= GREEN_E, buff=SMALL_BUFF),
        SurroundingRectangle(VGroup(*[p_matrix[0][i] for i in range(9) if i % 3 == 0]), color=YELLOW, buff=SMALL_BUFF),
        SurroundingRectangle(VGroup(*[p_matrix[0][i] for i in range(9) if i % 3 == 1]), color=YELLOW, buff=SMALL_BUFF),
        SurroundingRectangle(VGroup(*[p_matrix[0][i] for i in range(9) if i % 3 == 2]), color=YELLOW, buff=SMALL_BUFF),
        SurroundingRectangle(pi_n_1_col_vec[0], color=PURPLE, buff=SMALL_BUFF),
        SurroundingRectangle(pi_n_col_vec[0], color=GREEN_E, buff=SMALL_BUFF),
        SurroundingRectangle(p_transpose_matrix[0][:3], color=YELLOW, buff=SMALL_BUFF/1.5),
        SurroundingRectangle(p_transpose_matrix[0][3:6], color=YELLOW, buff=SMALL_BUFF/1.5),
        SurroundingRectangle(p_transpose_matrix[0][6:9], color=YELLOW, buff=SMALL_BUFF/1.5),
        ]
        self.play(
            *[FadeIn(r) for i, r in enumerate(surround_rects) if i < 5]
        )
        self.wait()

        self.play(
            *[TransformFromCopy(surround_rects[i], surround_rects[i + 5]) for i in range(5)],
        )
        self.wait()

        self.play(
            FadeOut(markov_chain_g),
            FadeOut(equation_transformation),
            FadeOut(transition_eq),
            *[FadeOut(r) for r in surround_rects],
            transpose_transition_eq.animate.move_to(UP * 3.5)
        )
        self.wait()

        return transpose_transition_eq

    def show_eigen_concept(self, transpose_transition_eq):
        dist_between_nodes = 3
        transition_matrix = np.array([[0.3, 0.7], [0.4, 0.6]])
        markov_chain = MarkovChain(
            2,
            [(0, 1), (1, 0)],
            transition_matrix=transition_matrix,
            dist=np.array([0.9, 0.1]),
        )

        markov_chain_g = MarkovChainGraph(
            markov_chain,
            layout={
                0: LEFT * dist_between_nodes / 2,
                1: RIGHT * dist_between_nodes / 2,
            },
        )
        markov_chain_g.scale(1).shift(UP * 2.5)
        markov_chain_t_labels = markov_chain_g.get_transition_labels()

        self_edges = self.get_edges(markov_chain_g)
        labels = [self.get_label(self_edges[(u, v)], transition_matrix[u][v]) for u, v in self_edges]
        self_edges_group = VGroup(*[obj for obj in list(self_edges.values()) + labels])
        markov_chain_group = VGroup(markov_chain_g, markov_chain_t_labels, self_edges_group)
        self.play(
            FadeIn(markov_chain_group)
        )
        self.wait()

        markov_chain_sim = MarkovChainSimulator(markov_chain, markov_chain_g, num_users=50)
        users = markov_chain_sim.get_users()

        purple_plane = NumberPlane(
            x_range=[0, 1, 0.25],
            y_range=[0, 1, 0.25],
            x_length=7,
            y_length=4.5,
            background_line_style={
                "stroke_color": PURPLE,
                "stroke_width": 3,
                "stroke_opacity": 0.5,
            },
            # faded_line_ratio=4,
            axis_config={"stroke_color": PURPLE, "stroke_width": 0, "include_numbers": True, "numbers_to_exclude": [0.25, 0.75]},
        ).move_to(DOWN * 1)

        surround_plane = Polygon(
            purple_plane.coords_to_point(0, 0),
            purple_plane.coords_to_point(0, 1),
            purple_plane.coords_to_point(1, 1),
            purple_plane.coords_to_point(1, 0),
        ).set_stroke(color=PURPLE)

        self.play(
            FadeIn(purple_plane),
            FadeIn(surround_plane)
        )
        self.wait()

        current_dist = markov_chain.get_current_dist()
        current_vector = self.get_vector(
            current_dist,
            purple_plane,
            r"\pi_0",
            max_tip_length_to_length_ratio=0.1)
        num_steps = 5

        self.play(
             *[FadeIn(u) for u in users],
            FadeIn(current_vector),
        )
        self.wait()

        for i in range(1, num_steps + 1):
            transition_animations = markov_chain_sim.get_instant_transition_animations()
            self.play(
                current_vector.animate.become(
                    self.get_vector(
                        markov_chain.get_current_dist(),
                        purple_plane,
                        r"\pi_{0}".format(i),
                        max_tip_length_to_length_ratio=0.1
                    )
                ),
                *transition_animations,
            )
            self.wait()

        stationary_dist_def = MathTex(r"\pi = P^T \pi").move_to(transpose_transition_eq.get_center())

        self.play(
            ReplacementTransform(transpose_transition_eq, stationary_dist_def)
        )
        self.wait()