In [1]:
import numpy as np
import pandas as pd

import seaborn
import matplotlib.pyplot as plt

from sklearn.model_selection import cross_val_score
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder

from dirty_cat.datasets import fetch_employee_salaries
from dirty_cat import SimilarityEncoder, MinHashEncoder, GapEncoder

from manim import *
from manim.utils.color import Colors
config.background_color = "#060d14"

In [2]:
def benchmark():
    employee_salaries = fetch_employee_salaries()
    X = employee_salaries.X
    y = employee_salaries.y

    X['date_first_hired'] = pd.to_datetime(X['date_first_hired'])
    X['year_first_hired'] = X['date_first_hired'].apply(lambda x: x.year)
    # Get mask of rows with missing values in gender
    mask = X.isna()['gender']
    # And remove the lines accordingly
    X.dropna(subset=['gender'], inplace=True)
    y = y[~mask]
    
    one_hot = OneHotEncoder(handle_unknown='ignore', sparse=False)

    encoders = {
        'Similarity': SimilarityEncoder(similarity='ngram'),
        'Gamma-Poisson': GapEncoder(n_components=100),
        'One-Hot': one_hot,
        'Min-Hash': MinHashEncoder(n_components=100),
        'Ordinal': OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=np.nan),
    }

    all_scores = dict()

    for name, method in encoders.items():
        encoder = make_column_transformer(
            (one_hot, ['gender', 'department_name', 'assignment_category']),
            ('passthrough', ['year_first_hired']),
            # Last but not least, our dirty column
            (method, ['employee_position_title']),
            remainder='drop',
        )

        pipeline = make_pipeline(encoder, HistGradientBoostingRegressor())
        scores = cross_val_score(pipeline, X, y)
        print(f'{name} encoding')
        print(f'r2 score:  mean: {np.mean(scores):.3f}; '
              f'std: {np.std(scores):.3f}\n')
        all_scores[name] = scores
        
        plt.figure(figsize=(6, 3))
        
    ax = seaborn.boxplot(data=pd.DataFrame(all_scores), orient='h')
    plt.ylabel('Encoding\n', size=20)
    plt.xlabel('Prediction accuracy', size=20)
    plt.yticks(size=20)
    plt.tight_layout()
    plt.savefig('assets/benchmark_results.png')

#benchmark()

In [5]:
%%manim -qh -v WARNING BenchmarkAnimation

class BenchmarkAnimation(Scene):
    
    def show_title(self):
        title = Text("Benchmarks")
        self.play(Create(title))
        self.play(title.animate.to_edge(UP))
        return title
    
    def construct(self):
        title = self.show_title()
        self.wait(2)
        
        methods = VGroup(
            Text("We will compare the performances of"),
            Text("• SimilarityEncoder"),
            Text("• GapEncoder (Gamma-Poisson)"),
            Text("• OneHotEncoder"),
            Text("• MinHashEncoder (from dirty_cat)"),
            Text("• OrdinalEncoder (a classic encoder)"),
        ).scale(0.5).arrange(DOWN).move_to(LEFT * 3.5 + UP * 0.25)
        self.play(Create(methods[0]))
        # Align methods properly and show
        for method in methods[1:]:
            method.align_to(methods[0], LEFT).shift(RIGHT * 0.25)
            self.play(Create(method))
            self.wait(0.1)
        self.wait(0.2)
        
        benchmark_result = ImageMobject("assets/benchmark_results_inverted.png").scale(2).move_to(RIGHT * 3.75 + UP * 0.25)
        
        self.play(FadeIn(benchmark_result))
        self.wait(0.2)
        
        takeaway = Text("Methods designed for dirty data outperform classical ones !").scale(0.6).move_to(DOWN * 2.5)
        
        self.play(Create(takeaway))
        self.wait(3)
        
        # Fade everything out except title
        self.play(
            FadeOut(methods),
            FadeOut(benchmark_result),
            FadeOut(takeaway),
        )
        self.wait()
        
        usage_title = Text("Usage").to_edge(UP)
        
        ex_ohe = ImageMobject("assets/ex_ohe.png").move_to(LEFT * 3.5)
        arrow = Arrow(start=LEFT, end=RIGHT)
        ex_sim = ImageMobject("assets/ex_sim.png").move_to(RIGHT * 3.5)
        
        comparison = Group(ex_ohe, arrow, ex_sim).shift(DOWN * 0.5)
        
        self.play(Transform(title, usage_title))
        self.wait()
        self.play(
            FadeIn(ex_ohe),
            FadeIn(ex_sim),
            Create(arrow),
        )
        self.wait(3)
        
        # Fade everything out
        self.play(*[FadeOut(mob) for mob in self.mobjects])

                                                                                                                                