In [1]:
import numpy as np
import pandas as pd
import dirty_cat as dc
from sklearn.preprocessing import OneHotEncoder

from manim import *
from manim.utils.color import Colors
config.background_color = "#060d14"

In [2]:
%%manim -qh -v WARNING FirstScreenAnimation

class FirstScreenAnimation(Scene):
    def construct(self):
        title = VGroup(
            Text("dirty_cat : a Python package", slant=ITALIC),
            Text("for Machine Learning on Dirty Categorical Data", slant=ITALIC),
        ).arrange(DOWN).scale(0.8).shift(DOWN * 0.25)
        
        banner = ImageMobject("assets/banner.png").scale(0.13).to_edge(UP).shift(UP * 0.5)
        
        dc_logo = ImageMobject("assets/dirty_cat_outline.png").scale(0.13).move_to(UP * 1.8 + LEFT * 2.25)
        inria_logo = ImageMobject("assets/inria_logo.png").scale(0.13).move_to(UP * 1.8 + RIGHT * 2.25)
        
        images = Group(banner, dc_logo, inria_logo)
        
        author1_name = Text("Patricio Cerda").scale(0.4)
        author1_note1 = Text("Inria, Parietal", slant=ITALIC).scale(0.25)
        author1_note2 = Text("Université Paris-Saclay", slant=ITALIC).scale(0.25)
        author1 = VGroup(author1_name, author1_note1, author1_note2).arrange(DOWN).shift(LEFT * 6)
        
        author2_name = Text("Gaël Varoquaux").scale(0.4)
        author2_note1 = Text("Inria, Parietal", slant=ITALIC).scale(0.25)
        author2 = VGroup(author2_name, author2_note1).arrange(DOWN).align_to(author1_name, UP).shift(LEFT * 3)
        
        author3_name = Text("Lilian Boulard").scale(0.4).align_to(author1_name, UP)
        author3_note1 = Text("Inria, Parietal", slant=ITALIC).scale(0.25)
        author3_note2 = Text("Université Paris-Saclay", slant=ITALIC).scale(0.25)
        author3_note3 = Text("speaker", slant=ITALIC).scale(0.25)
        author3 = VGroup(author3_name, author3_note1, author3_note2, author3_note3).arrange(DOWN).align_to(author1_name, UP)
        author3_note3.next_to(author3_name, UP)
        
        contributors_head = Text("With contributions from").scale(0.4).align_to(author1_name, UP).shift(RIGHT * 4 + UP * 0.4)
        contributors = VGroup(
            # Two VGroups for two columns
            VGroup(  # Left columns
                Text("Pierre Glaser"),
                Text("Thomas Schmitt"),
                Text("Léo Grinsztajn"),
                Text("Jérôme Dockès"),
            ).arrange(DOWN, buff=0.4),
            VGroup(  # Right columns
                Text("Alexis Cvetkov-Iliev"),
                Text("Amanda Dsouza"),
                Text("Nicolas Gensollen"),
                Text("and many others !", slant=ITALIC),
            ).arrange(DOWN, buff=0.4),
        ).scale(0.3).arrange(RIGHT).next_to(contributors_head, DOWN)
        
        # Align line by line
        for i, value in enumerate(contributors[1]):
            value.align_to(contributors[0][i], UP)
        
        # Fix name alignment
        contributors[0][2].shift(DOWN * 0.015)
        contributors[1][2].shift(DOWN * 0.025)
        
        authors = VGroup(author1, author2, author3, contributors_head, contributors).move_to(DOWN * 2.25)
        
        # Fade everything in
        self.play(FadeIn(title), FadeIn(images), FadeIn(authors))
        self.wait(13)
        
        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                    

In [11]:
%%manim -qh -v WARNING IntroAnimation

