In [0]:
!pip install -U numpy
!pip install pandas==0.24.2

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
from collections import Counter
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

FONT_SIZE = 20
matplotlib.rc('font', size=FONT_SIZE)
matplotlib.rc('axes', labelsize=28)
sns.set_palette('colorblind')

NOTEBOOK_DIR = ('drive/My Drive/Cell Profiling/Challenges/'
                'HPA challenge paper-Ouyang et al 2019/notebooks')
TRUTH_PATH = Path(NOTEBOOK_DIR) / 'kaggle_derived_solution.csv'
MASTER_TEST_PATH = Path(NOTEBOOK_DIR) / 'kaggle_master_test.csv'
MASTER_TRAIN_PATH = Path(NOTEBOOK_DIR) / 'kaggle_master_training.csv'
TRAIN_PATH = Path(NOTEBOOK_DIR) / 'kaggle_train.csv'
HPA_PATH = Path(NOTEBOOK_DIR) / 'hpa_v18_1_all_no_dups_merged.csv'

COL_CELL_TYPE = 'cell_type'
COL_ID = 'id'
COL_EXPECTED = 'expected'
COL_PREDICTED = 'predicted'
COL_TEAM = 'team'

ABB_LABEL_INDEX = {
    "0": "Nucleoplasm",
    "1": "N. membrane",
    "2": "Nucleoli",
    "3": "N. fibrillar c.",
    "4": "N. speckles",
    "5": "N. bodies",
    "6": "ER",
    "7": "Golgi app.",
    "8": "Peroxisomes",
    "9": "Endosomes",
    "10": "Lysosomes",
    "11": "Int. fil.",
    "12": "Actin fil.",
    "13": "F. a. sites",
    "14": "Microtubules",
    "15": "M. ends",
    "16": "Cyt. bridge",
    "17": "Mitotic spindle",
    "18": "MTOC",
    "19": "Centrosome",
    "20": "Lipid droplets",
    "21": "PM",
    "22": "C. Junctions",
    "23": "Mitochondria",
    "24": "Aggresome",
    "25": "Cytosol",
    "26": "C. bodies",
    "27": "Rods & Rings"
}


def load_data(path):
    """Return data from csv file."""
    try:
        data = pd.read_csv(path, compression=None, header=0, dtype=str)
    except UnicodeError:
        data = pd.read_csv(path, compression='gzip', header=0, dtype=str)
    return data


def tidy_cols(data):
    _data = data.copy()
    _data.columns = (_data.columns
                    .str.strip()
                    .str.lower()
                    .str.replace(' ', '_')
                    .str.replace('(', '')
                    .str.replace(')', ''))
    return _data


