Skip to content

Commit

Permalink
Correct multimodal base setting.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed May 10, 2024
1 parent a13d7d7 commit 321f345
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions harmonic/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ def bijector_fn(params: jnp.ndarray):

self.bijector_fn = bijector_fn

self.default_base_centers = [
jnp.zeros(self.n_features),
jnp.full(self.n_features, 2.0),
]

def make_flow(self, temperature: float = 1.0):
"""
Make distrax distribution containing the rational quadratic spline flow.
Expand Down Expand Up @@ -272,25 +277,33 @@ def make_flow(self, temperature: float = 1.0):
)

else:
if self.base_centers is not None:
base_dist = (
if self.base_centers is None:
base_centers = self.default_base_centers
else:
base_centers = self.base_centers

# base_dist = (
# distrax.MultivariateNormalFullCovariance(
# loc=base_centers[0],
# covariance_matrix=jnp.eye(self.n_features) * temperature,
# ),
# )

base_dist = distrax.MultivariateNormalFullCovariance(
loc=base_centers[0],
covariance_matrix=jnp.eye(self.n_features) * temperature,
)

for i in range(1, len(base_centers)):
base_dist = distrax.MixtureOfTwo(
0.5,
distrax.MultivariateNormalFullCovariance(
loc=self.base_centers[0],
loc=base_centers[i],
covariance_matrix=jnp.eye(self.n_features) * temperature,
),
base_dist,
)

for i in range(1, len(self.base_centers)):
gaussian_center = self.base_centers[i]
base_dist = distrax.MixtureOfTwo(
0.5,
distrax.MultivariateNormalFullCovariance(
loc=gaussian_center,
covariance_matrix=jnp.eye(self.n_features) * temperature,
),
base_dist,
)

return base_dist, flow

def __call__(self, x: jnp.array, temperature: float = 1.0) -> jnp.array:
Expand Down

0 comments on commit 321f345

Please sign in to comment.