In [1]:
from __future__ import annotations

from pprint import pprint, pformat
import random
import typing as typ

import numpy as np
from tqdm import tqdm

In [2]:
PRECISION = 1e-4
SEED = 42

Event = int | tuple[int, ...] | set[int] | slice

class Probability:
    __slots__ = ("prob")
    prob: np.ndarray #TODO: reimplement this but using vectors instead?

    def __init__(self, prob: np.ndarray | dict[int, float], sample_space_size: int | None = None) -> None:
        match sample_space_size, prob:
            case None, dict():
                sample_space_size = max(prob)
                prob = np.array([prob.get(i, 0.0) for i in range(sample_space_size)])
            case None, np.ndarray():
                sample_space_size = len(prob)
            case int(), dict():
                assert sample_space_size >= max(prob), (sample_space_size, max(prob))
                prob = np.array([prob.get(i, 0.0) for i in range(sample_space_size)])
            case int(), np.ndarray():
                assert sample_space_size >= len(prob), (sample_space_size, len(prob))
                prob = np.array(prob.tolist() + [0.0] * (sample_space_size - len(prob)))
                assert len(prob) == sample_space_size #TODO: remove
        assert (prob.sum() - 1) <= PRECISION
        self.prob = prob

    def __len__(self) -> int:
        return len(self.prob)

    @property
    def full_measure_card(self) -> int:
        return (self.prob > 0).sum()

    def __getitem__(self, event: Event) -> float:
        match event:
            case int():
                return self.prob[event]
            case tuple() | set():
                return sum(self.prob[o] for o in event)
            case slice():
                return self.prob[event].sum()

    def get_event_probs(self, *events: Event) -> dict[Event, float]:
        return {
            (tuple(event) if isinstance(event, set) else event): self[event] for event in events
        }

    def __repr__(self) -> str:
        return pformat(self.prob)
    def __str__(self) -> str:
        return self.__repr__()

    def condition(self, event: set[int]) -> Probability:
        """Update via simple conditioning.""" #TODO: virtual evidence updating and stuff?
        assert max(event) <= len(self)
        mask = np.array([i in event for i in range(len(self))])
        new_prob = self.prob * mask
        new_prob /= new_prob.sum()
        # new_prob = np.array([v if i in event else 0.0 for i, v in enumerate(self.prob)])
        # new_prob /= new_prob.sum()
        return Probability(new_prob)

    def __eq__(self, other: Probability) -> bool:
        return abs(self.prob - other.prob).max() <= PRECISION

def generate_random_probs(
    sample_space_size: int,
    n_probs: int,
    full_measure_cardinality_range: range | None = None, # (inclusive)
) -> list[Probability]:
    if full_measure_cardinality_range is None:
        full_measure_cardinality_range = range(sample_space_size, sample_space_size+1)
    fmc_min = full_measure_cardinality_range.start
    fmc_max = full_measure_cardinality_range.stop
    assert 0 <= fmc_min <= fmc_max-1 <= sample_space_size, (fmc_min, fmc_max-1, sample_space_size)

    sample_space: range = range(sample_space_size)
    probs: list[Probability] = []
    cardinality_range = list(range(fmc_min, fmc_max))
    for _ in range(n_probs):
        prob_fm_card = random.choice(cardinality_range)
        prom_fm_set = random.sample(sample_space, prob_fm_card)
        prob_vec = np.random.rand(prob_fm_card)
        prob_vec /= prob_vec.sum()
        prob = dict(zip(prom_fm_set, prob_vec.tolist()))
        probs.append(Probability(prob))
    return probs

In [3]:
random.seed(SEED)

sample_space_cardinality: int = 128
sample_space: set[int] = set(range(sample_space_cardinality))

probs = generate_random_probs(sample_space_cardinality, 10, range(5, 50))
p = probs[0].condition(set(range(50)))
print(p[0:10])
pprint(p.get_event_probs(*range(10)))
print(p[set(range(10))])

0.0740671845284474
{0: np.float64(0.0006104863238909937),
 1: np.float64(0.0),
 2: np.float64(0.0),
 3: np.float64(0.043567937919940944),
 4: np.float64(0.029888760284615467),
 5: np.float64(0.0),
 6: np.float64(0.0),
 7: np.float64(0.0),
 8: np.float64(0.0),
 9: np.float64(0.0)}
0.0740671845284474