class Count:
    """Represent a count to be calculated."""
    
    def __init__(
            self, hpa, kaggle_test, kaggle_train, master_test, master_train):
        """Set up instance."""
        self.hpa = hpa
        self.kaggle_test = kaggle_test
        self.kaggle_train = kaggle_train
        self.master_test = master_test
        self.master_train = master_train
    
    @classmethod
    def create(
            cls, hpa_path, test_path, train_path, master_test_path,
            master_train_path):
        hpa = (load_data(hpa_path)
               .drop(columns='Unnamed: 0')
               .rename(columns={'location': 'hpa'})
               .loc[lambda df: ~df.hpa.str.contains('-1')])
        hpa = cls.parse_data(hpa, 'hpa')
        test = load_data(test_path)
        test = cls.parse_data(test, 'expected')
        train = load_data(train_path)
        train = cls.parse_data(train, 'target')
        master_test = (load_data(master_test_path)
                        .pipe(tidy_cols))
        master_train = (load_data(master_train_path)
                        .pipe(tidy_cols))
        return cls(hpa, test, train, master_test, master_train)
    
    @staticmethod
    def parse_data(data, data_column):
        """Parse and return tidy data."""
        return (data
                .pipe(tidy_cols)
                .sort_values(by=[COL_ID])
                .reset_index(drop=True)
                .assign(**{data_column: lambda df: (df
                                             .loc[:, data_column]
                                             .str.split(' ')
                                             .apply(pd.to_numeric))}))
    
    @staticmethod
    def _set_plot_defaults(figsize=(16.5, 11.7), **kwargs):
        """Set and return plot defaults."""
        fig, ax = plt.subplots(figsize=figsize, **kwargs)
        return fig, ax
    
    @staticmethod
    def _save_fig(path):
        """Save the current figure to path."""
        plt.savefig(path, dpi=300, bbox_inches='tight')
    
    @staticmethod
    def calc_count(data, data_type):
        """Calculate label count for a pandas Series data set."""
        count = data.apply(Counter).sum()
        count = (pd.DataFrame(dict(count), index=[data_type])
         .T
         .sort_values(data_type, ascending=False))
        return count
    
    @staticmethod
    def calc_label_per_image(data, data_type):
        """Calculate count of amount of labels per image for Series data."""
        count = data.apply(len).value_counts()
        count = pd.DataFrame(dict(count), index=[data_type]).T
        return count
    
    @staticmethod
    def rename_labels(data):
        """Rename labels."""
        data = data.rename(
            index={int(idx): label for idx, label in ABB_LABEL_INDEX.items()})
        return data
    
    def calc_test(self, row_slice=None, data_type='test'):
        """Calculate and return counts for kaggle test set."""
        if row_slice is None:
            row_slice = slice(None)
        test = self.kaggle_test.loc[row_slice]
        return self.rename_labels(self.calc_count(test.expected, data_type))
    
    def calc_train(self, row_slice=None, method='calc_count', rename=True):
        """Calculate and return counts for kaggle training set."""
        if row_slice is None:
            row_slice = slice(None)
        train = self.kaggle_train.loc[row_slice]
        count = getattr(self, method)(train.target, 'training')
        if not rename:
            return count
        return self.rename_labels(count)
    
    def calc_test_cell_type(self, row_slice=None):
        """Calculate and return cell type counts for kaggle test set."""
        if row_slice is None:
            row_slice = slice(None)
        test = self.master_test.loc[row_slice]
        return self.calc_count(
            test.atlas_name.str.split(','), 'test cell type')
    
    def calc_train_cell_type(self, row_slice=None):
        """Calculate and return cell type counts for kaggle training set."""
        if row_slice is None:
            row_slice = slice(None)
        train = self.master_train.loc[row_slice]
        return self.calc_count(
            train.atlas_name.str.split(','), 'train cell type')
    
    def calc_private_lb(self):
        """Calculate and return counts for kaggle private leaderboard."""
        private_slice = self.kaggle_test.usage == 'Private'
        return self.rename_labels(
            self.calc_test(row_slice=private_slice, data_type='test_private'))
    
    def calc_public_lb(self):
        """Calculate and return counts for kaggle public leaderboard."""
        public_slice = self.kaggle_test.usage == 'Public'
        return self.rename_labels(
            self.calc_test(
                row_slice=public_slice, data_type='validation_public'))
    
    def calc_ignored_lb(self):
        """Calculate and return counts for kaggle ignored samples."""
        ignored_slice = self.kaggle_test.usage == 'Ignored'
        return self.rename_labels(
            self.calc_test(row_slice=ignored_slice, data_type='ignored'))
    
    def calc_hpa(self, row_slice=None, method='calc_count', rename=True):
        """Calculate and return counts for hpa training set."""
        if row_slice is None:
            row_slice = slice(None)
        hpa = self.hpa.loc[row_slice]
        count = getattr(self, method)(hpa.hpa, 'HPAv18')
        if not rename:
            return count
        return self.rename_labels(count)
    
    def calc_hpa_image_per_ab(self):
        """Calculate how many images per antibody id there is in HPAv18."""
        count = (self.hpa
                 .id.str.split('_')
                 .apply(lambda x: x[0])
                 .value_counts()
                )
        return count
    
    def plot_fig2a_class_distribution(self, save_path=None):
        """Return the plot for figure 1 class distribution."""
        return self.plot_kaggle(
            orient='v', scale='linear', hpa=True, save_path=save_path)
    
    def plot_kaggle(
            self, orient='h', scale='linear', hpa=True, save_path=None):
        """Plot a bar plot with kaggle counts."""
        private = self.calc_private_lb()
        public = self.calc_public_lb()
        train = self.calc_train()
        if hpa:
            hpa_set = self.calc_hpa()
            all_sets = [hpa_set, train, public, private]
            all_sets = (pd
                        .concat(all_sets, axis=1, sort=False)
                        .assign(**{'hpa_train': lambda df: (df
                                            .loc[:, ['HPAv18', 'training']]
                                            .sum(axis=1))
                                  })
                        .sort_values(by='hpa_train', ascending=False)
                       )
            hpa_train = all_sets.pop('hpa_train')
        else:
            all_sets = [train, public, private]
            all_sets = pd.concat(all_sets, axis=1, sort=False)
        
        a3_dims = (16.5, 11.7)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        ax.set_ylabel("Images")
        ax.set_xlabel("Location class")
        if hpa:
            hpa_train.plot(
                ax=ax, marker='+', linestyle='None',
                color='black', label='HPAv18 + Training sets', legend=True)
        if orient == 'h':
            if scale == 'log':
                ax.set_xscale('log')
            bar = all_sets.plot.barh(
                ax=ax, stacked=True,
                color=sns.color_palette(n_colors=len(all_sets.columns)))
        else:
            if scale == 'log':
                ax.set_yscale('log')
            bar = all_sets.plot.bar(
                ax=ax, stacked=True,
                color=sns.color_palette(n_colors=len(all_sets.columns)))
            bar.set_xticklabels(
                labels=bar.get_xticklabels(), rotation=70, ha='right')
        if save_path:
            self._save_fig(Path(save_path) / 'fig2a_class_distribution_all.svg')
        return bar
    
    def plot_hpa(self, orient='h', scale='log', save_path=None):
        """Plot a bar plot with hpa counts."""
        hpa_set = self.calc_hpa()
        a3_dims = (16.5, 11.7)
        fig, ax = self._set_plot_defaults(figsize=a3_dims)
        if orient == 'h':
            if scale == 'log':
                ax.set_xscale('log')
            bar = hpa_set.plot.barh(ax=ax)
        else:
            if scale == 'log':
                ax.set_yscale('log')
            bar = hpa_set.plot.bar(ax=ax)
            bar.set_xticklabels(labels=bar.get_xticklabels(), rotation=70)
        if save_path:
            self._save_fig(Path(save_path) / 'class_distribution_hpa.svg')
        return bar
    
    def plot_pie_per_class(self, save_path=None):
        """Plot a pie chart of training data per class."""
        train = self.calc_train().rename(columns={'training': 'Images'})
        hpa = self.calc_hpa().rename(columns={'HPAv18': 'Images'})
        data = train + hpa
        data = data.sort_values('Images', ascending=False)
        stats = data
        print(stats)
        fig, ax = self._set_plot_defaults()
        show_labels = ('Nucleoplasm', 'Cytosol', 'Rods & Rings')
        labels = [
            label if label in show_labels else ''
            for idx, label in enumerate(data.index)]
        pie = data.plot.pie(y='Images', ax=ax, labels=labels)
        ax.get_legend().remove()
        if save_path:
            fig_name = 'fig1_class_distribution_pie'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
            self._save_fig(Path(save_path) / f'{fig_name}.svg')
        return pie
    
    def plot_pie_per_label_count(self, save_path=None):
        """Plot a pie chart of training data per label count."""
        train = (self.calc_train(method='calc_label_per_image', rename=False)
                 .rename(columns={'training': 'Labels'}))
        hpa = (self.calc_hpa(method='calc_label_per_image', rename=False)
               .rename(columns={'HPAv18': 'Labels'}))
        data = train.add(hpa, fill_value=0)
        data = data.sort_values('Labels', ascending=False)
        stats = data
        print(stats)
        fig, ax = self._set_plot_defaults()
        show_labels = (1, 2, 3, 4)
        labels = [
            label if label in show_labels else ''
            for idx, label in enumerate(data.index)]
        pie = data.plot.pie(y='Labels', ax=ax, labels=labels)
        ax.get_legend().remove()
        if save_path:
            fig_name = 'fig1_label_distribution_pie'
            stats.to_csv(Path(save_path) / f'{fig_name}.csv')
            self._save_fig(Path(save_path) / f'{fig_name}.svg')
        return pie


In [0]:
count = Count.create(
    HPA_PATH, TRUTH_PATH, TRAIN_PATH, MASTER_TEST_PATH, MASTER_TRAIN_PATH)

In [0]:
count.plot_pie_per_label_count(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
count.plot_pie_per_class(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
count.plot_fig2a_class_distribution(save_path=Path(NOTEBOOK_DIR) / 'figures')

In [0]:
count.calc_hpa()