class IntroAnimation(Scene):
    def construct(self):
        table = Table(
            [
                ["Starks-Bey, Prince", "Police Officer III", "2005", "89 620"],
                ["Dolan, Thomas", "Master Police Officer", "1986", "97 392"],
                ["Copas, Robert", "Correctional Officer III", "2009", "58 720"],
                ["Blinkhorn, Russell", "Fire/Rescue Captain", "1998", "110 229"],
                ["Gaston, Birdie", "Correctional Officer III", "2001", "77 328"],
                ["Jang, Jiwoo", "Police Officer I", "2022", "?"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[
                Text("Full name", weight=BOLD), 
                Text("Job title", weight=BOLD), 
                Text("Year hired", weight=BOLD), 
                Text("Annual salary ($)", weight=BOLD),
            ],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4)
        
        job_titles = Table(
            [
                ["Police Officer III"],
                ["Master Police Officer"],
                ["Correctional Officer III"],
                ["Fire/Rescue Captain"],
                ["Correctional Officer III"],
                ["Police Officer I"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=True,
            col_labels=[Text("Job title", weight=BOLD)],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).shift(LEFT * 3)
        
        ohe_table = Table(
            [
                ["1", "0", "0", "0", "0"],
                ["0", "1", "0", "0", "0"],
                ["0", "0", "1", "0", "0"],
                ["0", "0", "0", "1", "0"],
                ["0", "0", "1", "0", "0"],
                ["0", "0", "0", "0", "1"],
            ],
            h_buff=0.8, v_buff=0.50,
            include_outer_lines=True,
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).align_to(job_titles, DOWN).shift(RIGHT * 3)
        ohe_arrow = Arrow(
            start=DOWN * 0.21 + LEFT * 1, 
            end=DOWN * 0.21 + RIGHT * 1.5,
        )
        ohe_caption = Text("One-hot").scale(0.4).next_to(ohe_table, UP)

        # Unpack columns, and filtering out the last row
        col_names, col_jobs, col_year, col_pay = map(lambda col: col[:-1], table.get_columns())
        # Get the last row and its line
        new_hired_row = table.get_rows()[-1]
        first_vline = table.get_vertical_lines()[0]
        *first_lines, last_line = table.get_horizontal_lines()
        first_lines = VGroup(*first_lines)  # Convert list of lines to vgroup
        
        survey_title = Text("Let's conduct a survey on salaries !").to_edge(UP)
        ml_title = Text("Machine learning !").center().to_edge(UP)
        dataset_info = Text("Permanent employees of Montgomery County, MD, 2016", 
                            slant=ITALIC).scale(0.2).to_edge(DOWN + RIGHT)
        
        self.play(Create(first_lines), Create(col_names), Create(survey_title))
        self.play(Create(dataset_info))
        self.wait(4.5)
        self.play(Create(col_jobs))
        self.play(Create(col_year))
        self.play(Create(col_pay))
        self.wait(2)
        
        self.play(Create(last_line), Create(new_hired_row))
        self.wait(3)
        
        surveyed = SurroundingRectangle(table.get_rows()[1:-1], color=BLUE)
        surveyed.rotate(180 * DEGREES)
        self.play(Create(surveyed))
        surveyed.invert()
        self.play(Uncreate(surveyed.scale(-1)))
        self.wait(1)
        
        # Before shifting the table, remove the vertical lines
        # (otherwise they are drawn)
        table.remove(*table.get_vertical_lines())
        # Also shift the orphan line
        first_vline.shift(DOWN)
        self.play(FadeOut(survey_title))
        self.play(table.animate.shift(DOWN), Create(ml_title))
        self.wait(3.5)
        
        # Get column groups
        name_col, *feature_cols, target_col = table.get_columns()
        feature_cols = VGroup(*feature_cols)  # Convert list of columns to vgroup
        # Create braces
        features_brace = Brace(feature_cols, sharpness=1.0).next_to(feature_cols, UP).rotate(180 * DEGREES)
        target_brace = Brace(target_col, sharpness=1.0).next_to(target_col, UP).rotate(180 * DEGREES)
        # Create text
        feat_txt = Text("Features", slant=ITALIC).next_to(features_brace, UP).scale(0.5)
        trgt_txt = Text("Target", slant=ITALIC).next_to(target_brace, UP).scale(0.5).align_to(feat_txt, UP)
        # Get missing target value
        missing = table.get_columns()[-1][-1]
        missing_rec = SurroundingRectangle(missing, color=BLUE)
        # Display
        self.play(Create(features_brace), Create(feat_txt), Create(first_vline))
        self.play(Create(target_brace), Create(trgt_txt), 
                  Create(missing_rec))
        self.wait(10)
        
        enc_title = VGroup(
            Text("We must encode our categorical features "),
            Text("into numerical features !"),
        ).scale(0.8).arrange(DOWN).to_edge(UP)
        
        self.play(table.get_columns()[1][1:].animate.set_color(RED), ReplacementTransform(ml_title, enc_title))
        self.wait(2)
        
        self.play(
            # Fade everything out except the job title column and the title
            *[FadeOut(mob) for mob in [
                name_col, feature_cols[1], target_col,
                features_brace, target_brace,
                feat_txt, trgt_txt,
                missing_rec,
                first_vline, table.get_horizontal_lines(),
                dataset_info,
            ]],
        )
        self.wait()
        
        # Move the job titles to the left
        self.play(Transform(feature_cols[0], job_titles.get_columns()[0]), run_time=0.75)
        # Add the one-hot table
        self.play(FadeIn(job_titles.get_horizontal_lines()), FadeIn(job_titles.get_vertical_lines()), run_time=0.25)
        self.play(Create(ohe_arrow), Create(ohe_table), Create(ohe_caption))
        self.wait()
        
        # Declare colors
        line_to_color = [
            Colors.blue.value,
            Colors.green.value,
            Colors.gold.value,
            Colors.red.value,
            Colors.gold.value,
            Colors.maroon.value,
        ]
        zeros_color = Colors.light_gray.value
        
        # Construct the color data
        table_background_cells = VGroup()
        ohe_zero_texts = VGroup()
        ohe_background_cells = VGroup()
        for i, (line_table, line_ohe) in enumerate(zip(job_titles.get_rows()[1:], ohe_table.get_rows())):
            table_background_cells.add(job_titles.get_highlighted_cell((i + 2, 1), color=line_to_color[i]))
            for j, txt in enumerate(line_ohe):
                if txt.text == '1':
                    ohe_background_cells.add(ohe_table.get_highlighted_cell((i + 1, j + 1), color=line_to_color[i]))
                elif txt.text == '0':
                    ohe_zero_texts.add(ohe_table.get_rows()[i][j])
                else:
                    print('OHE table value is neither 1 or 0')
        
        # Add all the colors
        feature_cols[0].add_to_back(table_background_cells)
        ohe_table.add_to_back(ohe_background_cells)
        self.play(
            # Add the background cells
            FadeIn(table_background_cells), FadeIn(ohe_background_cells),
            # And color the zeros
            ohe_zero_texts.animate.set_color(zeros_color),
        )
        self.wait(7)
        
        below_text = VGroup(
            Text("One-hot doesn't consider similarities !"),
            VGroup(
                Text("We can assume a "),
                Text("Police Officer I", color=BLUE),
                Text(" earns less than a "),
                Text("Police Officer III", color=BLUE),
            ).arrange(RIGHT),
            Text("(these is a relation between the two, and we want the model to understand that)", slant=ITALIC),
            Text("We have the same issue when there are typos and variations in the feature"),
        ).scale(0.4).arrange(DOWN).move_to(DOWN * 3)
        
        self.play(Create(below_text[0]))
        self.wait(5)
        self.play(Create(below_text[1]))
        self.wait(0.5)
        self.play(Create(below_text[2]))
        self.wait(3)
        
        self.play(
            FadeOut(ohe_table),
            FadeOut(job_titles),
            FadeOut(feature_cols[0]),
            FadeOut(table_background_cells),
            FadeOut(ohe_background_cells),
            FadeOut(ohe_arrow),
            FadeOut(ohe_caption),
        )
        self.wait()
        
        typos_table = Table(
            [
                ["Officer"],
                ["Offficer"],
                ["Off."],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[Text("Typos", weight=BOLD)],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).move_to(LEFT * 3 + UP)
        
        typos_ohe = Table(
            [
                ["1", "0", "0"],
                ["0", "1", "0"],
                ["0", "0", "1"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[
                Text("Officer", weight=BOLD),
                Text("Offficer", weight=BOLD),
                Text("Off.", weight=BOLD),
            ],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).move_to(RIGHT * 3 + UP)
        
        variations_table = Table(
            [
                ["Policier"],
                ["Policière"],
                ["Policier.ère"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[Text("Variations", weight=BOLD)],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).move_to(LEFT * 3 + DOWN)
        
        variations_ohe = Table(
            [
                ["1", "0", "0"],
                ["0", "1", "0"],
                ["0", "0", "1"],
            ],
            h_buff=0.8, v_buff=0.5,
            include_outer_lines=False,
            col_labels=[
                Text("Policier", weight=BOLD),
                Text("Policière", weight=BOLD),
                Text("Policier.ère", weight=BOLD),
            ],
            element_to_mobject=Text,
            line_config={"stroke_width": 1}
        ).scale(0.4).move_to(RIGHT * 3 + DOWN)
        
        self.play(Create(typos_table))
        self.play(Create(variations_table))
        self.wait()
        self.play(
            Create(typos_ohe), 
            FadeIn(ohe_arrow.shift(UP * 0.25 + LEFT)),
            Create(variations_ohe),
        )
        self.play(Create(below_text[3]))
        self.wait(4)

        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                                                              

In [10]:
%%manim -qh -v WARNING MethodsIntroAnimation

class MethodsIntroAnimation(Scene):
    def construct(self):
        title = Text("dirty_cat", slant=ITALIC).scale(0.8)
        
        paper1 = Text("Similarity encoding for learning with dirty categorical variables, 2018", slant=ITALIC).scale(0.5).move_to(UP * 1.5 + LEFT)
        paper2 = Text("Encoding high-cardinality string categorical variables, 2019", slant=ITALIC).scale(0.5).next_to(paper1, DOWN).align_to(paper1, LEFT)
        papers = VGroup(paper1, paper2)
        
        author1_pic = ImageMobject("assets/patricio_cerda.jpg").scale(0.63)
        author1_name = Text("Patricio Cerda", slant=ITALIC).scale(0.4)
        author1 = Group(author1_pic, author1_name).arrange(DOWN).shift(LEFT * 4)
        
        author2_pic = ImageMobject("assets/gael_varoquaux.jpg").scale(0.765)
        author2_name = Text("Gaël Varoquaux", slant=ITALIC).scale(0.4)
        author2 = Group(author2_pic, author2_name).arrange(DOWN).align_to(author1, UP)
        
        author3_pic = ImageMobject("assets/balazs_kegl.jpg").scale(0.43)
        author3_name = Text("Balázs Kégl", slant=ITALIC).scale(0.4)
        author3 = Group(author3_pic, author3_name).arrange(DOWN).align_to(author1, UP).shift(RIGHT * 4)
        
        authors = Group(author1, author2, author3).shift(DOWN * 1.5)
        
        github_logo = ImageMobject("assets/github_logo_outline.png").scale(0.3).shift(LEFT * 3).shift(UP * 0.2)
        python_logo = ImageMobject("assets/python_logo.png").scale(0.2).shift(RIGHT * 3).shift(UP * 0.1)
        
        header = Group(title, github_logo, python_logo)
        
        self.play(Create(title), FadeIn(github_logo), FadeIn(python_logo))
        header.generate_target()
        header.target.to_edge(UP)
        header.target.shift(UP * 0.1)
        self.play(MoveToTarget(header))
        self.wait()
        self.play(Create(papers), run_time=1.5)
        self.play(FadeIn(authors))
        self.wait(10)
        
        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])
        self.wait(0.5)

                                                                                             