Skip to content
2 changes: 1 addition & 1 deletion effectful/handlers/jax/_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def jax_at_set(arr, index_key, val):


@defdata.register(jax.Array)
def _embed_array(op, *args, **kwargs):
def _embed_array(ty, op, *args, **kwargs):
if (
op is jax_getitem
and not isinstance(args[0], Term)
Expand Down
80 changes: 40 additions & 40 deletions effectful/handlers/numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def Cauchy(loc=0.0, scale=1.0, **kwargs) -> dist.Cauchy:

@defdata.register(dist.Cauchy)
class CauchyTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.Cauchy, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -413,7 +413,7 @@ def Gumbel(loc=0.0, scale=1.0, **kwargs) -> dist.Gumbel:

@defdata.register(dist.Gumbel)
class GumbelTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.Gumbel, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -431,7 +431,7 @@ def Laplace(loc=0.0, scale=1.0, **kwargs) -> dist.Laplace:

@defdata.register(dist.Laplace)
class LaplaceTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.Laplace, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -449,7 +449,7 @@ def LogNormal(loc=0.0, scale=1.0, **kwargs) -> dist.LogNormal:

@defdata.register(dist.LogNormal)
class LogNormalTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.LogNormal, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -467,7 +467,7 @@ def Logistic(loc=0.0, scale=1.0, **kwargs) -> dist.Logistic:

@defdata.register(dist.Logistic)
class LogisticTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.Logistic, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -485,7 +485,7 @@ def Normal(loc=0.0, scale=1.0, **kwargs) -> dist.Normal:

@defdata.register(dist.Normal)
class NormalTerm(_DistributionTerm):
def __init__(self, op, loc, scale, **kwargs):
def __init__(self, ty, op, loc, scale, **kwargs):
super().__init__(dist.Normal, op, loc, scale, **kwargs)
self.loc = loc
self.scale = scale
Expand All @@ -503,7 +503,7 @@ def StudentT(df, loc=0.0, scale=1.0, **kwargs) -> dist.StudentT:

@defdata.register(dist.StudentT)
class StudentTTerm(_DistributionTerm):
def __init__(self, op, df, loc, scale, **kwargs):
def __init__(self, ty, op, df, loc, scale, **kwargs):
super().__init__(dist.StudentT, op, df, loc, scale, **kwargs)
self.df = df
self.loc = loc
Expand All @@ -522,7 +522,7 @@ def BernoulliProbs(probs, **kwargs) -> dist.BernoulliProbs:

@defdata.register(dist.BernoulliProbs)
class BernoulliProbsTerm(_DistributionTerm):
def __init__(self, op, probs, **kwargs):
def __init__(self, ty, op, probs, **kwargs):
super().__init__(dist.BernoulliProbs, op, probs, **kwargs)
self.probs = probs

Expand All @@ -539,7 +539,7 @@ def CategoricalProbs(probs, **kwargs) -> dist.CategoricalProbs:

@defdata.register(dist.CategoricalProbs)
class CategoricalProbsTerm(_DistributionTerm):
def __init__(self, op, probs, **kwargs):
def __init__(self, ty, op, probs, **kwargs):
super().__init__(dist.CategoricalProbs, op, probs, **kwargs)
self.probs = probs

Expand All @@ -556,7 +556,7 @@ def GeometricProbs(probs, **kwargs) -> dist.GeometricProbs:

@defdata.register(dist.GeometricProbs)
class GeometricProbsTerm(_DistributionTerm):
def __init__(self, op, probs, **kwargs):
def __init__(self, ty, op, probs, **kwargs):
super().__init__(dist.GeometricProbs, op, probs, **kwargs)
self.probs = probs

Expand All @@ -573,7 +573,7 @@ def BernoulliLogits(logits, **kwargs) -> dist.BernoulliLogits:

@defdata.register(dist.BernoulliLogits)
class BernoulliLogitsTerm(_DistributionTerm):
def __init__(self, op, logits, **kwargs):
def __init__(self, ty, op, logits, **kwargs):
super().__init__(dist.BernoulliLogits, op, logits, **kwargs)
self.logits = logits

Expand All @@ -590,7 +590,7 @@ def CategoricalLogits(logits, **kwargs) -> dist.CategoricalLogits:

@defdata.register(dist.CategoricalLogits)
class CategoricalLogitsTerm(_DistributionTerm):
def __init__(self, op, logits, **kwargs):
def __init__(self, ty, op, logits, **kwargs):
super().__init__(dist.CategoricalLogits, op, logits, **kwargs)
self.logits = logits

Expand All @@ -607,7 +607,7 @@ def GeometricLogits(logits, **kwargs) -> dist.GeometricLogits:

@defdata.register(dist.GeometricLogits)
class GeometricLogitsTerm(_DistributionTerm):
def __init__(self, op, logits, **kwargs):
def __init__(self, ty, op, logits, **kwargs):
super().__init__(dist.GeometricLogits, op, logits, **kwargs)
self.logits = logits

Expand All @@ -624,7 +624,7 @@ def Beta(concentration1, concentration0, **kwargs) -> dist.Beta:

@defdata.register(dist.Beta)
class BetaTerm(_DistributionTerm):
def __init__(self, op, concentration1, concentration0, **kwargs):
def __init__(self, ty, op, concentration1, concentration0, **kwargs):
super().__init__(dist.Beta, op, concentration1, concentration0, **kwargs)
self.concentration1 = concentration1
self.concentration0 = concentration0
Expand All @@ -642,7 +642,7 @@ def Kumaraswamy(concentration1, concentration0, **kwargs) -> dist.Kumaraswamy:

@defdata.register(dist.Kumaraswamy)
class KumaraswamyTerm(_DistributionTerm):
def __init__(self, op, concentration1, concentration0, **kwargs):
def __init__(self, ty, op, concentration1, concentration0, **kwargs):
super().__init__(dist.Kumaraswamy, op, concentration1, concentration0, **kwargs)
self.concentration1 = concentration1
self.concentration0 = concentration0
Expand All @@ -660,7 +660,7 @@ def BinomialProbs(probs, total_count=1, **kwargs) -> dist.BinomialProbs:

@defdata.register(dist.BinomialProbs)
class BinomialProbsTerm(_DistributionTerm):
def __init__(self, op, probs, total_count, **kwargs):
def __init__(self, ty, op, probs, total_count, **kwargs):
super().__init__(dist.BinomialProbs, op, probs, total_count, **kwargs)
self.probs = probs
self.total_count = total_count
Expand All @@ -678,7 +678,7 @@ def NegativeBinomialProbs(total_count, probs, **kwargs) -> dist.NegativeBinomial

@defdata.register(dist.NegativeBinomialProbs)
class NegativeBinomialProbsTerm(_DistributionTerm):
def __init__(self, op, total_count, probs, **kwargs):
def __init__(self, ty, op, total_count, probs, **kwargs):
super().__init__(dist.NegativeBinomialProbs, op, total_count, probs, **kwargs)
self.total_count = total_count
self.probs = probs
Expand All @@ -698,7 +698,7 @@ def MultinomialProbs(probs, total_count=1, **kwargs) -> dist.MultinomialProbs:

@defdata.register(dist.MultinomialProbs)
class MultinomialProbsTerm(_DistributionTerm):
def __init__(self, op, probs, total_count, **kwargs):
def __init__(self, ty, op, probs, total_count, **kwargs):
super().__init__(dist.MultinomialProbs, op, probs, total_count, **kwargs)
self.probs = probs
self.total_count = total_count
Expand All @@ -716,7 +716,7 @@ def BinomialLogits(logits, total_count=1, **kwargs) -> dist.BinomialLogits:

@defdata.register(dist.BinomialLogits)
class BinomialLogitsTerm(_DistributionTerm):
def __init__(self, op, logits, total_count, **kwargs):
def __init__(self, ty, op, logits, total_count, **kwargs):
super().__init__(dist.BinomialLogits, op, logits, total_count, **kwargs)
self.logits = logits
self.total_count = total_count
Expand All @@ -736,7 +736,7 @@ def NegativeBinomialLogits(

@defdata.register(dist.NegativeBinomialLogits)
class NegativeBinomialLogitsTerm(_DistributionTerm):
def __init__(self, op, total_count, logits, **kwargs):
def __init__(self, ty, op, total_count, logits, **kwargs):
super().__init__(dist.NegativeBinomialLogits, op, total_count, logits, **kwargs)
self.total_count = total_count
self.logits = logits
Expand All @@ -756,7 +756,7 @@ def MultinomialLogits(logits, total_count=1, **kwargs) -> dist.MultinomialLogits

@defdata.register(dist.MultinomialLogits)
class MultinomialLogitsTerm(_DistributionTerm):
def __init__(self, op, logits, total_count, **kwargs):
def __init__(self, ty, op, logits, total_count, **kwargs):
super().__init__(dist.MultinomialLogits, op, logits, total_count, **kwargs)
self.logits = logits
self.total_count = total_count
Expand All @@ -774,7 +774,7 @@ def Chi2(df, **kwargs) -> dist.Chi2:

@defdata.register(dist.Chi2)
class Chi2Term(_DistributionTerm):
def __init__(self, op, df, **kwargs):
def __init__(self, ty, op, df, **kwargs):
super().__init__(dist.Chi2, op, df, **kwargs)
self.df = df

Expand All @@ -791,7 +791,7 @@ def Dirichlet(concentration, **kwargs) -> dist.Dirichlet:

@defdata.register(dist.Dirichlet)
class DirichletTerm(_DistributionTerm):
def __init__(self, op, concentration, **kwargs):
def __init__(self, ty, op, concentration, **kwargs):
super().__init__(dist.Dirichlet, op, concentration, **kwargs)
self.concentration = concentration

Expand All @@ -810,7 +810,7 @@ def DirichletMultinomial(

@defdata.register(dist.DirichletMultinomial)
class DirichletMultinomialTerm(_DistributionTerm):
def __init__(self, op, concentration, total_count, **kwargs):
def __init__(self, ty, op, concentration, total_count, **kwargs):
super().__init__(
dist.DirichletMultinomial, op, concentration, total_count, **kwargs
)
Expand All @@ -832,7 +832,7 @@ def Exponential(rate=1.0, **kwargs) -> dist.Exponential:

@defdata.register(dist.Exponential)
class ExponentialTerm(_DistributionTerm):
def __init__(self, op, rate, **kwargs):
def __init__(self, ty, op, rate, **kwargs):
super().__init__(dist.Exponential, op, rate, **kwargs)
self.rate = rate

Expand All @@ -849,7 +849,7 @@ def Poisson(rate, **kwargs) -> dist.Poisson:

@defdata.register(dist.Poisson)
class PoissonTerm(_DistributionTerm):
def __init__(self, op, rate, **kwargs):
def __init__(self, ty, op, rate, **kwargs):
super().__init__(dist.Poisson, op, rate, **kwargs)
self.rate = rate

Expand All @@ -866,7 +866,7 @@ def Gamma(concentration, rate=1.0, **kwargs) -> dist.Gamma:

@defdata.register(dist.Gamma)
class GammaTerm(_DistributionTerm):
def __init__(self, op, concentration, rate, **kwargs):
def __init__(self, ty, op, concentration, rate, **kwargs):
super().__init__(dist.Gamma, op, concentration, rate, **kwargs)
self.concentration = concentration
self.rate = rate
Expand All @@ -884,7 +884,7 @@ def HalfCauchy(scale=1.0, **kwargs) -> dist.HalfCauchy:

@defdata.register(dist.HalfCauchy)
class HalfCauchyTerm(_DistributionTerm):
def __init__(self, op, scale, **kwargs):
def __init__(self, ty, op, scale, **kwargs):
super().__init__(dist.HalfCauchy, op, scale, **kwargs)
self.scale = scale

Expand All @@ -901,7 +901,7 @@ def HalfNormal(scale=1.0, **kwargs) -> dist.HalfNormal:

@defdata.register(dist.HalfNormal)
class HalfNormalTerm(_DistributionTerm):
def __init__(self, op, scale, **kwargs):
def __init__(self, ty, op, scale, **kwargs):
super().__init__(dist.HalfNormal, op, scale, **kwargs)
self.scale = scale

Expand All @@ -918,7 +918,7 @@ def LKJCholesky(dim, concentration=1.0, **kwargs) -> dist.LKJCholesky:

@defdata.register(dist.LKJCholesky)
class LKJCholeskyTerm(_DistributionTerm):
def __init__(self, op, dim, concentration, **kwargs):
def __init__(self, ty, op, dim, concentration, **kwargs):
super().__init__(dist.LKJCholesky, op, dim, concentration, **kwargs)
self.dim = dim
self.concentration = concentration
Expand All @@ -939,7 +939,7 @@ def MultivariateNormal(
@defdata.register(dist.MultivariateNormal)
class MultivariateNormalTerm(_DistributionTerm):
def __init__(
self, op, loc, covariance_matrix, precision_matrix, scale_tril, **kwargs
self, ty, op, loc, covariance_matrix, precision_matrix, scale_tril, **kwargs
):
super().__init__(
dist.MultivariateNormal,
Expand Down Expand Up @@ -972,7 +972,7 @@ def Pareto(scale, alpha, **kwargs) -> dist.Pareto:

@defdata.register(dist.Pareto)
class ParetoTerm(_DistributionTerm):
def __init__(self, op, scale, alpha, **kwargs):
def __init__(self, ty, op, scale, alpha, **kwargs):
super().__init__(dist.Pareto, op, scale, alpha, **kwargs)
self.scale = scale
self.alpha = alpha
Expand All @@ -990,7 +990,7 @@ def Uniform(low=0.0, high=1.0, **kwargs) -> dist.Uniform:

@defdata.register(dist.Uniform)
class UniformTerm(_DistributionTerm):
def __init__(self, op, low, high, **kwargs):
def __init__(self, ty, op, low, high, **kwargs):
super().__init__(dist.Uniform, op, low, high, **kwargs)
self.low = low
self.high = high
Expand All @@ -1008,7 +1008,7 @@ def VonMises(loc, concentration, **kwargs) -> dist.VonMises:

@defdata.register(dist.VonMises)
class VonMisesTerm(_DistributionTerm):
def __init__(self, op, loc, concentration, **kwargs):
def __init__(self, ty, op, loc, concentration, **kwargs):
super().__init__(dist.VonMises, op, loc, concentration, **kwargs)
self.loc = loc
self.concentration = concentration
Expand All @@ -1026,7 +1026,7 @@ def Weibull(scale, concentration, **kwargs) -> dist.Weibull:

@defdata.register(dist.Weibull)
class WeibullTerm(_DistributionTerm):
def __init__(self, op, scale, concentration, **kwargs):
def __init__(self, ty, op, scale, concentration, **kwargs):
super().__init__(dist.Weibull, op, scale, concentration, **kwargs)
self.scale = scale
self.concentration = concentration
Expand All @@ -1044,7 +1044,7 @@ def Wishart(df, scale_tril, **kwargs) -> dist.Wishart:

@defdata.register(dist.Wishart)
class WishartTerm(_DistributionTerm):
def __init__(self, op, df, scale_tril, **kwargs):
def __init__(self, ty, op, df, scale_tril, **kwargs):
super().__init__(dist.Wishart, op, df, scale_tril, **kwargs)
self.df = df
self.scale_tril = scale_tril
Expand All @@ -1062,7 +1062,7 @@ def Delta(v=0.0, log_density=0.0, event_dim=0, **kwargs) -> dist.Delta:

@defdata.register(dist.Delta)
class DeltaTerm(_DistributionTerm):
def __init__(self, op, v, log_density, event_dim, **kwargs):
def __init__(self, ty, op, v, log_density, event_dim, **kwargs):
super().__init__(dist.Delta, op, v, log_density, event_dim, **kwargs)
self.v = v
self.log_density = log_density
Expand All @@ -1082,7 +1082,7 @@ def LowRankMultivariateNormal(

@defdata.register(dist.LowRankMultivariateNormal)
class LowRankMultivariateNormalTerm(_DistributionTerm):
def __init__(self, op, loc, cov_factor, cov_diag, **kwargs):
def __init__(self, ty, op, loc, cov_factor, cov_diag, **kwargs):
super().__init__(
dist.LowRankMultivariateNormal, op, loc, cov_factor, cov_diag, **kwargs
)
Expand All @@ -1107,7 +1107,7 @@ def RelaxedBernoulliLogits(

@defdata.register(dist.RelaxedBernoulliLogits)
class RelaxedBernoulliLogitsTerm(_DistributionTerm):
def __init__(self, op, temperature, logits, **kwargs):
def __init__(self, ty, op, temperature, logits, **kwargs):
super().__init__(dist.RelaxedBernoulliLogits, op, temperature, logits, **kwargs)
self.temperature = temperature
self.logits = logits
Expand All @@ -1127,7 +1127,7 @@ def Independent(base_dist, reinterpreted_batch_ndims, **kwargs) -> dist.Independ

@defdata.register(dist.Independent)
class IndependentTerm(_DistributionTerm):
def __init__(self, op, base_dist, reinterpreted_batch_ndims, **kwargs):
def __init__(self, ty, op, base_dist, reinterpreted_batch_ndims, **kwargs):
super().__init__(
dist.Independent, op, base_dist, reinterpreted_batch_ndims, **kwargs
)
Expand Down
4 changes: 3 additions & 1 deletion effectful/handlers/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,9 @@ class _DistributionTerm(Term[TorchDistribution], TorchDistribution):
_args: tuple
_kwargs: dict

def __init__(self, op: Operation[Any, TorchDistribution], *args, **kwargs):
def __init__(
self, ty: type, op: Operation[Any, TorchDistribution], *args, **kwargs
):
self._op = op
self._args = args
self._kwargs = kwargs
Expand Down
2 changes: 1 addition & 1 deletion effectful/handlers/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def torch_getitem(x: torch.Tensor, key: tuple[IndexElement, ...]) -> torch.Tenso


@defdata.register(torch.Tensor)
def _embed_tensor(op, *args, **kwargs):
def _embed_tensor(ty, op, *args, **kwargs):
if (
op is torch_getitem
and not isinstance(args[0], Term)
Expand Down
Loading
Loading