In [0]:
!pip install -U numpy
!pip install pandas==0.24.2

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
"""Make competition statistics."""
import zipfile
import io
from collections import Counter
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import ks_2samp
from sklearn.metrics import (
    accuracy_score, f1_score, jaccard_score, precision_score,
    recall_score)
from sklearn.preprocessing import MultiLabelBinarizer

FONT_SIZE = 20
matplotlib.rc('font', size=FONT_SIZE)
matplotlib.rc('axes', labelsize=28)
sns.set_palette('colorblind')

NOTEBOOK_DIR = ('drive/My Drive/Cell Profiling/Challenges/'
                'HPA challenge paper-Ouyang et al 2019/notebooks')
TRUTH_PATH = Path(NOTEBOOK_DIR) / 'kaggle_derived_solution.csv'
MASTER_TEST_PATH = Path(NOTEBOOK_DIR) / 'scv_test.csv'
MASTER_TRAIN_PATH = Path(NOTEBOOK_DIR) / 'kaggle_master_training.csv'
TOP_TEN_TEAMS_PATH = (Path(NOTEBOOK_DIR) /
                 'kaggle_competition_1-10_best_submissions.zip')
ALL_TEAMS_PATH = (Path(NOTEBOOK_DIR) /
                 'kaggle_competition_all_best_submissions.zip')
TRAIN_PATH = Path(NOTEBOOK_DIR) / 'kaggle_train.csv'
HPA_PATH = Path(NOTEBOOK_DIR) / 'hpa_v18_1_all_no_dups_merged.csv'

COL_CELL_TYPE = 'cell_type'
COL_ID = 'id'
COL_EXPECTED = 'expected'
COL_PREDICTED = 'predicted'
COL_TEAM = 'team'

TOP_10_TEAMS = range(0, 10)
TOP_100_TEAMS = range(10, 100)
TOP_500_TEAMS = range(100, 500)
TOP_2137_TEAMS = range(500, 2137)

INVITED_TEAMS = list(range(5)) + [7, 9, 15, 38]

LABEL_INDEX = {
    '0': 'Nucleoplasm',
    '1': 'Nuclear membrane',
    '10': 'Lysosomes',
    '11': 'Intermediate filaments',
    '12': 'Actin filaments',
    '13': 'Focal adhesion sites',
    '14': 'Microtubules',
    '15': 'Microtubule ends',
    '16': 'Cytokinetic bridge',
    '17': 'Mitotic spindle',
    '18': 'Microtubule organizing center',
    '19': 'Centrosome',
    '2': 'Nucleoli',
    '20': 'Lipid droplets',
    '21': 'Plasma membrane',
    '22': 'Cell Junctions',
    '23': 'Mitochondria',
    '24': 'Aggresome',
    '25': 'Cytosol',
    '26': 'Cytoplasmic bodies',
    '27': 'Rods & Rings',
    '3': 'Nucleoli fibrillar center',
    '4': 'Nuclear speckles',
    '5': 'Nuclear bodies',
    '6': 'Endoplasmic reticulum',
    '7': 'Golgi apparatus',
    '8': 'Peroxisomes',
    '9': 'Endosomes'
}

ABB_LABEL_INDEX = {
    "0": "Nucleoplasm",
    "1": "N. membrane",
    "2": "Nucleoli",
    "3": "N. fibrillar c.",
    "4": "N. speckles",
    "5": "N. bodies",
    "6": "ER",
    "7": "Golgi app.",
    "8": "Peroxisomes",
    "9": "Endosomes",
    "10": "Lysosomes",
    "11": "Int. fil.",
    "12": "Actin fil.",
    "13": "F. a. sites",
    "14": "Microtubules",
    "15": "M. ends",
    "16": "Cyt. bridge",
    "17": "Mitotic spindle",
    "18": "MTOC",
    "19": "Centrosome",
    "20": "Lipid droplets",
    "21": "PM",
    "22": "C. Junctions",
    "23": "Mitochondria",
    "24": "Aggresome",
    "25": "Cytosol",
    "26": "C. bodies",
    "27": "Rods & Rings"
}

