In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def proba_of_flipping(n, pA, pB, p11, num_sampling):
    """
    Estimates the probability of flipping, i.e. the probability that method A was in fact better than method B.
    This is done by sampling from a Dirichlet distribution based on provided parameters.

    Args:
    n (int): Number of samples in the test set n.
    pA (float): Accuracy of method A.
    pB (float): Accuracy of method B.
    p11 (float): The assumed fraction of cases for which the two methods are correct (plays the same role as the correlation for the Dice).
    num_sampling (int): Number of samples to generate from the Dirichlet distribution.

    Returns:
    float: Probability of flipping
    """

    n11 = p11 * n

    x1 = n * pA - n11  # Occurrences of A but not B
    x2 = n * pB - n11  # Occurrences of B but not A

    # Compute parameters for the Dirichlet distribution
    alpha1 = x1 + 1
    alpha2 = x2 + 1
    alpha3 = n - x1 - x2 + 1

    # Define the number of samples to draw from the Dirichlet distribution
    num_samples = num_sampling

    # Draw samples from the Dirichlet distribution with parameters alpha1, alpha2, and alpha3
    samples = np.random.dirichlet([alpha1, alpha2, alpha3], num_samples)

    p1_samples = samples[:, 0]
    p2_samples = samples[:, 1]

    # Count how many times p1<p2 which is the same as pA<pB
    count_p1_less_than_p2 = np.sum(p1_samples < p2_samples)

    return samples, p1_samples, p2_samples, count_p1_less_than_p2 / num_samples

In [12]:
proba_of_flipping(100, 0.9,0.89,0.5,10000)

(array([[0.40812127, 0.38556315, 0.20631559],
        [0.37962746, 0.33942649, 0.28094605],
        [0.36583308, 0.47857451, 0.15559242],
        ...,
        [0.4143294 , 0.36302164, 0.22264896],
        [0.3882837 , 0.39046857, 0.22124773],
        [0.43717454, 0.32341637, 0.2394091 ]]),
 array([0.40812127, 0.37962746, 0.36583308, ..., 0.4143294 , 0.3882837 ,
        0.43717454]),
 array([0.38556315, 0.33942649, 0.47857451, ..., 0.36302164, 0.39046857,
        0.32341637]),
 0.4567)