Skip to content

Commit 273ef76

Browse files
committed
Merge branch 'feat/beartype-claw-cleanup' into feat/beartype-claw-extend
2 parents da3d7de + bfc7c35 commit 273ef76

6 files changed

Lines changed: 1 addition & 39 deletions

File tree

src/lcm/_beartype_conf.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
77
"""
88

9-
from collections.abc import Callable
10-
11-
from beartype import BeartypeConf, BeartypeStrategy, beartype
9+
from beartype import BeartypeConf, BeartypeStrategy
1210

1311
from lcm.exceptions import (
1412
CategoricalDefinitionError,
@@ -34,24 +32,6 @@ def _conf(exc: type[Exception]) -> BeartypeConf:
3432
)
3533

3634

37-
def beartype_init[C](conf: BeartypeConf) -> Callable[[type[C]], type[C]]:
38-
"""Class decorator that beartype-checks `__init__` only.
39-
40-
Bare `@beartype` on a class wraps every method, which surfaces
41-
annotation drift in helpers like `compute_gridpoints(**kwargs: float)`
42-
where runtime kwargs are actually JAX arrays. Restricting decoration
43-
to `__init__` keeps the perimeter check (parameter types at
44-
construction) without policing every method's runtime types.
45-
46-
"""
47-
48-
def deco(cls: type[C]) -> type[C]:
49-
cls.__init__ = beartype(conf=conf)(cls.__init__) # ty: ignore[invalid-assignment]
50-
return cls
51-
52-
return deco
53-
54-
5535
# Used on `Regime` and `MarkovTransition`.
5636
REGIME_CONF = _conf(RegimeInitializationError)
5737

src/lcm/grids/continuous.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import jax.numpy as jnp
77

8-
from lcm._beartype_conf import GRID_CONF, beartype_init
98
from lcm.dtypes import canonical_float_dtype
109
from lcm.exceptions import GridInitializationError, format_messages
1110
from lcm.grids import coordinates as grid_coordinates
@@ -100,7 +99,6 @@ def replace(self, **kwargs: float) -> UniformContinuousGrid:
10099
) from e
101100

102101

103-
@beartype_init(GRID_CONF)
104102
class LinSpacedGrid(UniformContinuousGrid):
105103
"""A linearly spaced grid of continuous values.
106104
@@ -126,7 +124,6 @@ def get_coordinate(self, value: FloatND) -> FloatND:
126124
)
127125

128126

129-
@beartype_init(GRID_CONF)
130127
class LogSpacedGrid(UniformContinuousGrid):
131128
"""A logarithmically spaced grid of continuous values.
132129
@@ -209,7 +206,6 @@ def _init_uniform_grid(
209206
object.__setattr__(grid, "distributed", distributed)
210207

211208

212-
@beartype_init(GRID_CONF)
213209
@dataclass(frozen=True, kw_only=True, init=False)
214210
class IrregSpacedGrid(ContinuousGrid):
215211
"""A grid of continuous values at irregular (user-specified) points.

src/lcm/grids/discrete.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import jax.numpy as jnp
22

3-
from lcm._beartype_conf import GRID_CONF, beartype_init
43
from lcm.grids.base import Grid
54
from lcm.grids.categorical import _validate_discrete_grid
65
from lcm.typing import Int1D
76
from lcm.utils.containers import get_field_names_and_values
87

98

10-
@beartype_init(GRID_CONF)
119
class DiscreteGrid(Grid):
1210
"""A discrete grid defining the outcome space of a categorical variable.
1311

src/lcm/grids/piecewise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import jax.numpy as jnp
55
import portion
66

7-
from lcm._beartype_conf import GRID_CONF, beartype_init
87
from lcm.exceptions import GridInitializationError, format_messages
98
from lcm.grids import coordinates as grid_coordinates
109
from lcm.grids.continuous import ContinuousGrid
@@ -44,7 +43,6 @@ def __init__(
4443
object.__setattr__(self, "n_points", jnp.int32(n_points))
4544

4645

47-
@beartype_init(GRID_CONF)
4846
@dataclass(frozen=True, kw_only=True)
4947
class PiecewiseLinSpacedGrid(ContinuousGrid):
5048
"""A piecewise linearly spaced grid with multiple segments.
@@ -108,7 +106,6 @@ def get_coordinate(self, value: FloatND) -> FloatND:
108106
return self._cumulative_offsets[piece_idx] + local_coord
109107