INVITED_TEAM_NAMES = {
    0: '0001_bestfitting',
    1: '0002_wair',
    2: '0003_pudae',
    3: '0004_wienerschnitzelgemeinschaft',
    4: '0005_vpp',
    7: '0008_one_more_layer_of_stacking',
    9: '0010_conv_is_all_u_need',
    15: '0016_ntu_mira',
    38: '0039_random_walk',
}

INVITED_TEAM_SHORT_NAMES = {
    '0001_bestfitting': 'Team 01',
    '0002_wair': 'Team 02',
    '0003_pudae': 'Team 03',
    '0004_wienerschnitzelgemeinschaft': 'Team 04',
    '0005_vpp': 'Team 05',
    '0008_one_more_layer_of_stacking': 'Team 08',
    '0010_conv_is_all_u_need': 'Team 10',
    '0016_ntu_mira': 'Team 16',
    '0039_random_walk': 'Team 39'
}

COLORS = sns.color_palette()

INVITED_TEAM_COLORS_OLD = {
    '0001_bestfitting': f'rgb{COLORS[0]}',
    '0002_wair': f'rgb{COLORS[1]}',
    '0003_pudae': f'rgb{COLORS[2]}',
    '0004_wienerschnitzelgemeinschaft': f'rgb{COLORS[3]}',
    '0005_vpp': f'rgb{COLORS[4]}',
    '0008_one_more_layer_of_stacking': f'rgb{COLORS[5]}',
    '0010_conv_is_all_u_need': f'rgb{COLORS[6]}',
    '0016_ntu_mira': f'rgb{COLORS[8]}',
    '0039_random_walk': f'rgb{COLORS[9]}',
}

INVITED_TEAM_COLORS = {
    '0001_bestfitting': 0,
    '0002_wair': 1,
    '0003_pudae': 2,
    '0004_wienerschnitzelgemeinschaft': 3,
    '0005_vpp': 4,
    '0008_one_more_layer_of_stacking': 5,
    '0010_conv_is_all_u_need': 6,
    '0016_ntu_mira': 8,
    '0039_random_walk': 9,
}

INVITED_TEAM_SHORT_NAMES_COLORS = {
    'Team 01': 0,
    'Team 02': 1,
    'Team 03': 2,
    'Team 04': 3,
    'Team 05': 4,
    'Team 08': 5,
    'Team 10': 6,
    'Team 16': 8,
    'Team 39': 9,
}


def get_score(truth_path, prediction_path, master_path, train_path, hpa_path):
    """Get score instance for calculation."""
    # Sort ids for teams and truth.
    truth = load_data(truth_path)
    master = load_data(master_path)
    truth = (
        truth
        .pipe(tidy_cols)
        .sort_values(by=[COL_ID])
        .reset_index(drop=True)
        .pipe(add_cell_type_scv, master)
        .assign(expected=lambda df: (df.expected
                                       .str.split(' ')
                                       .apply(pd.to_numeric)))
        .set_index('id')
    )
    teams = load_teams_data(prediction_path)
    teams = (
        teams
         .sort_values([(teams.columns.levels[0][0], COL_ID)])
         .loc[:, (slice(None), COL_PREDICTED)]
         .stack()
         .pipe(tidy_cols)
         .sort_index(axis=1)
         .reset_index(drop=True)
         .set_index(truth.index)
         .fillna('')
         .apply(
             lambda x: x.str.split(' ').apply(
                 lambda x: [int(float(i)) for i in x if i != '']))
    )
    
    train = load_data(train_path)
    train = train.pipe(parse_data, 'target')
    
    hpa = (load_data(hpa_path)
           .drop(columns='Unnamed: 0')
           .rename(columns={'location': 'hpa'})
           .loc[lambda df: ~df.hpa.str.contains('-1')])
    hpa = hpa.pipe(parse_data, 'hpa')
    
    private_slice = truth.usage == 'Private'
    truth = truth.loc[private_slice]
    teams = teams.loc[private_slice]
    return Score(truth, teams, train, hpa)


def parse_data(data, data_column):
    """Parse and return tidy data."""
    return (data
            .pipe(tidy_cols)
            .sort_values(by=[COL_ID])
            .reset_index(drop=True)
            .assign(**{data_column: lambda df: (df
                                         .loc[:, data_column]
                                         .str.split(' ')
                                         .apply(pd.to_numeric))}))


