In [7]:
import numpy as np
import multiprocessing as mp
import os
from typing import List, Tuple
from tqdm import notebook

In [8]:
N_bootstraps: int = 10 ** 4

def poisson_bootstrap_tp_fp_fn_tn(
    bundle: Tuple[float, List[Tuple[float, float, float, int]]],

    ) -> List[np.ndarray]:
    threshold, data = bundle
    TP = np.zeros((N_bootstraps))
    FP = np.zeros((N_bootstraps))
    FN = np.zeros((N_bootstraps))
    TN = np.zeros((N_bootstraps))
    for current_label, current_predict, weight, index in data:
        np.random.seed(index)
        current_predict += np.random.normal(0, 0.0125, 1)
        current_predict = int(np.clip(current_predict, 0, 1) >= threshold)
        p_sample = np.random.poisson(1, N_bootstraps) * weight
        
        if current_label == 1 and current_label == 1:
            TP += p_sample
        if current_label == 1 and current_label == 0:
            FN += p_sample        
        if current_label == 0 and current_label == 1:
            FP += p_sample        
        if current_label == 0 and current_label == 0:
            TN += p_sample
    return [TP, FP, FN, TN]

In [9]:
N = 10 ** 6
labels = np.random.randint(0, 2, N)
predicts = np.clip(np.random.normal(0.5, 1, N), 0, 1)
weights = np.array([1 for _ in range(N)])

In [10]:
chunk_size = N
threshold = 0.81
generator = (
    (
        threshold,
        [
            (labels[x + y],
             predicts[x + y],
             weights[x + y],
             x + y,)
        for x in range(chunk_size)
        if x + y < N
        ],
    )
    for y in range(0, N, chunk_size)
)

In [11]:
# cpu_to_use = np.max([os.cpu_count() - 3, 1])
# print(cpu_to_use)

# with mp.Pool(cpu_to_use) as pool:
#     stat_list = list(notebook.tqdm(pool.imap(poisson_bootstrap_tp_fp_fn_tn, generator, chunksize=chunk_size),
#                                    total=N // chunk_size)
sample = next(generator)
result = poisson_bootstrap_tp_fp_fn_tn(sample)

KeyboardInterrupt: 

In [None]:
result

[array([512., 519., 526., 458., 483., 482., 533., 482., 527., 501., 514.,
        494., 468., 528., 478., 497., 499., 503., 544., 482., 547., 525.,
        504., 511., 507., 520., 535., 489., 496., 534., 523., 524., 510.,
        512., 503., 509., 534., 537., 469., 520., 511., 477., 507., 501.,
        509., 483., 518., 466., 525., 487., 497., 474., 552., 533., 500.,
        532., 509., 500., 478., 538., 499., 536., 498., 456., 536., 519.,
        523., 531., 527., 490., 522., 561., 539., 520., 495., 493., 553.,
        510., 497., 519., 522., 511., 504., 504., 508., 524., 512., 527.,
        503., 476., 537., 484., 522., 497., 490., 486., 520., 475., 561.,
        523.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.