110108

111-
@beartype_init(GRID_CONF)
112109
@dataclass(frozen=True, kw_only=True)
113110
class PiecewiseLogSpacedGrid(ContinuousGrid):
114111
"""A piecewise logarithmically spaced grid with multiple segments.

src/lcm/shocks/ar1.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import jax.numpy as jnp
88
from jax.scipy.stats.norm import cdf
99

10-
from lcm._beartype_conf import GRID_CONF, beartype_init
1110
from lcm.shocks._base import (
1211
_gauss_hermite_normal,
1312
_mixture_cdf,
@@ -30,7 +29,6 @@ def draw_shock(
3029
) -> ScalarFloat: ...
3130

3231

33-
@beartype_init(GRID_CONF)
3432
@dataclass(frozen=True, kw_only=True)
3533
class Tauchen(_ShockGridAR1):
3634
r"""AR(1) shock discretized via Tauchen (1986).
@@ -130,7 +128,6 @@ def draw_shock(
130128
)
131129

132130

133-
@beartype_init(GRID_CONF)
134131
@dataclass(frozen=True, kw_only=True)
135132
class Rouwenhorst(_ShockGridAR1):
136133
r"""AR(1) shock discretized via Rouwenhorst (1995).
@@ -199,7 +196,6 @@ def draw_shock(
199196
)
200197

201198

202-
@beartype_init(GRID_CONF)
203199
@dataclass(frozen=True, kw_only=True)
204200
class TauchenNormalMixture(_ShockGridAR1):
205201
r"""AR(1) shock with mixture-of-normals innovations, discretized via Tauchen.

src/lcm/shocks/iid.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import jax.numpy as jnp
77
from jax.scipy.stats.norm import cdf
88

9-
from lcm._beartype_conf import GRID_CONF, beartype_init
109
from lcm.shocks._base import (
1110
_gauss_hermite_normal,
1211
_mixture_cdf,
@@ -28,7 +27,6 @@ def draw_shock(
2827
) -> ScalarFloat: ...
2928

3029

31-
@beartype_init(GRID_CONF)
3230
@dataclass(frozen=True, kw_only=True)
3331
class Uniform(_ShockGridIID):
3432
r"""Discretized iid uniform shock: $U(\text{start}, \text{stop})$.
@@ -63,7 +61,6 @@ def draw_shock(
6361
)
6462

6563

66-
@beartype_init(GRID_CONF)
6764
@dataclass(frozen=True, kw_only=True)
6865
class Normal(_ShockGridIID):
6966
r"""Discretized iid normal shock: $N(\mu_\varepsilon, \sigma_\varepsilon^2)$.
@@ -138,7 +135,6 @@ def draw_shock(
138135
return params["mu"] + params["sigma"] * jax.random.normal(key=key)
139136

140137

141-
@beartype_init(GRID_CONF)
142138
@dataclass(frozen=True, kw_only=True)
143139
class LogNormal(_ShockGridIID):
144140
r"""Discretized iid log-normal shock: $\ln X \sim N(\mu, \sigma^2)$."""
@@ -204,7 +200,6 @@ def draw_shock(
204200
return jnp.exp(params["mu"] + params["sigma"] * jax.random.normal(key=key))
205201

206202

207-
@beartype_init(GRID_CONF)
208203
@dataclass(frozen=True, kw_only=True)
209204
class NormalMixture(_ShockGridIID):
210205
r"""Discretized IID normal-mixture shock.

0 commit comments

Comments
 (0)