def _load_teams(
        path, teams=None, error=None, compression=None, team_name=None):
    """Return a list of tuples with team name and data."""
    if teams is None:
        teams = {}
    with zipfile.ZipFile(path) as zip_handle:
        zip_members = zip_handle.infolist()
        for zip_info in zip_members:
            is_dir = zip_info.is_dir()
            if is_dir:
                continue
            if team_name is None:
                name = zip_info.filename.split('/')[-2]
                if len(name) < 6:
                    print('Incorrect name format found for team', name)
                    continue
            else:
                name = team_name
            csv_zip = zip_handle.open(zip_info)
            try:
                data = pd.read_csv(
                    csv_zip, compression=compression, header=0, dtype=str)
            except UnicodeError:
                if compression == 'infer':
                    print('Failed to read file:', name)
                    continue
                csv_zip = zip_handle.open(zip_info)
                try:
                    data = pd.read_csv(
                        csv_zip, compression='gzip', header=0, dtype=str)
                except OSError as exc:
                    if error:
                        print('Failed to read file:', name)
                    continue
                    csv_zip = zip_handle.open(zip_info)
                    csv_zip = io.BytesIO(csv_zip.read())
                    teams = _load_teams(
                        csv_zip, teams=teams, error=exc, compression='infer',
                        team_name=name)
                    data = teams.pop(name)
            except pd.errors.ParserError as exc:
                if error:
                    print('Failed to read file:', name)
                    continue
                csv_zip = zip_handle.open(zip_info)
                csv_zip = io.BytesIO(csv_zip.read())
                teams = _load_teams(
                    csv_zip, teams=teams, error=exc, compression='infer',
                    team_name=name)
                data = teams.pop(name)

            data = data.pipe(tidy_cols)
            if any(data.columns != [COL_ID, COL_PREDICTED]):
                print('Incorrect data format found for team', name)
                continue
            teams[name] = data
    return teams


def load_teams_data(path):
    """Return a DataFrame with all data of the teams."""
    teams = _load_teams(path)
    print(len(teams), 'teams')
    data = pd.concat(list(teams.values()), axis=1, keys=list(teams))
    return data


def add_cell_type_scv(data, master):
    """Add cell type and scv to data from master."""
    addition = (master
                .rename(columns={
                    'atlas_name': 'cell_type', 'annotated_cell_cycle': 'scv'})
                .set_index('id')
                .loc[:, ['cell_type', 'scv']]
               )
    cols = data.columns.tolist()
    return (data
            .join(addition, on='id')
           )


def load_data(path):
    """Return data from csv file."""
    try:
        data = pd.read_csv(path, compression=None, header=0, dtype=str)
    except UnicodeError:
        data = pd.read_csv(path, compression='gzip', header=0, dtype=str)
    return data


def tidy_cols(data):
    _data = data.copy()
    _data.columns = (_data.columns
                    .str.strip()
                    .str.lower()
                    .str.replace(' ', '_')
                    .str.replace('(', '')
                    .str.replace(')', ''))
    return _data


