In [None]:
import random
import typing as typ

import numpy as np
from tqdm import tqdm

In [None]:
SEED = 40
PRECISION = 1e-2
N_FLIPS = 1_000_000
random.seed(SEED)

# true_bias = random.random()
true_bias = .82
hypotheses = np.arange(0, 1, PRECISION).tolist()
weights = [1.] * len(hypotheses)
weights_hist = [weights]

def pr(hyp: float, w: float, flip: bool) -> float:
    return w*(hyp if flip else 1-hyp)

def update_hyp_weights(
    hypotheses: list[float], weights: list[float], flip: bool
) -> list[float]:
    assert all(map(lambda x: 0<=x<=1, hypotheses))
    assert all(map(lambda x: 0<=x<=1, weights))
    max_prob = max(
        pr(hyp, w, flip) for hyp, w in zip(hypotheses, weights)
    )
    assert max_prob > 0, f"{max_prob = }"
    return [
        max(
            prob
            for hyp_, w_ in zip(hypotheses, weights)
            if (prob := pr(hyp, w, flip)) == pr(hyp_, w_, flip)
        ) / max_prob
        for hyp, w in zip(hypotheses, weights)
    ]

flips = []

for i in tqdm(range(N_FLIPS)):
    flip = random.random() < true_bias
    weights = update_hyp_weights(hypotheses, weights, flip)
    flips.append(flip)
    weights_hist.append(weights)

100%|██████████| 1000000/1000000 [47:11<00:00, 353.21it/s]


In [None]:
print(f"{true_bias=:.4f}; {np.mean(flips)=:.4f}")

w, hyp, i = max([(w, h, i) for i, (h, w) in enumerate(zip(hypotheses, weights))])
print(hyp)

true_bias=0.8200; np.mean(flips)=0.8200
0.8200000000000001


In [None]:
import plotly.express as px

A = np.array(weights_hist)

n = A.shape[1]
l = A.shape[0] // n
B = np.array(
    [
        A[l*i:l*(i+1)].mean(0)
        for i in range(n)
    ]

)

fig = px.imshow(
    B,
    color_continuous_scale="Viridis",
    labels=dict(color="value")
)

fig.update_layout(
    width=600,
    height=600,
    margin=dict(l=0, r=0, t=30, b=30),   # shrink or expand
    xaxis=dict(domain=[0.0, 1.0]),   # 90% of horizontal space
    yaxis=dict(domain=[0.0, 1.0]),
)

fig.show()


In [None]:
hyp,w

(0.8200000000000001, 1.0)

In [None]:
from pprint import pprint
pprint(list(enumerate(zip(weights,hypotheses))))

[(0, (0.0, 0.0)),
 (1, (0.0, 0.01)),
 (2, (0.0, 0.02)),
 (3, (0.0, 0.03)),
 (4, (0.0, 0.04)),
 (5, (0.0, 0.05)),
 (6, (0.0, 0.06)),
 (7, (0.0, 0.07)),
 (8, (0.0, 0.08)),
 (9, (0.0, 0.09)),
 (10, (0.0, 0.1)),
 (11, (0.0, 0.11)),
 (12, (0.0, 0.12)),
 (13, (0.0, 0.13)),
 (14, (0.0, 0.14)),
 (15, (0.0, 0.15)),
 (16, (0.0, 0.16)),
 (17, (0.0, 0.17)),
 (18, (0.0, 0.18)),
 (19, (0.0, 0.19)),
 (20, (0.0, 0.2)),
 (21, (0.0, 0.21)),
 (22, (0.0, 0.22)),
 (23, (0.0, 0.23)),
 (24, (0.0, 0.24)),
 (25, (0.0, 0.25)),
 (26, (0.0, 0.26)),
 (27, (0.0, 0.27)),
 (28, (0.0, 0.28)),
 (29, (0.0, 0.29)),
 (30, (0.0, 0.3)),
 (31, (0.0, 0.31)),
 (32, (0.0, 0.32)),
 (33, (0.0, 0.33)),
 (34, (0.0, 0.34)),
 (35, (0.0, 0.35000000000000003)),
 (36, (0.0, 0.36)),
 (37, (0.0, 0.37)),
 (38, (0.0, 0.38)),
 (39, (0.0, 0.39)),
 (40, (0.0, 0.4)),
 (41, (0.0, 0.41000000000000003)),
 (42, (0.0, 0.42)),
 (43, (0.0, 0.43)),
 (44, (0.0, 0.44)),
 (45, (0.0, 0.45)),
 (46, (0.0, 0.46)),
 (47, (0.0, 0.47000000000000003)),
 (48, (0.0