# RQ1: How well can we estimate prevalences from unlabeled deployment data?
This notebook generates the figures 5 and C.11. It assesses quantification capabilities with existing methods in our use cases.

In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath('..'))
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import torch
from progiter import ProgIter

import quapy as qp
from quapy.method.aggregative import CC, ACC, PCC, PACC, EMQ, KDEyCS, KDEyHD, KDEyML, DMy
from quapy.error import nkld

from src.prev.data_loading import get_values, Kind, Split, all_tasks
from src.prev.plotting import plot_aggregate_results, Confidence, box_plot, multiplot
from src.prev.scaling import scale_prevalences_ir
from src.prev.quantification import QuantificationMethod, absolute_error, compute_w_hat_and_mu_hat, IdentityClassifier

current_path = os.getcwd()
DATA_PATH = Path(current_path).parent / 'data'
RESULT_PATH = Path(current_path).parent / 'results'
assert DATA_PATH.exists() and RESULT_PATH.exists()
torch.manual_seed(seed=0)

<torch._C.Generator at 0x7f0191b16290>

In [2]:
data = {}
for t in ProgIter(all_tasks, desc='Loading data'):
    data[t] = get_values(t, DATA_PATH, proj='mic23_predictions_original_0')  # original paper predictions

Loading data 100.00% 30/30... rate=3.15 Hz, eta=0:00:00, total=0:00:09


