In [1]:
from __future__ import annotations

from pprint import pprint, pformat
import random
import typing as typ

import numpy as np
from tqdm import tqdm

In [2]:
PRECISION = 1e-4
SEED = 42

class Probability:
    __slots__ = ("sample_space", "prob")
    sample_space: set[int] # discrete sigma-algebra
    prob: dict[int, float] #TODO: reimplement this but using vectors instead?

    def __init__(self, prob: dict[int, float], sample_space: set[int] | None = None) -> None:
        if sample_space is None:
            sample_space = set(prob)
        else:
            # assert set(prob) <= sample_space
            prob |= {i: 0.0 for i in sample_space-set(prob)}
            assert set(prob) == sample_space
        assert (sum(prob.values()) - 1) <= PRECISION
        self.sample_space = sample_space
        self.prob = prob

    def __len__(self) -> int:
        return len(self.sample_space)

    @property
    def full_measure_card(self) -> int:
        return len([v for v in self.prob.values() if v])

    def __getitem__(self, outcomes: int | tuple[int, ...] | set[int] | slice[int, int, None]) -> float:
        match outcomes:
            case int():
                return self.prob[outcomes]
            case tuple() | set():
                return sum(self.prob[o] for o in outcomes)
            case slice():
                return sum(self.prob[o] for o in range(outcomes.start, outcomes.stop))

    def __repr__(self) -> str:
        return pformat(self.prob)
    def __str__(self) -> str:
        return self.__repr__()

    def condition(self, event: set[int]) -> Probability:
        """Update via simple conditioning.""" #TODO: virtual evidence updating and stuff?
        assert event <= self.sample_space
        new_prob = {o: p if o in event else 0.0 for o, p in self.prob.items()}
        new_prob = {o: p/sum(new_prob.values()) for o, p in new_prob.items()}
        return Probability(new_prob, self.sample_space)

    def __eq__(self, other: Probability) -> bool:
        return (
            self.sample_space == other.sample_space
            and max(self[i]-other[i] for i in self.sample_space) <= PRECISION
        )

def generate_random_probs(
    sample_space: set[int],
    n_probs: int,
    full_measure_cardinality_range: tuple[int, int]
) -> list[Probability]:
    assert 0 <= full_measure_cardinality_range[0] <= full_measure_cardinality_range[1] <= len(sample_space)
    ss_list: list[int] = list(sample_space)
    probs: list[Probability] = []
    cardinality_range = list(range(full_measure_cardinality_range[0], full_measure_cardinality_range[1]+1))
    for _ in range(n_probs):
        full_measure_cardinality = random.choice(cardinality_range)
        full_measure_set = random.sample(ss_list, full_measure_cardinality)
        prob_vec = np.random.rand(full_measure_cardinality)
        prob_vec /= prob_vec.sum()
        prob = dict(zip(full_measure_set, prob_vec.tolist()))
        probs.append(Probability(prob, sample_space))
    return probs

In [3]:
random.seed(SEED)

# sample_space_cardinality: int = 128
# sample_space: set[int] = set(range(sample_space_cardinality))

# probs = generate_random_probs(sample_space, 10, (5, 50))
# p = probs[0].condition(set(range(50)))
# print(p[0:10])
# pprint({i: p.prob[i] for i in range(0, 10)})
# print(p[set(range(10))])

In [4]:
def filter_nonzero_prob(prob: dict[int, float]) -> dict[int, float]:
    return {o: p for o, p in prob.items() if p > 0}

