diff --git a/src/pdl/pdl.py b/src/pdl/pdl.py index 0f0d9b412..d0248c1af 100644 --- a/src/pdl/pdl.py +++ b/src/pdl/pdl.py @@ -60,6 +60,8 @@ class InterpreterConfig(TypedDict, total=False): """ with_resample: bool """Allow the interpreter to raise the `Resample` exception.""" + ignore_factor: bool + """Do not evaluate the expression associated to the `factor` block but use `0` instead (so resample if `with_resample` is true).""" score: float | Ref[float] """Initial value of the score.""" event_loop: AbstractEventLoop diff --git a/src/pdl/pdl_infer.py b/src/pdl/pdl_infer.py index 5b7b4ac66..c779563ae 100644 --- a/src/pdl/pdl_infer.py +++ b/src/pdl/pdl_infer.py @@ -18,6 +18,8 @@ from .pdl_inference import ( infer_importance_sampling, infer_importance_sampling_parallel, + infer_majority_voting, + infer_majority_voting_parallel, infer_rejection_sampling, infer_rejection_sampling_parallel, infer_smc, @@ -32,7 +34,14 @@ class PpdlConfig(TypedDict, total=False): """Configuration parameters of the PDL interpreter.""" algo: Literal[ - "is", "parallel-is", "smc", "parallel-smc", "rejection", "parallel-rejection" + "is", + "parallel-is", + "smc", + "parallel-smc", + "rejection", + "parallel-rejection", + "maj", + "parallel-maj", ] num_particles: int max_workers: int @@ -101,6 +110,20 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg num_samples=num_particles, max_workers=max_workers, ) + case "maj": + dist = infer_majority_voting( + prog, config, scope, loc, num_particles=num_particles + ) + case "parallel-maj": + dist = infer_majority_voting_parallel( + prog, + config, + scope, + loc, + num_particles=num_particles, + max_workers=max_workers, + ) + case _: assert False, f"Unexpected algo: {algo}" match output: @@ -183,6 +206,8 @@ def main(): "parallel-smc", "rejection", "parallel-rejection", + "maj", + "parallel-maj", ], help="Choose inference algorithm.", default="smc", diff --git a/src/pdl/pdl_inference.py b/src/pdl/pdl_inference.py index 87bd609f6..2f29ee127 100644 --- a/src/pdl/pdl_inference.py +++ b/src/pdl/pdl_inference.py @@ -260,6 +260,37 @@ def gen(): return Categorical(samples) +def infer_majority_voting( # pylint: disable=too-many-arguments + prog: Program, + config: InterpreterConfig, + scope: Optional[ScopeType | dict[str, Any]], + loc: Optional[PdlLocationType], + # output: Literal["result", "all"], + *, + num_particles: int, +) -> Categorical[T]: + config["ignore_factor"] = True + return infer_importance_sampling( + prog, config, scope, loc, num_particles=num_particles + ) + + +def infer_majority_voting_parallel( # pylint: disable=too-many-arguments + prog: Program, + config: InterpreterConfig, + scope: Optional[ScopeType | dict[str, Any]], + loc: Optional[PdlLocationType], + # output: Literal["result", "all"], + *, + num_particles: int, + max_workers: Optional[int], +) -> Categorical[T]: + config["ignore_factor"] = True + return infer_importance_sampling_parallel( + prog, config, scope, loc, num_particles=num_particles, max_workers=max_workers + ) + + # async def _process_particle_async(state, model, num_particles): # with ImportanceSampling(num_particles) as sampler: # try: diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index bd72502ad..94f99fde4 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1131,9 +1131,13 @@ def loop_body(iidx, items): case CallBlock(): result, background, scope, trace = process_call(state, scope, block, loc) case FactorBlock(): - weight, trace = process_expr_of( - block, "factor", scope, append(loc, "factor") - ) + if state.ignore_factor: + weight = 0.0 + trace = block.model_copy() + else: + weight, trace = process_expr_of( + block, "factor", scope, append(loc, "factor") + ) state.score.ref += weight result = PdlConst("") background = DependentContext([]) diff --git a/src/pdl/pdl_interpreter_state.py b/src/pdl/pdl_interpreter_state.py index 16869ee87..1295974aa 100644 --- a/src/pdl/pdl_interpreter_state.py +++ b/src/pdl/pdl_interpreter_state.py @@ -32,6 +32,8 @@ class InterpreterState(BaseModel): """Id generator for the UI.""" with_resample: bool = False """Allow the interpreter to raise the `Resample` exception.""" + ignore_factor: bool = False + """Do not evaluate the expression associated to the `factor` block but use `0` instead (so resample if `with_resample` is true).""" # The following are shared variable that should be modified by side effects imported: dict[str, tuple[ScopeType, BlockType]] = {}