In [1]:
import numpy as np
import pandas as pd
import dirty_cat as dc
from sklearn.preprocessing import OneHotEncoder

from manim import *
from manim.utils.color import Colors
config.background_color = "#060d14"

In [16]:
%%manim -qh -v WARNING SimilarityAnimation

class SimilarityAnimation(Scene):
    
    def show_title(self):
        title = Text("Similarity encoding")
        self.play(Create(title))
        self.play(title.animate.to_edge(UP))
        return title
    
    def construct(self):
        
        def get_rows(arr: np.array, i: int = 5) -> List[List[str]]:
            rows = []
            for row in arr.tolist():
                r = []
                for val in row:
                    r.append(str(val)[:i])
                rows.append(r)
            return rows
        
        def rgb_to_hex(rgb: Tuple[int, int, int]) -> str:
            return '#%02x%02x%02x' % rgb
        
        def hex_to_rgb(hx) -> Tuple[int, int, int]:
            hx = str(hx).lstrip("#")
            return tuple(int(hx[i:i + 2], 16) for i in (0, 2, 4))
        
        def norm(array: List[List[str]]) -> List[List[str]]:
            """
            Takes an array of weights between 0 and 1, and normalizes them between the background color and the maximum (255, 255, 255).
            """
            back_rgb = hex_to_rgb(config.background_color)
            normalize_factor = 255 - max(back_rgb)
            
            new_array = []
            for sub_array in array:
                new_sub_array = []
                for value in sub_array:
                    new_sub_array.append(rgb_to_hex((
                        int(float(value) * normalize_factor + back_rgb[0]),
                        int(float(value) * normalize_factor + back_rgb[1]),
                        int(float(value) * normalize_factor + back_rgb[2]),
                    )))
                new_array.append(new_sub_array)
            return new_array
        
        categories = [
            "Police Officer III", "Master Police Officer", "Correctional Officer III", 
            "Fire/Rescue Captain", "Correctional Officer III", "Police Officer I"
        ]
        arr = np.array(categories).reshape(-1, 1)
        unique_categories = pd.Series(categories).unique().tolist()
        
        ohe = OneHotEncoder()
        sim = dc.SimilarityEncoder(ngram_range=(3, 3))
        ohe_arr = ohe.fit_transform(arr)
        sim_arr = sim.fit_transform(arr)
        ohe_categories = ohe.categories_[0].tolist()
        sim_categories = sim.categories_[0].tolist()
        assert ohe_categories == sim_categories
        
        table = Table(
            [
                ["0", "0", "0", "0", "0"],
                ["0", "0", "0", "0", "0"],
                ["0", "0", "0", "0", "0"],
                ["0", "0", "0", "0", "0"],
                ["0", "0", "0", "0", "0"],
                ["0", "0", "0", "0", "0"],
            ],
            h_buff=0.5, v_buff=2,
            include_outer_lines=False,
            row_labels=list(map(Text, categories)),
            col_labels=list(map(Text, ohe_categories)),
            top_left_entry=Text("Empty", weight=BOLD),
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).center().shift(DOWN * 0.5)
        
        ohe_table = Table(
            get_rows(ohe_arr.toarray(), i=1),
            h_buff=0.5, v_buff=2,
            include_outer_lines=False,
            row_labels=list(map(Text, categories)),
            col_labels=list(map(Text, ohe_categories)),
            top_left_entry=Text("One-hot", weight=BOLD),
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).center().shift(DOWN * 0.5).set_z_index(table.z_index - 1)
        ohe_table_colors = VGroup(*[
            ohe_table.get_highlighted_cell((i + 2, j + 2), color=value)
            for i, sub_array in enumerate(norm(get_rows(ohe_arr.toarray(), i=1)))
            for j, value in enumerate(sub_array)
        ]).set_z_index(ohe_table.z_index - 2)
        
        sim_table = Table(
            get_rows(sim_arr),
            h_buff=0.5, v_buff=2,
            include_outer_lines=False,
            row_labels=list(map(Text, categories)),
            col_labels=list(map(Text, sim_categories)),
            top_left_entry=Text("Similarity (ngram)", weight=BOLD),
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).center().shift(DOWN * 0.5).set_z_index(ohe_table.z_index - 1)
        sim_table_colors = VGroup(*[
            ohe_table.get_highlighted_cell((i + 2, j + 2), color=value)
            for i, sub_array in enumerate(norm(get_rows(sim_arr)))
            for j, value in enumerate(sub_array)
        ]).set_z_index(sim_table.z_index - 2)
        
        # Fix alignment of label
        table.get_labels()[2].shift(DOWN * 0.025)
        ohe_table.get_labels()[2].shift(DOWN * 0.025)
        sim_table.get_labels()[2].shift(DOWN * 0.025)
        
        self.show_title()
        
        self.play(Create(table))
        self.wait(0.25)
        
        self.play(FadeIn(ohe_table_colors), ReplacementTransform(table, ohe_table))
        self.wait(2)
        
        self.play(FadeOut(ohe_table_colors), FadeIn(sim_table_colors), ReplacementTransform(ohe_table, sim_table))
        self.wait(2)
        
        final_sim_table = VGroup(sim_table, sim_table_colors)
        
        # Pick a couple of samples for an example
        sample1 = sim_table.get_rows()[1][0].copy()
        sample2 = sim_table.get_rows()[0][1].copy()
        
        # Move the whole table on the side
        final_sim_table.save_state()
        final_sim_table.generate_target()
        final_sim_table.target.scale(0.6)
        final_sim_table.target.shift(LEFT * 3)
        self.play(
            MoveToTarget(final_sim_table),
            sample1.animate.move_to(RIGHT * 4 + UP * 1),
            sample2.animate.move_to(RIGHT * 4 + DOWN * 1),
        )
        # Bold both, so the decomposition looks better
        sample1_bold = VGroup(sample1.copy(), sample1[1:-1].copy(), sample1[2:-2].copy())
        sample2_bold = VGroup(sample2.copy(), sample2[1:-1].copy(), sample2[2:-2].copy())
        self.play(FadeIn(sample1_bold), FadeIn(sample2_bold))
        self.wait(5)
        
        sample1_groups = [
            sample1[0:3], sample1[1:4], sample1[2:5], sample1[3:6],
            sample1[4:6], sample1[5:7], sample1[6:8], sample1[6:9],
            sample1[7:10], sample1[8:11], sample1[9:12], sample1[10:13],
            sample1[11:13], sample1[12:14], sample1[13:15], sample1[13:16],
        ]
        sample1_3grams = Table(
            [
                ['pol', 'oli', 'lic', 'ice'],
                ['ce ', 'e o', ' of', 'off'],
                ['ffi', 'fic', 'ice', 'cer'],
                ['er ', 'r I', ' II', 'III'],
            ],
            h_buff=0.5, v_buff=0.5,
            include_outer_lines=True,
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).next_to(sample1, DOWN)
        
        sample2_groups = [
            sample2[0:3], sample2[1:4], sample2[2:5], sample2[3:6],
            sample2[4:7], sample2[5:8], sample2[6:9], sample2[7:10],
            sample2[8:11], sample2[9:12], sample2[10:12], sample2[11:13],
            sample2[12:14], sample2[12:15], sample2[13:16], sample2[14:17],
            sample2[15:18], sample2[16:19], sample2[17:19], sample2[18:20],
            sample2[19:21], sample2[19:22],
        ]
        sample2_3grams = Table(
            [
                ['cor', 'orr', 'rre', 'rec'],
                ['ect', 'cti', 'tio', 'ion'],
                ['ona', 'nal', 'al ', 'l o'],
                [' of', 'off', 'ffi', 'fic'],
                ['ice', 'cer', 'er ', 'r I'],
                [' II', 'III', '   ', '   '],
            ],
            h_buff=0.5, v_buff=0.5,
            include_outer_lines=True,
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).next_to(sample2, DOWN)
        
        # Align ngrams
        for _table in [sample1_3grams, sample2_3grams]:
            for _row in _table.get_rows():
                # Align all ngrams to the first one in the row
                for _ngram in _row[1:]:
                    if _ngram.text:  # If ngram is not empty
                        _ngram.align_to(_row[0], DOWN)
        
        self.play(Create(sample1_3grams), Create(sample2_3grams))
        
        # Animate ngrams to their individual positions
        group_1 = []
        for i, _group in enumerate(sample1_groups):
            group_1.append(ReplacementTransform(_group.copy(), sample1_3grams[0][i]))
        group_2 = []
        for i, _group in enumerate(sample2_groups):
            group_2.append(ReplacementTransform(_group.copy(), sample2_3grams[0][i]))
        
        self.remove(sample1_bold, sample2_bold)
        self.play(LaggedStart(*group_1), LaggedStart(*group_2), run_time=4)
        self.wait(1)
        
        decomposition = VGroup(sample1, sample1_3grams, sample2, sample2_3grams)
        self.play(decomposition.animate.shift(LEFT * 1.5))
        
        common_3grams_head = Text("Common 3-grams").scale(0.3).move_to(RIGHT * 5.5 + UP * 1)
        common_3grams = Table(
            [
                [' of', 'off', 'ffi', 'fic'],
                ['ice', 'cer', 'er ', 'r I'],
                [' II', 'III', '   ', '   '],
            ],
            h_buff=0.5, v_buff=0.5,
            include_outer_lines=True,
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).next_to(common_3grams_head, DOWN)
        
        all_3grams_head = Text("All 3-grams").scale(0.3).move_to(RIGHT * 5.5 + DOWN * 1)
        all_3grams = Table(
            [
                ['pol', 'oli', 'lic', 'ice'],
                ['ce ', 'e o', 'cor', 'orr'],
                ['rre', 'rec', 'ect', 'cti'],
                ['tio', 'ion', 'ona', 'nal'],
                ['al ', 'l o', ' of', 'off'],
                ['ffi', 'fic', 'ice', 'cer'],
                ['er ', 'r I', ' II', 'III'],
            ],
            h_buff=0.5, v_buff=0.5,
            include_outer_lines=True,
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.3).next_to(all_3grams_head, DOWN)
        
        # Align ngrams
        for _table in [common_3grams, all_3grams]:
            for _row in _table.get_rows():
                # Align all ngrams to the first one in the row
                for _ngram in _row[1:]:
                    if _ngram.text:  # If ngram is not empty
                        _ngram.align_to(_row[0], DOWN)
        
        self.play(
            Create(common_3grams_head),
            Create(common_3grams),
            Create(all_3grams_head),
            Create(all_3grams),
        )
        self.wait(0.5)
        
        # Construct a custom equation
        ratio = VGroup(
            VGroup(  # Frac
                MathTex("10"),
                Line(stroke_width=1.5).scale(0.25),
                MathTex("28"),
            ).arrange(DOWN),
            MathTex(r"\approx 0.354"),
        ).scale(0.8).arrange(RIGHT).move_to(RIGHT * 5 + DOWN * 1)
        
        self.play(
            FadeOut(common_3grams_head),
            ReplacementTransform(common_3grams, ratio[0][0]),
            Create(ratio[0][1]),
            FadeOut(all_3grams_head),
            ReplacementTransform(all_3grams, ratio[0][2]),
        )
        
        self.play(Create(ratio[1]))

        self.play(
            FadeOut(ratio[0]),
            FadeOut(decomposition),
            ReplacementTransform(ratio[1], sim_table.get_rows()[1][1]),
        )
        self.wait(0.25)
        
        self.play(Restore(final_sim_table))
        self.wait(1)
        
        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                         