class WeightedProbabilitySet:
    __slots__ = ("sample_space", "probs", "weights")
    sample_space: set[int]
    probs: list[Probability]
    weights: list[float]

    def __init__(self, probs: list[Probability], weights: list[float]) -> None:
        sample_space = probs[0].sample_space
        assert all(p.sample_space == sample_space for p in probs[1:])
        assert len(probs) == len(weights)
        assert all(0<=w<=1 for w in weights), weights
        assert max(weights) == 1
        self.sample_space = sample_space
        self.probs = probs
        self.weights = weights

    def condition(self, event: set[int]) -> WeightedProbabilitySet:
        assert event <= self.sample_space
        max_w_prob = max(
            w*p[event]
            for p, w in zip(self.probs, self.weights)
        )
        assert max_w_prob > 0, f"{max_w_prob = }"
        new_probs = [p.condition(event) for p in self.probs]
        new_weights = [
            max(
                w_ * p_[event]
                for p_, w_, new_p_ in zip(self.probs, self.weights, new_probs)
                if new_p_ == new_p
            ) / max_w_prob
            for p, w, new_p in zip(self.probs, self.weights, new_probs)
        ]
        return WeightedProbabilitySet(new_probs, new_weights)

    def __len__(self) -> int:
        return len(self.sample_space)

    def print(self) -> None:
        pprint(
            list(
                enumerate(
                    zip(
                        self.weights,
                        [filter_nonzero_prob(p.prob) for p in self.probs]
                    )
                )
            )
        )



In [5]:
sample_space_cardinality: int = 128
sample_space: set[int] = set(range(sample_space_cardinality))
n_probs = 16
probs = generate_random_probs(sample_space, n_probs, (127, 128)) #TODO: ~patch so that the weight of P is 0 if P(E)==0
weights = [1.0] * n_probs
wps = WeightedProbabilitySet(probs, weights)

true_event = {42}
events = [set(range(20, 90)), set(range(40, 120)), set(range(41, 50)), set(range(42, 120))]
# events

wpss = [wps]
for i, event in enumerate(events):
    print(i)
    wps = wps.condition(event)
    wpss.append(wps)

final_wps = wpss[-1]

0
1
2
3


In [6]:
final_set_pre_weights = {i: wpss[0].probs[i][42:50] for i in range(n_probs)}
final_set_pre_weights = {i: p/max(final_set_pre_weights.values()) for i, p in final_set_pre_weights.items()}
pprint(final_set_pre_weights)
# wpss[0].probs[1][42:50]

{0: 0.7045419278930906,
 1: 0.8895088249704213,
 2: 0.8997739049840343,
 3: 0.880611888486982,
 4: 0.7389514589386774,
 5: 0.9549307946497739,
 6: 1.0,
 7: 0.4565196548648335,
 8: 0.612266128662657,
 9: 0.9761453831564684,
 10: 0.7518461926322993,
 11: 0.7590641253637334,
 12: 0.9074133677256787,
 13: 0.41847676485104074,
 14: 0.6962064490424879,
 15: 0.7503839093625884}


In [7]:
final_wps.print()

[(0,
  (0.7045419278930906,
   {42: 0.18030712638081486,
    43: 0.14193732051944974,
    44: 0.1096140002727213,
    45: 0.035075922220441265,
    46: 0.16724946696503232,
    47: 0.11324604253807947,
    48: 0.18538055166509101,
    49: 0.06718956943837001})),
 (1,
  (0.8895088249704215,
   {42: 0.12617625187567047,
    43: 0.1819254310844089,
    44: 0.09827429359449594,
    45: 0.13844685294050402,
    46: 0.16532114436736056,
    47: 0.10884098283713879,
    48: 0.018807506948177163,
    49: 0.16220753635224414})),
 (2,
  (0.8997739049840343,
   {42: 0.12714760503764905,
    43: 0.05860660169625595,
    44: 0.16944724537974823,
    45: 0.045344107754217225,
    46: 0.1690286275202723,
    47: 0.11566318748785566,
    48: 0.13651400747377854,
    49: 0.17824861765022307})),
 (3,
  (0.8806118884869821,
   {42: 0.08784138311277728,
    43: 0.11736961956937862,
    44: 0.12109395243637618,
    45: 0.023589175213281216,
    46: 0.16850653158928036,
    47: 0.10465376392500698,
    48: 