class Score:
    """Represent a score to be calculated."""

    def __init__(self, true, prediction, train, hpa):
        """Set up instance."""
        self.hpa = hpa
        self.teams_pred = prediction
        self.true = true
        self.train = train
    
    @staticmethod
    def _set_plot_defaults(**kwargs):
        """Set and return plot defaults."""
        fig, ax = plt.subplots(**kwargs)
        ax.set_ylim(bottom=0.0, top=1.0)
        ax.yaxis.set_ticks(np.arange(0.0, 1.0, 0.1))
        ax.set_ylabel("F1 score")
        ax.set_xlabel("Location class")
        return fig, ax
    
    @staticmethod
    def _save_fig(path):
        """Save the current figure to path."""
        plt.savefig(path, dpi=300, bbox_inches='tight')
        
    @staticmethod
    def rename_labels(data):
        """Rename labels."""
        data = data.rename(
            index=str,
            columns={
                int(idx): label for idx, label in ABB_LABEL_INDEX.items()})
        return data

    def _calc(
            self, method, group=COL_TEAM, row_slice=None, team_slice=None,
            **kwargs):
        """Calculate different selected score for all teams."""
        if row_slice is None:
            row_slice = slice(None)
        if team_slice is None:
            team_slice = slice(None)
        pred = self.teams_pred
        true = self.true.loc[row_slice, (COL_CELL_TYPE, COL_EXPECTED)]
        
        grouped = (
            pred
            .iloc[:, team_slice]
            .loc[row_slice]
            .stack()
            .reset_index(1)
            .rename(columns={'level_1': COL_TEAM, 0: COL_PREDICTED})
            .join(true)
            .loc[:, [COL_TEAM, COL_CELL_TYPE, COL_EXPECTED, COL_PREDICTED]]
            .groupby(group)
        )

        mlb = MultiLabelBinarizer()
        mlb.fit(true.expected)
        form = lambda x: mlb.transform(x)
        
        if group == COL_TEAM:
            app_method = lambda x: method(
                form(true.expected), form(x), **kwargs)
            return grouped[COL_PREDICTED].apply(app_method)
    
        else:
            true = grouped[COL_EXPECTED].apply(form)
            pred = grouped[COL_PREDICTED].apply(form)
            return (
                pd.concat([true.rename('true'), pred.rename('pred')], axis=1)
                .apply(lambda df: method(df.true, df.pred, **kwargs), axis=1)
            )
    
    @staticmethod
    def calc_count(data, data_type):
        """Calculate label count for a pandas Series data set."""
        count = data.apply(Counter).sum()
        count = (pd.DataFrame(dict(count), index=[data_type])
         .T
         .sort_values(data_type, ascending=False))
        return count

    def calc_f1_all(self, group=COL_TEAM, team_slice=None):
        """Calculate overall macro F1 for all teams."""
        result = self._calc(
            f1_score, group=group, team_slice=team_slice, average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'all'})
    
    def calc_accuracy_all(self, group=COL_TEAM, team_slice=None):
        """Calculate overall macro accuracy for all teams."""
        result = self._calc(
            accuracy_score, group=group, team_slice=team_slice)
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'all'})
    
    def calc_precision_all(self, group=COL_TEAM, team_slice=None):
        """Calculate overall macro precision for all teams."""
        result = self._calc(
            precision_score, group=group, team_slice=team_slice,
            average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'all'})
    
    def calc_recall_all(self, group=COL_TEAM, team_slice=None):
        """Calculate overall macro recall for all teams."""
        result = self._calc(
            recall_score, group=group, team_slice=team_slice,
            average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'all'})
    
    def calc_f1_acc_prec_rec_all(self, team_slice=None):
        """Calculate f1, accuracy, precision and recall score for all teams."""
        f1 = self.calc_f1_all(team_slice=team_slice)
        acc = self.calc_accuracy_all(team_slice=team_slice)
        prec = self.calc_precision_all(team_slice=team_slice)
        rec = self.calc_recall_all(team_slice=team_slice)
        result = (f1
                  .rename(columns={'all': 'f1'})
                  .assign(accuracy=acc.loc[:, 'all'])
                  .assign(precision=prec.loc[:, 'all'])
                  .assign(recall=rec.loc[:, 'all'])
                 )
        return result

    def calc_f1_per_class(self, team_slice=None):
        """Calculate per class F1 for all teams."""
        result = self._calc(
            f1_score, team_slice=team_slice, average=None)
        return pd.DataFrame(dict(result)).T
    
    def calc_jaccard_per_class(self, team_slice=None):
        """Calculate per class jaccard score for all teams."""
        result = self._calc(jaccard_score, team_slice=team_slice, average=None)
        return pd.DataFrame(dict(result)).T
    
    def calc_precision_per_class(self, team_slice=None):
        """Calculate per class precision for all teams."""
        result = self._calc(
            precision_score, team_slice=team_slice, average=None)
        return pd.DataFrame(dict(result)).T
    
    def calc_recall_per_class(self, team_slice=None):
        """Calculate per class recall for all teams."""
        result = self._calc(
            recall_score, team_slice=team_slice, average=None)
        return pd.DataFrame(dict(result)).T
    
    def calc_f1_multi(self, team_slice=None):
        """Calculate macro F1 for multilocalizing samples for all teams."""
        multi_slice = self.true.expected.apply(lambda x: len(x) > 1)
        result = self._calc(
            f1_score, row_slice=multi_slice, team_slice=team_slice,
            average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'multi'})
    
    def calc_f1_single(self, team_slice=None):
        """Calculate macro F1 for singlelocalizing samples for all teams."""
        single_slice = self.true.expected.apply(lambda x: len(x) < 2)
        result = self._calc(
            f1_score, row_slice=single_slice, team_slice=team_slice,
            average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'single'})
    
    def calc_f1_scv(self, team_slice=None):
        """Calculate macro F1 for scv samples for all teams."""
        scv_slice = self.true.scv == '1'
        result = self._calc(
            f1_score, row_slice=scv_slice, team_slice=team_slice,
            average='macro')
        return pd.DataFrame(result).rename(columns={COL_PREDICTED: 'scv'})
    
    def calc_all_local(self):
        """Calculate all groups F1 for all, single and multi localizations."""
        data = {}
        f1_all = self.calc_f1_all()
        f1_single = self.calc_f1_single()
        f1_multi = self.calc_f1_multi()
        for name, team_slice in (
                ('top_1-10', TOP_10_TEAMS), ('top_11-100', TOP_100_TEAMS),
                ('top_101-500', TOP_500_TEAMS),
                ('top_501-2137', TOP_2137_TEAMS)):
            data[name] = pd.concat([
                f1_all.iloc[team_slice],
                f1_single.iloc[team_slice],
                f1_multi.iloc[team_slice]], axis=1)
        return pd.concat(
            [d for d in data.values()],
            keys=[name for name in data], axis=1, sort=False)
    
    def plot_fig2_f1_per_class(self, save_path=None):
        """Return a plot for figure 2 of f1 per class for top 10 teams."""
        return self.plot_violin_f1_per_class(
            team_slice=TOP_10_TEAMS, save_path=save_path)
    
    def plot_violin_f1_per_class(
            self, team_slice=TOP_10_TEAMS, save_path=None):
        """Plot violin plots for each class for teams."""
        dims = a3_dims = (16.5, 11.7)
        fig, ax = self._set_plot_defaults(figsize=dims)
        data = self.calc_f1_per_class(team_slice=team_slice)
        # Calculate sort order by label count of hpa + training sets.
        hpa_set = self.calc_count(self.hpa.hpa, 'HPAv18')
        train = self.calc_count(self.train.target, 'training')
        hpa_train = [hpa_set, train]
        hpa_train = (pd
                    .concat(hpa_train, axis=1, sort=False)
                    .assign(**{'hpa_train': lambda df: (df
                                        .loc[:, ['HPAv18', 'training']]
                                        .sum(axis=1))
                              })
                    .sort_values(by='hpa_train', ascending=False)
                   )
        hpa_train = hpa_train.pop('hpa_train')
        data = data.reindex(hpa_train.index, axis=1).pipe(self.rename_labels)
        stats = data.describe()
        print(stats)
        self.print_fig_stats(stats)
        violin = sns.violinplot(
            ax=ax, data=data, orient='v', scale='area', width=1.2,
            color=sns.color_palette()[2])
        ax.xaxis.grid(True)
        violin.set_xticklabels(
            labels=violin.get_xticklabels(), rotation=70, ha='right')
        if save_path:
            fig_name = 'fig2c_f1_per_class'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
            self._save_fig(Path(save_path) / f'{fig_name}.svg')
        return violin
    
    def plot_sup_f1_per_cell_type(self, save_path=None):
        """Return a violin plot for each cell type for top 10 teams."""
        return self.plot_violin_f1_per_cell_type(
            team_slice=TOP_10_TEAMS, save_path=save_path)
    
    def plot_violin_f1_per_cell_type(
            self, team_slice=TOP_10_TEAMS, save_path=None):
        """Plot violin plots for each cell type for teams."""
        data = (self.calc_f1_all(
            group=[COL_TEAM, COL_CELL_TYPE], team_slice=team_slice)
                .unstack().stack(0))
        order = data.median().sort_values(ascending=False).index
        a3_dims = (16.5, 11.7)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        stats = data.describe()
        print(stats)
        self.print_fig_stats(stats)
        violin = sns.violinplot(
            ax=ax, data=data, order=order, orient='v', scale='area', width=1.2,
            color=sns.color_palette()[2])
        ax.set_xlabel("Cell type")
        ax.xaxis.grid(True)
        violin.set_xticklabels(
            labels=violin.get_xticklabels(), rotation=70, ha='right')
        if save_path:
            fig_name = 'sup_f1_per_cell_type'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
            self._save_fig(Path(save_path) / f'{fig_name}.png')
            self._save_fig(Path(save_path) / f'{fig_name}.jpeg')
            self._save_fig(Path(save_path) / f'{fig_name}.svg')
        return violin
    
    def plot_f1_all_single_multi(
            self, plot_type='violin', team_slice=TOP_10_TEAMS):
        """Plot violin plots for all, single and multi class samples."""
        a3_dims = (16.5, 11.7)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        ax.set_ylabel("Macro F1 score")
        f1_all = self.calc_f1_all(team_slice=team_slice)
        f1_single = self.calc_f1_single(team_slice=team_slice)
        f1_multi = self.calc_f1_multi(team_slice=team_slice)
        data = pd.concat([f1_all, f1_single, f1_multi], axis=1)
        stats = data.describe()
        print(stats)
        self.print_fig_stats(stats)
        if plot_type == 'box':
            plot = sns.boxplot(ax=ax, data=data, orient='v')
        else:
            plot = sns.violinplot(ax=ax, data=data, orient='v')
        return plot
    
    def plot_f1_all_scv(
            self, plot_type='violin', team_slice=TOP_10_TEAMS, save_path=None):
        """Plot violin plots for all, scv class samples."""
        a3_dims = (16.5, 16.5)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        ax.set_xlabel("Image group")
        f1_all = self.calc_f1_all(team_slice=team_slice)
        f1_scv = self.calc_f1_scv(team_slice=team_slice)
        data = pd.concat([f1_all, f1_scv], axis=1)
        stats = data.describe()
        print(stats)
        self.print_fig_stats(stats)
        if plot_type == 'box':
            plot = sns.boxplot(ax=ax, data=data, orient='v')
        else:
            plot = sns.violinplot(ax=ax, data=data, orient='v')
        if save_path:
            fig_name = 'review_f1_all_vs_scv'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
            self._save_fig(Path(save_path) / f'{fig_name}.png')
        return plot
    
    def plot_bar_rare_classes(self, team_slice=TOP_10_TEAMS):
        """Plot two rare classes."""
        fig, ax = self._set_plot_defaults()
        data = self.calc_f1_per_class(team_slice=team_slice)
        data = (data
                .rename(
                    index=str,
                    columns={int(idx): label
                             for idx, label in LABEL_INDEX.items()})
                .loc[:, ['Microtubule ends', 'Rods & Rings']])
        return data.plot(ax=ax, kind='bar')
    
    def plot_fig2_f1_team_distribution(self, save_path=None):
        """Return a violin plot of f1 team distribution for figure 2."""
        return self.plot_all_f1_all_single_multi(
            plot_type='violin', save_path=save_path)
    
    def calc_long_f1_all_single_multi(self, save_path=None):
        """Calculate long form f1 score for all single and multi locals."""
        data = self.calc_all_local()
        stats = data.describe()
        print(stats)
        self.print_fig_stats(stats)
        if save_path:
            fig_name = 'fig2d_f1_all_single_multi'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
        data = (
            data
            .stack()
            .reset_index()
            .rename(index=str, columns={'level_0': 'team','level_1': 'local'})
            .pipe(
                pd.wide_to_long, ['top_'], i=['team', 'local'], j='Teams',
                suffix='\d+-\d+')
            .stack()
            .reset_index()
            .rename(columns={0: 'score'})
            .drop(columns='level_3'))
        
        return data

    def ks_test_single_multi(self, data=None):
        """Print significance of single and multi locals of 4 team groups."""
        if data is None:
            data = self.calc_long_f1_all_single_multi()
        
        for r in ['1-10', '11-100', '101-500', '501-2137']:
            single_score = data[
                (data['local'] == 'single') & (data['Teams'] == r)]
            multi_score = data[
                (data['local'] == 'multi') & (data['Teams'] == r)]
            m = multi_score['score'].to_numpy()
            s = single_score['score'].to_numpy()
            print('Top-'+str(r), ks_2samp(m, s))

    def plot_all_f1_all_single_multi(self, plot_type='violin', save_path=None):
        """Plot box plots for all teams groups for all localization groups."""
        data = self.calc_long_f1_all_single_multi(save_path=save_path)
        stats = data.describe()
        
        self.ks_test_single_multi(data=data)

        a3_dims = (16.5, 16.5)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        if plot_type == 'box':
            plot = sns.boxplot(
                ax=ax, x='Teams', hue='local', y='score', data=data)
        else:
            plot = sns.violinplot(
                ax=ax, x='Teams', hue='local', y='score', data=data)
        ax.set_ylabel("Macro F1 score")
        if save_path:
            fig_name = 'fig2d_f1_all_single_multi'
            self._save_fig(Path(save_path) / f'{fig_name}.svg')
        return plot
    
    def plot_sup_table_1(self, save_path=None):
        """Return the table for table 1 in supplemental part."""
        return self.plot_f1_table(save_path=save_path)
    
    def calc_stats_table(self):
        """Calculate a dataframe with performance scores for selected teams."""
        f1_per_class = score.calc_f1_per_class(
            team_slice=INVITED_TEAMS).pipe(round_down, 3).assign(score='f1')
        jacc_per_class = (score
                          .calc_jaccard_per_class(team_slice=INVITED_TEAMS)
                          .pipe(round_down, 3)
                          .assign(score='jaccard'))
        prec_per_class = (score
                          .calc_precision_per_class(team_slice=INVITED_TEAMS)
                          .pipe(round_down, 3)
                          .assign(score='precision'))
        rec_per_class = (score
                         .calc_recall_per_class(team_slice=INVITED_TEAMS)
                         .pipe(round_down, 3)
                         .assign(score='recall'))
        df_table = pd.concat([
            f1_per_class, jacc_per_class, prec_per_class, rec_per_class])
        df_table = (df_table
                    .reset_index()
                    .rename(columns={'index': 'team'})
                    .melt(id_vars=['team', 'score'], var_name='class')
                    .pipe(
                        pd.pivot_table, index='class',
                        columns=['team', 'score'])
                   )
        
        return df_table
    
    def calc_f1_table(self):
        """Calculate a dataframe with F1 scores for selected teams."""
        f1_per_class = self.calc_f1_per_class(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        f1_single = self.calc_f1_single(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        f1_multi = self.calc_f1_multi(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        f1_all = self.calc_f1_all(team_slice=INVITED_TEAMS).pipe(round_down, 3)
        models = pd.Series(
            [
                'metric learning', 'average', 'average', 'hill climbing',
                'weighted', 'per class voting', 'per class linear', 'average',
                'average'
            ],
            index=f1_all.index,
            name='ensemble')
        df_table = pd.concat(
            [f1_per_class, f1_single, f1_multi, f1_all, models], axis=1).T
        return df_table
    
    def calc_jaccard_table(self):
        """Calculate a dataframe with jaccard scores for selected teams."""
        jacc_per_class = self.calc_jaccard_per_class(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        models = pd.Series(
            [
                'metric learning', 'average', 'average', 'hill climbing',
                'weighted', 'per class voting', 'per class linear', 'average',
                'average'
            ],
            index=jacc_per_class.index,
            name='ensemble')
        df_table = pd.concat([jacc_per_class, models], axis=1).T
        return df_table
    
    def calc_precision_table(self):
        """Calculate a dataframe with precision scores for selected teams."""
        prec_per_class = self.calc_precision_per_class(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        models = pd.Series(
            [
                'metric learning', 'average', 'average', 'hill climbing',
                'weighted', 'per class voting', 'per class linear', 'average',
                'average'
            ],
            index=prec_per_class.index,
            name='ensemble')
        df_table = pd.concat([prec_per_class, models], axis=1).T
        return df_table
    
    def calc_recall_table(self):
        """Calculate a dataframe with recall scores for selected teams."""
        rec_per_class = self.calc_recall_per_class(
            team_slice=INVITED_TEAMS).pipe(round_down, 3)
        models = pd.Series(
            [
                'metric learning', 'average', 'average', 'hill climbing',
                'weighted', 'per class voting', 'per class linear', 'average',
                'average'
            ],
            index=rec_per_class.index,
            name='ensemble')
        df_table = pd.concat([rec_per_class, models], axis=1).T
        return df_table
    
    def plot_f1_table(self, save_path=None):
        """Plot a table with F1 scores for selected teams."""
        df_table = self.calc_f1_table()
        a3_dims = (16.5, 11.7)
        fig, ax = plt.subplots(figsize=a3_dims)
        ax.axis('off')

        table = ax.table(
            cellText=df_table.values, rowLabels=df_table.index,
            colLabels=df_table.columns, cellLoc='center', rowLoc='center',
            loc='center')
        table.auto_set_font_size(False)
        table.set_fontsize(5)
        table.scale(1, 1.41)
        if save_path:
            self._save_fig(Path(save_path) / 'sup_f1_table.pdf')
        return table
    
    def plot_prec_rec_scatter(
            self, legend='full', team_slice=None, save_path=None):
        """Return a scatter plot of precision and recall (figure 2d)."""
        if team_slice is None:
            team_slice = range(2137)
        prec = self.calc_precision_all(team_slice=team_slice)
        rec = self.calc_recall_all(team_slice=team_slice)

        teams = []
        for idx in team_slice:
            team = {}
            try:
                invited = INVITED_TEAM_NAMES[idx]
            except KeyError:
                team['size'] = 1
                team['name'] = 'Other'
                team['style'] = 'spot'
            else:
                team['size'] = 2
                team['name'] = INVITED_TEAM_SHORT_NAMES[invited]
                team['style'] = 'spot'
            teams.append(team)
        teams = pd.DataFrame(teams)
        
        experts = {
            "Precision": [0.74],
            "Recall": [0.6909524],
            "size": 4,
            "team": "Experts",
            'style': 'star',
        }
        experts = pd.DataFrame(experts)

        data = pd.DataFrame()
        data = (data
                .assign(Precision=prec.loc[:, 'all'])
                .reset_index()
                .assign(Recall=rec.reset_index().loc[:, 'all'])
                .rename(columns={'team': 'full_team'})
                .assign(team=teams.loc[:, 'name'])
                .assign(size=teams.loc[:, 'size'])
                .assign(style=teams.loc[:, 'style'])
                .append(experts, sort=True)
                .sort_values(by=['team'])
               )
        sizes = (40, 840)
        markers = {'spot': 'o', 'star': '*'}
        current_palette = sns.color_palette()
        palette = {
            name: current_palette[idx]
            for name, idx in INVITED_TEAM_SHORT_NAMES_COLORS.items()
        }
        palette['Experts'] = 'black'
        palette['Other'] = current_palette[7]
        a3_dims = (16.5, 16.5)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        ax.set_xlim(0, 1.0)
        ax.xaxis.set_ticks(np.arange(0.0, 1.0, 0.1))
        ax.grid(b=True)
        plot = sns.scatterplot(
            ax=ax, x="Precision", y="Recall", data=data, size='size',
            sizes=sizes, hue='team', legend=legend, style='style',
            markers=markers, palette=palette)
        if save_path:
            self._save_fig(Path(
                save_path) / f'fig2b_precision_recall_legend_{legend}.svg')
        return plot
    
    @staticmethod
    def print_fig_stats(data):
        """Print figure stats from pd.DataFrame.describe data."""
        data_d = (data
                .round(2)
                .to_dict()
                )
        text = ''
        for class_, stats in data_d.items():
            class_text = (f"{class_} (mean: {stats['mean']}, min: {stats['min']}, "
                        f"25th P: {stats['25%']}, 50th P: {stats['50%']}, "
                        f"75th P: {stats['75%']}, max: {stats['max']}), ")
            text += class_text

        print(text)

    
def round_down(n, decimals=0):
    multiplier = 10 ** decimals
    return np.floor(n * multiplier) / multiplier

In [0]:
# Run this when you want to change score code without reloading data.
#score = Score(score.true, score.teams_pred, score.train, score.hpa)

In [0]:
score.plot_sup_f1_per_cell_type(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
score.plot_prec_rec_scatter(
    legend=False, team_slice=None, save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
score.plot_prec_rec_scatter(
    legend='full', team_slice=None, save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
score.plot_fig2_f1_per_class(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
score.plot_fig2_f1_team_distribution(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
score.plot_f1_all_scv(team_slice=None, save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
# Run this before plotting score plots. It's going to take a while.
#score = get_score(TRUTH_PATH, ALL_TEAMS_PATH, MASTER_TEST_PATH, TRAIN_PATH, HPA_PATH)