In [4]:
def filter_nonzero_prob(prob: Probability) -> dict[int, float]:
    return {o: p for o, p in enumerate(prob.prob) if p > 0}

class WeightedProbabilitySet:
    __slots__ = ("probs", "weights")
    probs: list[Probability]
    weights: list[float]

    def __init__(self, probs: list[Probability], weights: list[float]) -> None:
        sample_space_size = len(probs[0])
        assert all(len(p) == sample_space_size for p in probs[1:])
        assert len(probs) == len(weights) != 0
        assert all(0 <= w <= 1 for w in weights), weights
        assert max(weights) == 1
        self.probs = probs
        self.weights = weights

    def condition(self, event: set[int]) -> WeightedProbabilitySet:
        assert max(event) <= len(self)
        max_w_prob = max(
            w*p[event]
            for p, w in zip(self.probs, self.weights)
        )
        assert max_w_prob > 0, f"{max_w_prob = }"
        new_probs = [p.condition(event) for p in self.probs]
        new_weights = [
            max(
                w_ * p_[event]
                for p_, w_, new_p_ in zip(self.probs, self.weights, new_probs)
                if new_p_ == new_p
            ) / max_w_prob
            for p, w, new_p in zip(self.probs, self.weights, new_probs)
        ]
        return WeightedProbabilitySet(new_probs, new_weights)

    def __len__(self) -> int:
        return len(self.probs[0])

    def print(self) -> None:
        pprint(
            list(
                enumerate(
                    zip(
                        self.weights,
                        list(map(filter_nonzero_prob, self.probs))
                    )
                )
            )
        )



In [5]:
sample_space_size: int = 128
n_probs = 16
probs = generate_random_probs(sample_space_size, n_probs, None) #TODO: ~patch so that the weight of P is 0 if P(E)==0
weights = [1.0] * n_probs
wps = WeightedProbabilitySet(probs, weights)

true_event = {42}
events = [set(range(20, 90)), set(range(40, 120)), set(range(41, 50)), set(range(42, 120))]
# events

wpss = [wps]
for i, event in enumerate(events):
    print(i)
    wps = wps.condition(event)
    wpss.append(wps)

final_wps = wpss[-1]

0
1
2
3


In [6]:
final_set_pre_weights = {i: wpss[0].probs[i][42:50] for i in range(n_probs)}
final_set_pre_weights = {i: p/max(final_set_pre_weights.values()) for i, p in final_set_pre_weights.items()}
pprint(final_set_pre_weights)
# wpss[0].probs[1][42:50]

{0: np.float64(1.0),
 1: np.float64(0.904396854765036),
 2: np.float64(0.847596440736716),
 3: np.float64(0.7734154441955019),
 4: np.float64(0.8995787894627736),
 5: np.float64(0.8295282517294996),
 6: np.float64(0.7149151343651895),
 7: np.float64(0.6218132176361868),
 8: np.float64(0.667845750219964),
 9: np.float64(0.6544621740637392),
 10: np.float64(0.7948172908461891),
 11: np.float64(0.9907814273056288),
 12: np.float64(0.8760379177099173),
 13: np.float64(0.8777375842057072),
 14: np.float64(0.8221194996180214),
 15: np.float64(0.8249148204437192)}


In [7]:
final_wps.print()

[(0,
  (np.float64(1.0),
   {42: np.float64(0.19062567462406527),
    43: np.float64(0.07846114917221775),
    44: np.float64(0.04530709738195018),
    45: np.float64(0.12329468661506182),
    46: np.float64(0.0724296250187454),
    47: np.float64(0.10191046520332736),
    48: np.float64(0.18152028382559077),
    49: np.float64(0.2064510181590415)})),
 (1,
  (np.float64(0.9043968547650363),
   {42: np.float64(0.16417391176644597),
    43: np.float64(0.13680962529679988),
    44: np.float64(0.25086072096851864),
    45: np.float64(0.04275840657937089),
    46: np.float64(0.10143917928183609),
    47: np.float64(0.010299474205268413),
    48: np.float64(0.04241230063965294),
    49: np.float64(0.25124638126210713)})),
 (2,
  (np.float64(0.8475964407367165),
   {42: np.float64(0.20901631821001887),
    43: np.float64(0.03961361003734532),
    44: np.float64(0.10202305575777552),
    45: np.float64(0.08750982486926345),
    46: np.float64(0.06211040753227306),
    47: np.float64(0.21968879