/
minipyro.py
629 lines (519 loc) · 22.9 KB
/
minipyro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
Mini Pyro
---------
This file contains a minimal implementation of the Pyro Probabilistic
Programming Language. The API (method signatures, etc.) match that of
the full implementation as closely as possible. This file is independent
of the rest of Pyro, with the exception of the :mod:`pyro.distributions`
module.
An accompanying example that makes use of this implementation can be
found at examples/minipyro.py.
"""
import functools
import warnings
import weakref
from collections import OrderedDict, namedtuple
import torch
from pyro.distributions import validation_enabled
from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam
import funsor
# Funsor repreresents distributions in a fundamentally different way from
# torch.Distributions and Pyro: funsor distributions are densities whereas
# torch Distributions are samplers. This class is a compatibility wrapper
# between the two. It is used only internally in the sample() function.
class Distribution(object):
def __init__(self, funsor_dist, sample_inputs=None):
assert isinstance(funsor_dist, funsor.Funsor)
assert not sample_inputs or all(
isinstance(inp.dtype, int) for inp in sample_inputs.values()
)
self.funsor_dist = funsor_dist
self.output = self.funsor_dist.inputs["value"]
self.sample_inputs = sample_inputs
def log_prob(self, value):
result = self.funsor_dist(value=value)
if self.sample_inputs:
result = result + funsor.tensor.Tensor(
torch.zeros(*(size.dtype for size in self.sample_inputs.values())),
self.sample_inputs,
)
return result
# Draw a sample.
def __call__(self):
with funsor.interpretations.eager:
dist = self.funsor_dist(value="value")
delta = dist.sample(frozenset(["value"]), sample_inputs=self.sample_inputs)
if isinstance(delta, funsor.cnf.Contraction):
assert len(delta.terms) == 2
assert any(isinstance(t, funsor.delta.Delta) for t in delta.terms)
delta = [t for t in delta.terms if isinstance(t, funsor.delta.Delta)][0]
assert isinstance(delta, funsor.delta.Delta)
return delta.terms[0][1][0]
# Similar to torch.distributions.Distribution.expand().
def expand_inputs(self, name, size):
if name in self.funsor_dist.inputs:
assert self.funsor_dist.inputs[name] == funsor.Bint[int(size)]
return self
inputs = OrderedDict([(name, funsor.Bint[int(size)])])
if self.sample_inputs:
inputs.update(self.sample_inputs)
return Distribution(self.funsor_dist, sample_inputs=inputs)
# Pyro keeps track of two kinds of global state:
# i) The effect handler stack, which enables non-standard interpretations of
# Pyro primitives like sample();
# See http://docs.pyro.ai/en/0.3.1/poutine.html
# ii) Trainable parameters in the Pyro ParamStore;
# See http://docs.pyro.ai/en/0.3.1/parameters.html
PYRO_STACK = []
PARAM_STORE = {} # maps name -> (unconstrained_value, constraint)
def get_param_store():
return PARAM_STORE
# The base effect handler class (called Messenger here for consistency with Pyro).
class Messenger(object):
def __init__(self, fn=None):
self.fn = fn
# Effect handlers push themselves onto the PYRO_STACK.
# Handlers earlier in the PYRO_STACK are applied first.
def __enter__(self):
PYRO_STACK.append(self)
def __exit__(self, *args, **kwargs):
assert PYRO_STACK[-1] is self
PYRO_STACK.pop()
def process_message(self, msg):
pass
def postprocess_message(self, msg):
pass
def __call__(self, *args, **kwargs):
with self:
return self.fn(*args, **kwargs)
# A first useful example of an effect handler.
# trace records the inputs and outputs of any primitive site it encloses,
# and returns a dictionary containing that data to the user.
class trace(Messenger):
def __enter__(self):
super(trace, self).__enter__()
self.trace = OrderedDict()
return self.trace
# trace illustrates why we need postprocess_message in addition to process_message:
# We only want to record a value after all other effects have been applied
def postprocess_message(self, msg):
assert (
msg["type"] != "sample" or msg["name"] not in self.trace
), "sample sites must have unique names"
self.trace[msg["name"]] = msg.copy()
def get_trace(self, *args, **kwargs):
self(*args, **kwargs)
return self.trace
# A second example of an effect handler for setting the value at a sample site.
# This illustrates why effect handlers are a useful PPL implementation technique:
# We can compose trace and replay to replace values but preserve distributions,
# allowing us to compute the joint probability density of samples under a model.
# See the definition of elbo(...) below for an example of this pattern.
class replay(Messenger):
def __init__(self, fn, guide_trace):
self.guide_trace = guide_trace
super(replay, self).__init__(fn)
def process_message(self, msg):
if msg["name"] in self.guide_trace:
msg["value"] = self.guide_trace[msg["name"]]["value"]
# block allows the selective application of effect handlers to different parts of a model.
# Sites hidden by block will only have the handlers below block on the PYRO_STACK applied,
# allowing inference or other effectful computations to be nested inside models.
class block(Messenger):
def __init__(self, fn=None, hide_fn=lambda msg: True):
self.hide_fn = hide_fn
super(block, self).__init__(fn)
def process_message(self, msg):
if self.hide_fn(msg):
msg["stop"] = True
# seed is used to fix the RNG state when calling a model.
class seed(Messenger):
def __init__(self, fn=None, rng_seed=None):
self.rng_seed = rng_seed
super(seed, self).__init__(fn)
def __enter__(self):
self.old_rng_state = torch.get_rng_state()
torch.manual_seed(self.rng_seed)
def __exit__(self, type, value, traceback):
torch.set_rng_state(self.old_rng_state)
# Conditional independence is recorded as a plate context at each site.
CondIndepStackFrame = namedtuple("CondIndepStackFrame", ["name", "size", "dim"])
# This implementation of vectorized PlateMessenger broadcasts and
# records a cond_indep_stack which is later used to convert
# torch.Tensors to funsor.tensor.Tensors.
class PlateMessenger(Messenger):
def __init__(self, fn, name, size, dim):
assert dim < 0
self.frame = CondIndepStackFrame(name, size, dim)
super(PlateMessenger, self).__init__(fn)
def process_message(self, msg):
if msg["type"] in ("sample", "param"):
assert self.frame.dim not in msg["cond_indep_stack"]
msg["cond_indep_stack"][self.frame.dim] = self.frame
if msg["type"] == "sample":
msg["fn"] = msg["fn"].expand_inputs(self.frame.name, self.frame.size)
# This converts raw tensor.Tensors to funsor.Funsors with .inputs and .output
# based on information in msg["cond_indep_stack"] and msg["fn"].
def tensor_to_funsor(value, cond_indep_stack, output):
assert isinstance(value, torch.Tensor)
event_shape = output.shape
batch_shape = value.shape[: value.dim() - len(event_shape)]
if torch._C._get_tracing_state():
with funsor.tensor.ignore_jit_warnings():
batch_shape = tuple(map(int, batch_shape))
inputs = OrderedDict()
data = value
for dim, size in enumerate(batch_shape):
if size == 1:
data = data.squeeze(dim - value.dim())
else:
frame = cond_indep_stack[dim - len(batch_shape)]
assert size == frame.size, (size, frame)
inputs[frame.name] = funsor.Bint[int(size)]
value = funsor.tensor.Tensor(data, inputs, output.dtype)
assert value.output == output
return value
# The log_joint messenger is the main way of recording log probabilities.
# This is roughly the Funsor equivalent to pyro.poutine.trace.
class log_joint(Messenger):
def __enter__(self):
super(log_joint, self).__enter__()
self.log_factors = OrderedDict() # maps site name to log_prob factor
self.plates = set()
return self
def process_message(self, msg):
if msg["type"] == "sample":
if msg["value"] is None:
# Create a delayed sample.
msg["value"] = funsor.Variable(msg["name"], msg["fn"].output)
def postprocess_message(self, msg):
if msg["type"] == "sample":
assert (
msg["name"] not in self.log_factors
), "all sites must have unique names"
log_prob = msg["fn"].log_prob(msg["value"])
self.log_factors[msg["name"]] = log_prob
self.plates.update(f.name for f in msg["cond_indep_stack"].values())
# apply_stack is called by pyro.sample and pyro.param.
# It is responsible for applying each Messenger to each effectful operation.
def apply_stack(msg):
for pointer, handler in enumerate(reversed(PYRO_STACK)):
handler.process_message(msg)
# When a Messenger sets the "stop" field of a message,
# it prevents any Messengers above it on the stack from being applied.
if msg.get("stop"):
break
if msg["value"] is None:
msg["value"] = msg["fn"](*msg["args"])
if isinstance(msg["value"], torch.Tensor):
msg["value"] = tensor_to_funsor(
msg["value"], msg["cond_indep_stack"], msg["output"]
)
# A Messenger that sets msg["stop"] == True also prevents application
# of postprocess_message by Messengers above it on the stack
# via the pointer variable from the process_message loop
for handler in PYRO_STACK[-pointer - 1 :]:
handler.postprocess_message(msg)
return msg
# sample is an effectful version of Distribution.sample(...)
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def sample(name, fn, obs=None, infer=None):
# Wrap the funsor distribution in a Pyro-compatible way.
fn = Distribution(fn)
# if there are no active Messengers, we just draw a sample and return it as expected:
if not PYRO_STACK:
return fn()
# Otherwise, we initialize a message...
initial_msg = {
"type": "sample",
"name": name,
"fn": fn,
"args": (),
"value": obs,
"cond_indep_stack": {}, # maps dim to CondIndepStackFrame
"output": fn.output,
"infer": {} if infer is None else infer,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
assert isinstance(msg["value"], funsor.Funsor)
return msg["value"]
# param is an effectful version of PARAM_STORE.setdefault that also handles constraints.
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def param(
name,
init_value=None,
constraint=torch.distributions.constraints.real,
event_dim=None,
):
cond_indep_stack = {}
output = None
if init_value is not None:
if event_dim is None:
event_dim = init_value.dim()
output = funsor.Reals[init_value.shape[init_value.dim() - event_dim :]]
def fn(init_value, constraint):
if name in PARAM_STORE:
unconstrained_value, constraint = PARAM_STORE[name]
else:
# Initialize with a constrained value.
assert init_value is not None
with torch.no_grad():
constrained_value = init_value.detach()
unconstrained_value = torch.distributions.transform_to(constraint).inv(
constrained_value
)
unconstrained_value.requires_grad_()
unconstrained_value._funsor_metadata = (cond_indep_stack, output)
PARAM_STORE[name] = unconstrained_value, constraint
# Transform from unconstrained space to constrained space.
constrained_value = torch.distributions.transform_to(constraint)(
unconstrained_value
)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return tensor_to_funsor(
constrained_value, *unconstrained_value._funsor_metadata
)
# if there are no active Messengers, we just draw a sample and return it as expected:
if not PYRO_STACK:
return fn(init_value, constraint)
# Otherwise, we initialize a message...
initial_msg = {
"type": "param",
"name": name,
"fn": fn,
"args": (init_value, constraint),
"value": None,
"cond_indep_stack": cond_indep_stack, # maps dim to CondIndepStackFrame
"output": output,
}
# ...and use apply_stack to send it to the Messengers
msg = apply_stack(initial_msg)
assert isinstance(msg["value"], funsor.Funsor)
return msg["value"]
# boilerplate to match the syntax of actual pyro.plate:
def plate(name, size, dim):
return PlateMessenger(fn=None, name=name, size=size, dim=dim)
# This is a thin wrapper around the `torch.optim.Optimizer` class that
# dynamically generates optimizers for dynamically generated parameters.
# See http://docs.pyro.ai/en/0.3.1/optimization.html
class PyroOptim(object):
def __init__(self, optim_args):
self.optim_args = optim_args
# Each parameter will get its own optimizer, which we keep track
# of using this dictionary keyed on parameters.
self.optim_objs = {}
def __call__(self, params):
for param in params:
# If we've seen this parameter before, use the previously
# constructed optimizer.
if param in self.optim_objs:
optim = self.optim_objs[param]
# If we've never seen this parameter before, construct
# an Adam optimizer and keep track of it.
else:
optim = self.TorchOptimizer([param], **self.optim_args)
self.optim_objs[param] = optim
# Take a gradient step for the parameter param.
optim.step()
# We wrap some commonly used PyTorch optimizers.
class Adam(PyroOptim):
TorchOptimizer = torch.optim.Adam
class ClippedAdam(PyroOptim):
TorchOptimizer = _ClippedAdam
# This is a unified interface for stochastic variational inference in Pyro.
# The actual construction of the loss is taken care of by `loss`.
# See http://docs.pyro.ai/en/0.3.1/inference_algos.html
class SVI(object):
def __init__(self, model, guide, optim, loss):
self.model = model
self.guide = guide
self.optim = optim
self.loss = loss
# This method handles running the model and guide, constructing the loss
# function, and taking a gradient step.
def step(self, *args, **kwargs):
# This wraps both the call to `model` and `guide` in a `trace` so that
# we can record all the parameters that are encountered. Note that
# further tracing occurs inside of `loss`.
with trace() as param_capture:
# We use block here to allow tracing to record parameters only.
with block(hide_fn=lambda msg: msg["type"] != "param"):
loss = self.loss(self.model, self.guide, *args, **kwargs)
# Differentiate the loss.
funsor.to_data(loss).backward()
# Grab all the parameters from the trace.
params = [site["value"].data.unconstrained() for site in param_capture.values()]
# Take a step w.r.t. each parameter in params.
self.optim(params)
# Zero out the gradients so that they don't accumulate.
for p in params:
p.grad = torch.zeros_like(p.grad)
return loss.item()
# TODO(eb8680) Replace this with funsor.Expectation.
def Expectation(log_probs, costs, sum_vars, prod_vars):
result = 0
for cost in costs:
log_prob = funsor.sum_product.sum_product(
sum_op=funsor.ops.logaddexp,
prod_op=funsor.ops.add,
factors=log_probs,
plates=prod_vars,
eliminate=(prod_vars | sum_vars) - frozenset(cost.inputs),
)
term = funsor.Integrate(log_prob, cost, sum_vars & frozenset(cost.inputs))
term = term.reduce(funsor.ops.add, prod_vars & frozenset(cost.inputs))
result += term
return result
# This is a basic implementation of the Evidence Lower Bound, which is the
# fundamental objective in Variational Inference.
# See http://pyro.ai/examples/svi_part_i.html for details.
# This implementation uses a Dice estimator similar to TraceEnum_ELBO.
def elbo(model, guide, *args, **kwargs):
with log_joint() as guide_log_joint:
guide(*args, **kwargs)
with log_joint() as model_log_joint:
model(*args, **kwargs)
# contract out auxiliary variables in the guide
guide_log_probs = list(guide_log_joint.log_factors.values())
guide_aux_vars = (
frozenset().union(*(f.inputs for f in guide_log_probs))
- frozenset(guide_log_joint.plates)
- frozenset(model_log_joint.log_factors)
)
if guide_aux_vars:
guide_log_probs = funsor.sum_product.partial_sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
guide_log_probs,
plates=frozenset(guide_log_joint.plates),
eliminate=guide_aux_vars,
)
# contract out auxiliary variables in the model
model_log_probs = list(model_log_joint.log_factors.values())
model_aux_vars = (
frozenset().union(*(f.inputs for f in model_log_probs))
- frozenset(model_log_joint.plates)
- frozenset(guide_log_joint.log_factors)
)
if model_aux_vars:
model_log_probs = funsor.sum_product.partial_sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
model_log_probs,
plates=frozenset(model_log_joint.plates),
eliminate=model_aux_vars,
)
# compute remaining plates and sum_dims
plates = frozenset().union(
*(model_log_joint.plates.intersection(f.inputs) for f in model_log_probs)
)
plates = plates | frozenset().union(
*(guide_log_joint.plates.intersection(f.inputs) for f in guide_log_probs)
)
sum_vars = frozenset().union(
model_log_joint.log_factors, guide_log_joint.log_factors
) - frozenset(model_aux_vars | guide_aux_vars)
# Accumulate costs from model and guide and log_probs from guide.
# Cf. pyro.infer.traceenum_elbo._compute_dice_elbo()
# https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/traceenum_elbo.py#L119
costs = []
log_probs = []
for p in model_log_probs:
costs.append(p)
for q in guide_log_probs:
costs.append(-q)
log_probs.append(q)
# Compute expected cost.
# Cf. pyro.infer.util.Dice.compute_expectation()
# https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
elbo = Expectation(
tuple(log_probs), tuple(costs), sum_vars=sum_vars, prod_vars=plates
)
loss = -elbo
assert not loss.inputs
return loss
# Base class for elbo implementations.
class ELBO(object):
def __init__(self, **kwargs):
self.options = kwargs
def __call__(self, model, guide, *args, **kwargs):
return elbo(model, guide, *args, **kwargs)
# This is a wrapper for compatibility with full Pyro.
class Trace_ELBO(ELBO):
def __call__(self, model, guide, *args, **kwargs):
with funsor.montecarlo.MonteCarlo():
return elbo(model, guide, *args, **kwargs)
class TraceMeanField_ELBO(ELBO):
# TODO Use exact KLs where possible.
pass
class TraceEnum_ELBO(ELBO):
# TODO allow mixing of sampling and exact integration
def __call__(self, model, guide, *args, **kwargs):
if self.options.get("optimize", None):
with funsor.optimizer.optimize:
elbo_expr = elbo(model, guide, *args, **kwargs)
return funsor.reinterpret(elbo_expr)
return elbo(model, guide, *args, **kwargs)
# This is a PyTorch jit wrapper that (1) delays tracing until the first
# invocation, and (2) registers pyro.param() statements with torch.jit.trace.
# This version does not support variable number of args or non-tensor kwargs.
class Jit(object):
def __init__(self, fn, **kwargs):
self.fn = fn
self.ignore_jit_warnings = kwargs.get("ignore_jit_warnings", False)
self._compiled = None
self._param_trace = None
def __call__(self, *args):
# On first call, initialize params and save their names.
if self._param_trace is None:
with block(), trace() as tr, block(hide_fn=lambda m: m["type"] != "param"):
self.fn(*args)
self._param_trace = tr
# Augment args with reads from the global param store.
unconstrained_params = tuple(
param(name).data.unconstrained() for name in self._param_trace
)
params_and_args = unconstrained_params + args
# On first call, create a compiled elbo.
if self._compiled is None:
def compiled(*params_and_args):
unconstrained_params = params_and_args[: len(self._param_trace)]
args = params_and_args[len(self._param_trace) :]
for name, unconstrained_param in zip(
self._param_trace, unconstrained_params
):
constrained_param = param(name) # assume param has been initialized
assert constrained_param.data.unconstrained() is unconstrained_param
self._param_trace[name]["value"] = constrained_param
result = replay(self.fn, guide_trace=self._param_trace)(*args)
assert not result.inputs
assert result.output == funsor.Real
return funsor.to_data(result)
with validation_enabled(False), warnings.catch_warnings():
if self.ignore_jit_warnings:
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
self._compiled = torch.jit.trace(
compiled, params_and_args, check_trace=False
)
data = self._compiled(*params_and_args)
return funsor.tensor.Tensor(data)
# This is a jit wrapper for ELBO implementations.
class Jit_ELBO(ELBO):
def __init__(self, elbo, **kwargs):
super(Jit_ELBO, self).__init__(**kwargs)
self._elbo = elbo(**kwargs)
self._compiled = {} # maps (model,guide) -> Jit instances
def __call__(self, model, guide, *args):
if (model, guide) not in self._compiled:
elbo = functools.partial(self._elbo, model, guide)
self._compiled[model, guide] = Jit(elbo, **self.options)
return self._compiled[model, guide](*args)
def JitTrace_ELBO(**kwargs):
return Jit_ELBO(Trace_ELBO, **kwargs)
def JitTraceMeanField_ELBO(**kwargs):
return Jit_ELBO(TraceMeanField_ELBO, **kwargs)
def JitTraceEnum_ELBO(**kwargs):
return Jit_ELBO(TraceEnum_ELBO, **kwargs)