forked from pyro-ppl/pyro
-
Notifications
You must be signed in to change notification settings - Fork 0
/
strategies.py
216 lines (167 loc) · 7.31 KB
/
strategies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0
"""
These reparametrization strategies are registered with
:func:`~pyro.poutine.reparam_messenger.register_reparam_strategy` and are
accessed by name via ``poutine.reparam(config=name_of_strategy)`` .
See :func:`~pyro.poutine.handlers.reparam` for usage.
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Union
import torch
from torch.distributions import constraints
import pyro.distributions as dist
import pyro.poutine as poutine
from .loc_scale import LocScaleReparam
from .projected_normal import ProjectedNormalReparam
from .reparam import Reparam
from .softmax import GumbelSoftmaxReparam
from .stable import LatentStableReparam, StableReparam, SymmetricStableReparam
from .transform import TransformReparam
class Strategy(ABC):
"""
Abstract base class for reparametrizer configuration strategies.
Derived classes must implement the :meth:`configure` method.
:ivar dict config: A dictionary configuration. This will be populated the
first time the model is run. Thereafter it can be used as an argument
to ``poutine.reparam(config=___)``.
"""
# TODO(https://github.com/pyro-ppl/pyro/issues/2831) As part of refactoring
# Reparam objects to be nn.Modules:
# - make Strategy inherit from torch.nn.Module
# - make self.config a torch.nn.ModuleDict
def __init__(self):
# TODO(#2831) Make this a torch.nn.ModuleDict.
self.config: Dict[str, Optional[Reparam]] = {}
super().__init__()
@abstractmethod
def configure(self, msg: dict) -> Optional[Reparam]:
"""
Inputs a sample site and returns either None or a
:class:`~pyro.infer.reparam.reparam.Reparam` instance.
This will be called only on the first model execution; subsequent
executions will use the reparametrizer stored in ``self.config``.
:param dict msg: A sample site to possibly reparametrize.
:returns: An optional reparametrizer instance.
"""
raise NotImplementedError
def __call__(self, msg_or_fn: Union[dict, Callable]):
"""
Strategies can be used as decorators to reparametrize a model.
:param msg_or_fn: Public use: a model to be decorated. (Internal use: a
site to be configured for reparametrization).
"""
if isinstance(msg_or_fn, dict): # Internal use during configuration.
msg = msg_or_fn
name = msg["name"]
if name in self.config:
return self.config[name]
result = self.configure(msg)
self.config[name] = result
return result
else: # Public use as a decorator or handler.
fn = msg_or_fn
return poutine.reparam(fn, self)
class MinimalReparam(Strategy):
"""
Minimal reparametrization strategy that reparametrizes only those sites
that would otherwise lead to error, e.g.
:class:`~pyro.distributions.Stable` and
:class:`~pyro.distributions.ProjectedNormal` random variables.
Example::
@MinimalReparam()
def model(...):
...
which is equivalent to::
@poutine.reparam(config=MinimalReparam())
def model(...):
...
"""
def configure(self, msg: dict) -> Optional[Reparam]:
return _minimal_reparam(msg["fn"], msg["is_observed"])
def _minimal_reparam(fn, is_observed):
# Unwrap Independent, Masked, Transformed etc.
while isinstance(getattr(fn, "base_dist", None), dist.Distribution):
if isinstance(fn, torch.distributions.TransformedDistribution):
if _minimal_reparam(fn.base_dist, is_observed) is None:
return None # No need to reparametrize.
else:
return TransformReparam() # Then reparametrize new sites.
fn = fn.base_dist
if isinstance(fn, dist.Stable):
if not is_observed:
return LatentStableReparam()
elif fn.skew.requires_grad or fn.skew.any():
return StableReparam()
else:
return SymmetricStableReparam()
if isinstance(fn, dist.ProjectedNormal):
return ProjectedNormalReparam()
# TODO apply CircularReparam for VonMises
class AutoReparam(Strategy):
"""
Applies a recommended set of reparametrizers. These currently include:
:class:`MinimalReparam`,
:class:`~pyro.infer.reparam.transform.TransformReparam`, a fully-learnable
:class:`~pyro.infer.reparam.loc_scale.LocScaleReparam`, and
:class:`~pyro.infer.reparam.softmax.GumbelSoftmaxReparam`.
Example::
@AutoReparam()
def model(...):
...
which is equivalent to::
@poutine.reparam(config=AutoReparam())
def model(...):
...
.. warning:: This strategy may change behavior across Pyro releases.
To inspect or save a given behavior, extract the ``.config`` dict after
running the model at least once.
:param centered: Optional centering parameter for
:class:`~pyro.infer.reparam.loc_scale.LocScaleReparam` reparametrizers.
If None (default), centering will be learned. If a float in
``[0.0,1.0]``, then a fixed centering. To completely decenter (e.g. in
MCMC), set to 0.0.
"""
def __init__(self, *, centered: Optional[float] = None):
assert centered is None or isinstance(centered, float)
super().__init__()
self.centered = centered
def configure(self, msg: dict) -> Optional[Reparam]:
# Focus on tricks for latent sites.
fn = msg["fn"]
if not msg["is_observed"]:
# Unwrap Independent, Masked, Transformed etc.
while isinstance(getattr(fn, "base_dist", None), dist.Distribution):
if isinstance(fn, torch.distributions.TransformedDistribution):
return TransformReparam() # Then reparametrize new sites.
fn = fn.base_dist
# Try to apply a GumbelSoftmaxReparam.
if isinstance(fn, torch.distributions.RelaxedOneHotCategorical):
return GumbelSoftmaxReparam()
# Apply a learnable LocScaleReparam.
result = _loc_scale_reparam(msg["name"], fn, self.centered)
if result is not None:
return result
# Apply minimal reparametrizers.
return _minimal_reparam(fn, msg["is_observed"])
def _loc_scale_reparam(name, fn, centered):
if "_decentered" in name:
return # Avoid infinite recursion.
# Check for location-scale families.
params = set(fn.arg_constraints)
if not {"loc", "scale"}.issubset(params):
return
# Check for unconstrained support.
if not _is_unconstrained(fn.support):
return
# TODO reparametrize only if parameters are variable. We might guess
# based on whether parameters are differentiable, .requires_grad. See
# https://github.com/pyro-ppl/pyro/pull/2824
# Create an elementwise-learnable reparametrizer.
shape_params = sorted(params - {"loc", "scale"})
return LocScaleReparam(centered=centered, shape_params=shape_params)
def _is_unconstrained(constraint):
# Unwrap constraints.independent.
while hasattr(constraint, "base_constraint"):
constraint = constraint.base_constraint
return constraint == constraints.real