In [2]:
import itertools
import os
from pathlib import Path
for folder in itertools.chain([Path.cwd()], Path.cwd().parents):
    if (folder / 'Pipfile').exists():
        os.chdir(folder)
        break

import shelve

from matplotlib import pyplot
from toolz import keyfilter

from formatting import h2, format_ci, format_decimal, render_struct_table
from notebooks.heart_transplant.heart_transplant_training_curves import \
    HEART_TRANSPLANT_TRAINING_CURVES_EXPANDING_IDENTIFIER
from utils import transpose_list, transpose_dicts
from visualisation import set_integer_ticks, display_html

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

from notebooks.heart_transplant.dependencies.heart_transplant_functions import  format_heart_transplant_method_name

from notebooks.heart_transplant.dependencies.heart_transplant_data import get_reduced_binary_dataset_cached

pyplot.style.use('default')

PUBLISH_FOLDER  = './data/heart_transplant/publish'

In [3]:
X_365, y_365, dataset_raw = get_reduced_binary_dataset_cached()

X_90, y_90, _ = get_reduced_binary_dataset_cached(survival_days=90)

# %%
# fill_between(train_sizes, fit_times_mean - fit_times_std,
#                          fit_times_mean + fit_times_std, alpha=0.1)

[Memory]0.8s, 0.0min    : Loading get_reduced_binary_dataset...
[Memory]2.6s, 0.0min    : Loading get_reduced_binary_dataset...


In [4]:
with shelve.open(HEART_TRANSPLANT_TRAINING_CURVES_EXPANDING_IDENTIFIER+'_365_ALL', 'ru') as data:

    def plot_x_y(_metrics, **kwargs):
        x_auc, y_auc = transpose_list([
            [n_features, _metrics[n_features]['roc_auc'].mean]
            for n_features in sorted(list(_metrics.keys()))
        ])

        x_confidence, y_low_confidence, y_high_confidence = transpose_list([
            [n_features, *_metrics[n_features]['roc_auc'].ci]
            for n_features in sorted(list(_metrics.keys()))
        ])

        pyplot.plot(x_auc, y_auc, **kwargs)
        pyplot.fill_between(x_confidence, y_low_confidence, y_high_confidence, alpha=0.1)

    for method_name, item in data.items():
       pyplot.figure()
       pyplot.grid(alpha=0.5, linestyle='--', linewidth=0.75)
       pyplot.title(format_heart_transplant_method_name(method_name))
       for score_type, metrics_item in keyfilter(lambda k: k in ('test'), transpose_dicts(item)).items():
            plot_x_y(metrics_item, label=score_type)
            pyplot.xlabel('n features')
            pyplot.ylabel('ROC AUC')
            pyplot.ylim(0.5, 0.7)
       set_integer_ticks()
    # pyplot.legend()

error: db file doesn't exist; use 'c' or 'n' flag to create a new db

In [None]:
with shelve.open(HEART_TRANSPLANT_TRAINING_CURVES_EXPANDING_IDENTIFIER+'_365_ALL', 'ru') as data:
    for method_name, series in data.items():
        h2(method_name)
        rows = []
        last_item = None
        rows.append(['<b>n features</b>', '<b>ROC AUC</b>', '<b>CI</b>', '<b>feature</b>'])
        for n_features, item in sorted(series.items(), key=lambda i: i[0]):
            # print()
            current_features = set(item['features'])

            if last_item:
                rows.append([n_features, format_decimal(last_item['test']['roc_auc'].mean) + ' + ' + format_decimal(item['test']['roc_auc'].mean-last_item['test']['roc_auc'].mean), format_ci(item['test']['roc_auc'].ci), '+ '+(",".join(current_features-set(last_item['features'])))])
            else:
                rows.append([n_features, format_decimal(item['test']['roc_auc'].mean), format_ci(item['test']['roc_auc'].ci), '+ '+(",".join(current_features))])
            last_item = item
        display_html(render_struct_table(rows))