# Dirichlet multinomial

Evaluate estimates distributions using Dirichlet-Multinomial

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import arviz as az
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

import mmur
import mmu
from mmur import DirichletMultinomialConfusionMatrix

In [None]:
%matplotlib inline
plt.style.use('ggplot')
plt.rcParams['text.color'] = 'black'
plt.rcParams['figure.max_open_warning'] = 0
COLORS = [i['color'] for i in plt.rcParams['axes.prop_cycle']]

In [None]:
def plot_metric_distributions(estimated_metrics, gt_metrics, coverage=None):
    fig, axs = plt.subplots(ncols=5, figsize=(25, 5))
    for i, c in enumerate(estimated_metrics.columns):
        sns.kdeplot(estimated_metrics[c], ax=axs[i], label='estimated')
        if coverage is not None:
            sns.kdeplot(coverage[c], ax=axs[i], label='simulated')
        axs[i].axvline(gt_metrics[c][0], c='grey', lw=2, ls='--', label='population mean')
    axs[0].legend()
    return fig, axs

In [None]:
target_metrics = [
    'neg.precision', 'pos.precision', 'neg.recall', 'pos.recall', 'mcc'
]

## Generate data

Generate data from a Logistic process with noise

#### Hold-out set

Validate the model by comparing the credible interval of the samples from the model and unseen data sampled from the data generating process.

In [None]:
generator = mmur.LogisticGenerator()
outp = generator.fit_transform(
    train_samples=10000,
    test_samples=10000,
    holdout_samples=10000,
    noise_sigma=0.3,
    #enable_noise=True,
    random_state=123456
)

Select the test sets

In [None]:
y_test = outp['test']['y']
probas_test = outp['test']['proba']

Compute the confusion matrix on the test-set

In [None]:
test_conf_mat, test_metrics = mmu.binary_metrics(
    y_test, scores=probas_test, threshold=0.5
)
test_conf_mat = test_conf_mat.flatten()
test_metrics = mmu.metrics_to_dataframe(test_metrics)[target_metrics]

In [None]:
mmu.confusion_matrix_to_dataframe(test_conf_mat)

In [None]:
test_metrics

### Ground truth

In [None]:
gt_proba_test = outp['ground_truth']['test']

Compute the ground truth confusion matrix and metrics

In [None]:
gt_conf_mat, gt_metrics = mmu.binary_metrics(
    y_test, scores=gt_proba_test, threshold=0.5
)

In [None]:
mmu.confusion_matrix_to_dataframe(gt_conf_mat)

In [None]:
gt_metrics = mmu.metrics_to_dataframe(gt_metrics)[target_metrics]
gt_metrics

### Hold-out set

In [None]:
y_holdout = outp['holdout']['y']
proba_holdout = outp['holdout']['proba']

Compute metrics on this set

In [None]:
holdout_conf_mat, holdout_metrics = mmu.binary_metrics_runs(
    y=y_holdout, scores=proba_holdout, threshold=0.5
)
holdout_metrics = mmu.metrics_to_dataframe(holdout_metrics)[target_metrics]

## Model

In [None]:
dm_model = DirichletMultinomialConfusionMatrix()

In [None]:
y_hat = dm_model.fit_predict(
    test_conf_mat,
    n_samples=10000
)

### Prior traces

In [None]:
axs = dm_model.plot_prior_trace()

### Posterior traces

In [None]:
axs = dm_model.plot_posterior_trace()

### Generative posterior

In [None]:
axs = dm_model.plot_posterior()

### Estimated metrics

In [None]:
mtr = mmu.metrics_to_dataframe(
    dm_model.compute_metrics(metrics=target_metrics),
    target_metrics
)

In [None]:
_ = plot_metric_distributions(mtr, gt_metrics)

In [None]:
_ = sns.pairplot(mtr, diag_kind='kde')

### Compute Highest Density Interval (HDI)

#### Predictive samples from Confusion Matrix

In [None]:
dm_model.posterior_predictive_hdi()

In [None]:
_ = dm_model.plot_hdi_predictive_posterior()

#### Metrics based on Confusion Matrix

In [None]:
_ = dm_model.plot_hdi(metrics=['pos.prec', 'pos.rec'])

In [None]:
fig, ax = dm_model.plot_hdi()

### Coverage

In [None]:
holdout_metrics_moments = pd.concat(
    (
        holdout_metrics.apply([np.min, np.max, np.mean]).T,
         mmur.stats.compute_hdi(holdout_metrics)
    ), axis=1
)
holdout_metrics_moments

HDI estimates

In [None]:
hdi_estimates = mmur.stats.compute_hdi(mtr)
hdi_estimates['mu'] = mtr.values.mean(0)

In [None]:
hdi_estimates

## Coverage

In [None]:
coverage_counts = ((holdout_metrics - hdi_estimates['lb'].T) < 0.0).sum().to_frame()
coverage_counts.columns = ['<lb']
coverage_counts['>ub'] = ((holdout_metrics - hdi_estimates['ub'].T) > 0.0).sum()

coverage_counts['under_coverage'] = coverage_counts.sum(1)

coverage_counts['under_coverage_perc'] = (
    (coverage_counts['under_coverage'] / holdout_metrics.shape[0])
    * 100
)

In [None]:
coverage_counts

In [None]:
fig, ax = mmur.viz.dists.plot_hdis_violin(hdi_estimates, holdout_metrics)


In [None]:
fig, axs = plt.subplots(figsize=(50, 10), ncols=5, sharey=True)
for i, idx in enumerate(hdi_estimates.index):
    ax = axs[i]
    sns.kdeplot(
        mtr[idx],
        clip=(mtr[idx].min(), mtr[idx].max()),
        ax=ax,
        label='estimated',
        color=COLORS[0]
    )
    x, y = ax.get_lines()[0].get_data()
    shade_idx = (x > hdi_estimates.loc[idx, 'lb']) & (x < hdi_estimates.loc[idx, 'ub'])
    ax.fill_between(
        x=x[shade_idx],
        y1=y[shade_idx],
        alpha=0.3,
        label='HDI estimate',
        color=COLORS[0]
    )

    ax.axvline(
        x=holdout_metrics_moments.loc[idx, 'lb'],
        color=COLORS[1],
        ls='--',
        label='HDI hold-out'
    )
    ax.axvline(x=holdout_metrics_moments.loc[idx, 'ub'], color=COLORS[1], ls='--')
    ax.axvline(
        x=holdout_metrics_moments.loc[idx, 'amin'],
        color=COLORS[3],
        ls='dotted',
        lw=3,
        label='range hold-out'
    )
    ax.axvline(x=holdout_metrics_moments.loc[idx, 'amax'], color=COLORS[3], ls='dotted', lw=3)
    ax.legend()
    ax.set_ylabel('density', fontsize=16)
    ax.set_xlabel(idx, fontsize=18)
    ax.tick_params(labelsize=14)
    ax.legend(fontsize=16);
    fig.suptitle('Estimated vs observed out-of-sample performance', fontsize=20)