In [1]:
import numpy as np
import pandas as pd
import dirty_cat as dc
from sklearn.preprocessing import OneHotEncoder

from manim import *
from manim.utils.color import Colors
config.background_color = "#060d14"

In [3]:
%%manim -qh -v WARNING GammaPoissonAnimation

class GammaPoissonAnimation(Scene):
    
    # Declare colors
    police_color = Colors.blue.value
    officer_color = Colors.green.value
    senior_color = Colors.gold.value
    empty_color = Colors.light_gray.value
    police_officer_topic_color = Colors.teal.value
    senior_topic_color = Colors.maroon.value
    
    def show_title(self):
        title = Text("Gamma-Poisson encoding")
        self.play(Create(title))
        self.play(title.animate.to_edge(UP))
        return title
    
    def show_subtitle(self):
        intro = Text("Tries to describe each sample as a combination of topics").scale(0.5).shift(UP * 2.5)
        self.play(Create(intro))
        return intro
    
    def show_samples(self):
        samples = VGroup(
            Text("police officer"),
            Text("senior police officer"),
        ).scale(0.4).arrange(DOWN).shift(DOWN)
        samples[1].align_to(samples[0], LEFT)
        
        topic_desc = VGroup(
            VGroup(
                Text('1'),
                Text(' time the topic '),
                Text('is a police officer', color=self.police_officer_topic_color),
            ).arrange(RIGHT),
            VGroup(
                VGroup(
                    Text('1'),
                    Text(' time the topic '),
                    Text('is a police officer', color=self.police_officer_topic_color),
                ).arrange(RIGHT),
                VGroup(
                    Text('+ '),
                    Text('1'),
                    Text(' time the topic '),
                    Text('is senior', color=self.senior_topic_color),
                ).arrange(RIGHT),
            ).arrange(RIGHT),
        ).arrange(DOWN).scale(0.4).next_to(samples, RIGHT).shift(LEFT * 1.5)
        topic_desc[1].align_to(topic_desc[0], LEFT)
        # Align everything properly
        for elem in topic_desc[0]:
            elem.align_to(samples[0], UP)
        for line in topic_desc[1]:
            for elem in line:
                line.align_to(samples[1], UP)
        # Align again because shit won't work
        topic_desc[0][0].shift(DOWN * 0.01)
        topic_desc[1][0][0].shift(UP * 0.025)
        topic_desc[1][1][1].shift(UP * 0.025)
        topic_desc[1][1][3].shift(UP * 0.025)
        
        # Center both parts
        VGroup(samples, topic_desc).center()
        
        self.play(Create(samples))
        self.wait(2)
        self.play(Create(topic_desc[0]))
        self.wait(4)
        self.play(Create(topic_desc[1][0]))
        self.wait(2)
        self.play(Create(topic_desc[1][1]))
        
        return samples, topic_desc
    
    def samples_to_table(self, samples, topic_desc):
        samples_table = Table(
            [
                ["1", "0"],
                ["1", "1"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[
                Text("Is a police officer", weight=BOLD, color=self.police_officer_topic_color),
                Text("Is senior", weight=BOLD, color=self.senior_topic_color),
            ],
            row_labels=[
                Text("police officer"),
                Text("senior police officer"),
            ],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).shift(DOWN * 2, LEFT)
        
        transformed_values = VGroup(
            topic_desc[0][2], topic_desc[1][0][2],
            topic_desc[1][1][3],
            samples[0], samples[1],
            topic_desc[0][0],
            topic_desc[1][0][0],
            topic_desc[1][1][1],
        )
        self.play(
            # Column labels
            Transform(topic_desc[0][2], samples_table.get_entries((1, 2))),
            Transform(topic_desc[1][0][2], samples_table.get_entries((1, 2))),
            Transform(topic_desc[1][1][3], samples_table.get_entries((1, 3))),
            # Row labels
            Transform(samples[0], samples_table.get_entries((2, 1))),
            Transform(samples[1], samples_table.get_entries((3, 1))),
            # Matrix
            Transform(topic_desc[0][0], samples_table.get_entries((2, 2))),
            Transform(topic_desc[1][0][0], samples_table.get_entries((3, 2))),
            Transform(topic_desc[1][1][1], samples_table.get_entries((3, 3))),
            # Fade everything else
            FadeOut(topic_desc[0][1]),
            FadeOut(topic_desc[1][0][1]),
            FadeOut(topic_desc[1][1][0]), FadeOut(topic_desc[1][1][2]),
            run_time=2,
        )
        self.wait()
        
        # Create the rest of the table
        table_remains = VGroup(
            samples_table.get_vertical_lines(),
            samples_table.get_horizontal_lines(), 
            samples_table.get_entries((2, 3)),
        )
        self.play(Create(table_remains))
        
        return VGroup(transformed_values, table_remains)
    
    def construct(self):
        
        self.show_title()
        self.wait()
        self.show_subtitle()
        self.wait(2)
        samples, topic_desc = self.show_samples()
        self.wait()
        table = self.samples_to_table(samples, topic_desc)
        self.wait()
        
        self.play(FadeOut(table))
        self.wait()

        equation = VGroup(MathTex('F'), MathTex(r'\approx'), MathTex('X'), MathTex(r'\Lambda')).scale(1.5).arrange(RIGHT, buff=0.75)
        shape_F = MathTex(r"n \times v").next_to(equation[0], DOWN).scale(0.75)  # We'll shift it later
        shape_X = MathTex(r"n \times t").next_to(equation[2], DOWN).scale(0.75).align_to(shape_F, DOWN).shift(UP * 1.25)
        shape_L = MathTex(r"t \times v").next_to(equation[3], DOWN).scale(0.75).align_to(shape_F, DOWN).shift(UP * 1.25)
        equation_w_shapes = VGroup(equation, shape_F, shape_X, shape_L)
        self.play(Create(equation), run_time=2)
        self.wait(3)
        
        # Initial values
        F_exr = MathTex(r"\begin{bmatrix} 1 & 1 & 0 \\ 1 & 1 & 1 \end{bmatrix}").scale(0.75)
        X_exr = MathTex(r"\begin{bmatrix} ? & ? \\ ? & ? \end{bmatrix}").scale(0.75)
        L_exr = MathTex(r"\begin{bmatrix} ? & ? & ? \\ ? & ? & ? \end{bmatrix}").scale(0.75)
        equation_exr = VGroup(F_exr, MathTex(r"\approx"), X_exr, L_exr).arrange(RIGHT, buff=0.5).shift(DOWN * 0.25)
        
        # Set colors in F
        VGroup(F_exr[0][1], F_exr[0][4]).set_color(self.police_color)
        VGroup(F_exr[0][2], F_exr[0][5]).set_color(self.officer_color)
        VGroup(F_exr[0][6]).set_color(self.senior_color)
        VGroup(F_exr[0][3]).set_color(self.empty_color)
        
        # Random (intermediary) values
        X_exi = MathTex(r"\begin{bmatrix} 0.5 & 0.8 \\ 0.6 & 0.3 \end{bmatrix}").scale(0.75)
        L_exi = MathTex(r"\begin{bmatrix} 0.5 & 0.4 & 0.7 \\ 0.3 & 0.6 & 0.1 \end{bmatrix}").scale(0.75)
        equation_exi = VGroup(X_exi, L_exi).arrange(RIGHT, buff=0.5).shift(DOWN * 0.25 + LEFT * 2.5)
        
        # Final values
        F_exf = F_exr.copy()
        X_exf = MathTex(r"\begin{bmatrix} 1 & 0 \\ 1 & 1 \end{bmatrix}").scale(0.75)
        L_exf = MathTex(r"\begin{bmatrix} 1 & 1 & 0 \\ 0 & 0 & 1 \end{bmatrix}").scale(0.75)
        equation_exf = VGroup(F_exf, MathTex(r"\approx"), X_exf, L_exf).arrange(RIGHT, buff=0.5).shift(DOWN * 0.25)
        
        # Set colors in L
        VGroup(L_exf[0][1]).set_color(self.police_color)
        VGroup(L_exf[0][2]).set_color(self.officer_color)
        VGroup(L_exf[0][6]).set_color(self.senior_color)
        VGroup(L_exf[0][3], L_exf[0][4], L_exf[0][5]).set_color(self.empty_color)
        
        with_txt = Text("With a word-based decomposition").scale(0.5).to_edge(LEFT).shift(DOWN * 1.75)
        
        # Init the lower portion (n, v and t)
        n_text = VGroup(
            VGroup(Text("police", color=self.police_color), Text("officer", color=self.officer_color)).arrange(RIGHT),
            VGroup(Text("senior", color=self.senior_color), Text("police", color=self.police_color), Text("officer", color=self.officer_color)).arrange(RIGHT),
        ).scale(0.5).arrange(DOWN)
        n_brace = Brace(n_text, sharpness=1.0, stroke_width=2).scale(0.3).next_to(n_text, LEFT).rotate(270 * DEGREES).shift(RIGHT)
        n_eq = MathTex(r"n = ").scale(0.7).next_to(n_brace, LEFT)
        n_verbose = VGroup(n_eq, n_brace, n_text).shift(DOWN * 3 + LEFT * 3)
        n = MathTex("n = 2").scale(0.5).move_to(RIGHT * 4 + UP * 1.5)
        v_verbose = VGroup(
            MathTex(r"v =", r"\{"),
            MathTex(r"\text{police}", color=self.police_color),
            MathTex(r"\text{officer}", color=self.officer_color),
            MathTex(r"\text{senior}", color=self.senior_color),
            MathTex(r"\}"),
        ).scale(0.7).arrange(RIGHT).shift(DOWN * 3 + RIGHT * 3)
        v = MathTex("v = 3").scale(0.5).next_to(n, DOWN).align_to(n, RIGHT)
        t_verbose = MathTex(r"\text{(arbitrary)} \;\; t = 2").scale(0.7).move_to(DOWN * 1.25)
        t = MathTex("t = 2").scale(0.5).next_to(v, DOWN).align_to(n, RIGHT)
        
        # Align everything for better render
        for group in n_text:
            for elem in group[1:]:
                elem.align_to(group[0], UP)
        n_text[1].align_to(n_text[0], LEFT)
        v_verbose[2].align_to(v_verbose[1], UP)
        v_verbose[3].align_to(v_verbose[1], UP)
        
        takeaway = VGroup(
            Text("• F counts how many times each decomposition appears in each sample", t2g={'F': BLUE}),
            Text("• X describes how much each topic appears in each sample", t2g={'X': BLUE}),
            Text("• Λ indicates which words belong to which topic", t2g={'Λ': BLUE}),
        ).scale(0.5).arrange(DOWN).move_to(DOWN * 2.5)
        
        # F
        self.play(Create(shape_F))
        self.play(
            # We don't shift the whole equation because that would display X's and L's shapes
            equation.animate.shift(UP * 1.25),
            shape_F.animate.shift(UP * 1.25),
            Create(with_txt),
            Create(n_verbose),
            Create(v_verbose),
        )
        self.wait(5)
        
        n_text2 = n_text.copy().scale(0.75)
        n_text2[1].align_to(n_text2[0], RIGHT)
        n_text2.next_to(F_exr, LEFT)
        self.play(
            Create(F_exr),
            Create(n_text2),
            FadeOut(with_txt),
            ReplacementTransform(n_verbose, n),
            ReplacementTransform(v_verbose, v),
        )
        self.play(Create(takeaway[0]))
        self.wait(4)
        
        # X
        self.play(Create(shape_X), Create(t_verbose))
        self.wait(0.5)
        self.play(Create(equation_exr[1]), Create(X_exr))
        self.play(Create(takeaway[1]))
        self.wait(4)
        
        # Create a rectangle around the topics
        t_rec = SurroundingRectangle(t_verbose, color=BLUE)
        self.play(Create(t_rec), run_time=0.5)
        self.wait(10)
        t_rec.rotate(180 * DEGREES)
        t_rec.invert()
        self.play(Uncreate(t_rec), run_time=0.5)
        self.play(ReplacementTransform(t_verbose, t))
        self.wait(1)
        
        topics = VGroup(
            Text("Is a police officer", color=self.police_officer_topic_color),
            Text("Is senior", color=self.senior_topic_color),
        ).scale(0.4).arrange(DOWN).next_to(L_exr, RIGHT)
        topics[1].align_to(topics[0], LEFT)
        
        # Lambda
        self.play(Create(shape_L))
        self.wait(0.5)
        self.play(Create(L_exr), Create(topics))
        self.play(Create(takeaway[2]))
        self.wait(2)
        
        # Create a rectangle around F
        F_rec = SurroundingRectangle(F_exr, color=BLUE)
        self.play(Create(F_rec), run_time=0.5)
        self.wait(1)
        F_rec.rotate(180 * DEGREES)
        F_rec.invert()
        self.play(Uncreate(F_rec), run_time=0.5)
        self.wait(1)
        
        # Add random values
        equation_w_shapes.save_state()
        equation_w_shapes.generate_target()
        equation_w_shapes.target[1].set_color(config.background_color)  # Fade out F's shape
        equation_w_shapes.target[0][0].set_color(config.background_color)  # Fade out F
        equation_w_shapes.target[0][1].set_color(config.background_color)  # Fade out \approx
        equation_w_shapes.target.shift(LEFT * 4)
        self.play(FadeOut(n_text2), FadeOut(equation_exr[1]), FadeOut(F_exr), run_time=0.5)
        self.play(
            ReplacementTransform(X_exr, X_exi), 
            ReplacementTransform(L_exr, L_exi),
            MoveToTarget(equation_w_shapes),
            FadeOut(topics, shift=LEFT),
            FadeOut(n),
            FadeOut(v),
            FadeOut(t),
            FadeOut(takeaway),
        )
        self.wait(3)
        
        gap_algo = ImageMobject("assets/algo_transparent.png").scale(0.8).move_to(RIGHT * 3.5 + DOWN)
        self.play(FadeIn(gap_algo))
        self.wait(2)
        
        # Optimization results
        self.play(FadeOut(gap_algo))
        self.play(
            Restore(equation_w_shapes),
            ReplacementTransform(X_exi, X_exf), 
            ReplacementTransform(L_exi, L_exf),
            FadeIn(equation_exf[1], shift=RIGHT), 
            FadeIn(F_exf, shift=RIGHT), 
            FadeIn(n_text2, shift=RIGHT),
            FadeIn(topics.next_to(L_exf, RIGHT), shift=RIGHT),
            FadeIn(takeaway),
        )
        self.wait(2)
        
        self.play(
            Create(SurroundingRectangle(X_exf, color=BLUE)),
        )
        self.wait(15)
        
        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                                                      