In [3]:
# Research Question 1: How well can we estimate prevalences from unlabeled deployment data?
quantification_ir_results = []
IRS = list(np.arange(1, 10.5, 0.5))
for ir in IRS:
    for t in ProgIter(all_tasks):
        # modify DEV_TEST according to IR
        try:
            app_test_logits, app_test_classes = scale_prevalences_ir(logits=data[t][Kind.LOGITS][Split.APP_TEST],
                                                                     classes=data[t][Kind.LABELS][Split.APP_TEST],
                                                                     ir=ir)
        except:
            print(f'{t=}, {ir=}')
            raise
        mod_data = {Kind.LOGITS: {Split.DEV_CAL: data[t][Kind.LOGITS][Split.DEV_CAL],
                                  Split.DEV_TEST: data[t][Kind.LOGITS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_logits},
                    Kind.LABELS: {Split.DEV_CAL: data[t][Kind.LABELS][Split.DEV_CAL],
                                  Split.DEV_TEST: data[t][Kind.LABELS][Split.DEV_TEST],
                                  Split.APP_TEST: app_test_classes}}
        # estimate prevalence using BBSE
        _, bbse_prior = compute_w_hat_and_mu_hat(mod_data[Kind.LABELS][Split.DEV_TEST],
                                                 torch.argmax(mod_data[Kind.LOGITS][Split.DEV_TEST], dim=1),
                                                 torch.argmax(mod_data[Kind.LOGITS][Split.APP_TEST], dim=1))
        prior = (torch.bincount(app_test_classes) / len(app_test_classes)).numpy()
        d_size = len(app_test_classes)
        _info = {'ir': ir, 'task': t}
        _info.update({"BBSE": bbse_prior})
        _info.update({"prior": prior})
        _info.update({"d_size": d_size})
        # convert data to qp format
        dev_data = qp.data.LabelledCollection(torch.softmax(mod_data[Kind.LOGITS][Split.DEV_TEST], dim=1),
                                              mod_data[Kind.LABELS][Split.DEV_TEST])
        app_data = qp.data.LabelledCollection(torch.softmax(mod_data[Kind.LOGITS][Split.APP_TEST], dim=1),
                                              mod_data[Kind.LABELS][Split.APP_TEST])
        dset = qp.data.base.Dataset(training=dev_data, test=app_data)
        # compute estimated prevalences with methods from qp
        for method_name, method in {"CC": CC, "ACC": ACC, "PCC": PCC, "PACC": PACC, "EMQ": EMQ, "HDy": DMy,
                                    'KDEyCS': KDEyCS, 'KDEyHD': KDEyHD, 'KDEyML': KDEyML}.items():
            identity_class = IdentityClassifier(len(prior))
            model = method(identity_class)
            model.fit(dset.training)
            estim_prevalence = model.quantify(dset.test.instances)
            _info.update({method_name: estim_prevalence})
        quantification_ir_results.append(_info)
ir_df = pd.DataFrame(quantification_ir_results)
ir_df.to_pickle(RESULT_PATH / '24_prev_estimation_df.pkl')

 100.00% 30/30... rate=0.24 Hz, eta=0:00:00, total=0:02:06
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:22
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:26
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:30
 100.00% 30/30... rate=0.19 Hz, eta=0:00:00, total=0:02:35
 100.00% 30/30... rate=0.19 Hz, eta=0:00:00, total=0:02:34
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:33
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:31
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:30
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:28
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:25
 100.00% 30/30... rate=0.20 Hz, eta=0:00:00, total=0:02:28
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:24
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:24
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:21
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:21
 100.00% 30/30... rate=0.21 Hz, eta=0:00:00, total=0:02:

# Generate Figure 5 and Figure C.11

In [8]:
display_methods = [QuantificationMethod.CC, QuantificationMethod.EMQ, QuantificationMethod.PACC,
                   QuantificationMethod.KDEyCS,
                   QuantificationMethod.BBSE, QuantificationMethod.ACC, QuantificationMethod.DMy,
                   QuantificationMethod.KDEyML, QuantificationMethod.KDEyHD]
metrics = {"Absolute error": absolute_error, "Normalized KLD": nkld}
limits = {"Absolute error": [0, 0.55], "Normalized KLD": [0, 0.17]}
for metric_name, metric in metrics.items():
    ir_df = pd.read_pickle(RESULT_PATH / '24_prev_estimation_df.pkl')
    for col in ir_df.columns:
        if col not in ["ir", "task", "d_size", 'prior']:
            if metric_name == "Normalized KLD":
                ir_df[col] = ir_df.apply(lambda row: metric(row['prior'], row[col], eps=1 / row['d_size']), axis=1)
            else:
                ir_df[col] = ir_df.apply(lambda row: metric(row['prior'], row[col]), axis=1)
    fin_df = ir_df.groupby('task').aggregate(
        {col: list for col in ir_df.columns if col not in ['task', 'd_size', 'prior']})
    fin_df = fin_df.reset_index()
    # select values at imbalance ratio 10
    ir_10_df = ir_df.loc[ir_df['ir'] == 10][['task', *[q.value for q in display_methods]]]
    # create line plot
    subplt = plot_aggregate_results(fin_df, line_ids=display_methods, file=None,
                                    delta=False, ci=Confidence.STD, y_axis_title=f"<b>{metric_name}</b>", title=None,
                                    bound=[0, 0.2], opacity=0.15)
    # create box plot
    box1 = box_plot(ir_10_df, line_ids=display_methods)
    # create final figure
    fig = multiplot(rows=1, cols=2, subplts=[subplt, box1], horizontal_spacing=0.04, legend_index=1,
                    y_title=f"<b>{metric_name}</b>", sub_x_axis_titles={0: "Imbalance ratio"},
                    sub_y_ranges={0: limits[metric_name]}, shared_yaxes=True, vertical_spacing=0.1,
                    ir_axes=[2], ir_values=[10], little_guys=True, icon_size=0.14, icon_y_adjustment=0.06)
    # fig.show()
    name = f"24_estimating_prevalence_{metric_name}"
    fig.write_image(RESULT_PATH / f"{name}.png")
    # fig.write_image(RESULT_PATH / f"{name}.svg")
    # fig.write_image(RESULT_PATH / f"{name}.pdf")
    # fig.write_html(RESULT_PATH / f"{name}.html")