Skip to content

Commit

Permalink
fix LDS, improve test
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Jun 4, 2019
1 parent 0c6e794 commit dfd50f9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 19 deletions.
42 changes: 29 additions & 13 deletions src/gluonts/distribution/lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ def __init__(

@property
def batch_shape(self) -> Tuple:
return self.emission_coeff.shape[1]
return self.emission_coeff[0].shape[:1] + (self.seq_length,)

@property
def event_shape(self) -> Tuple:
return self.seq_length, self.output_dim
return (self.output_dim,)

@property
def event_dim(self) -> int:
Expand Down Expand Up @@ -157,7 +157,7 @@ def log_prob(
Returns
-------
Tensor
Log probabilities, shape (batch_size, )
Log probabilities, shape (batch_size, seq_length)
Tensor
Final mean, shape (batch_size, latent_dim)
Tensor
Expand Down Expand Up @@ -188,7 +188,7 @@ def kalman_filter(
Returns
-------
Tensor
Log probabilities, shape (batch_size, )
Log probabilities, shape (batch_size, seq_length)
Tensor
Mean of p(l_T | l_{T-1}), where T is seq_length, with shape
(batch_size, latent_dim)
Expand Down Expand Up @@ -301,15 +301,15 @@ def sample(
samples_eps_obs = (
Gaussian(noise_std.zeros_like(), noise_std)
.sample(num_samples)
.split(axis=2, num_outputs=self.seq_length, squeeze_axis=True)
.split(axis=-3, num_outputs=self.seq_length, squeeze_axis=True)
)

# Sample standard normal for all time steps
# samples_eps_std_normal[t]: (num_samples, batch_size, obs_dim, 1)
samples_std_normal = (
Gaussian(noise_std.zeros_like(), noise_std.ones_like())
.sample(num_samples)
.split(axis=2, num_outputs=self.seq_length, squeeze_axis=True)
.split(axis=-3, num_outputs=self.seq_length, squeeze_axis=True)
)

# Sample the prior state.
Expand All @@ -328,6 +328,8 @@ def sample(
# innovation_coeff_t: (num_samples, batch_size, 1, latent_dim)
emission_coeff_t, transition_coeff_t, innovation_coeff_t = [
_broadcast_param(coeff, axes=[0], sizes=[num_samples])
if num_samples is not None
else coeff
for coeff in [
self.emission_coeff[t],
self.transition_coeff[t],
Expand All @@ -337,18 +339,27 @@ def sample(

# Expand residuals as well
# residual_t: (num_samples, batch_size, obs_dim, 1)
residual_t = _broadcast_param(
self.residuals[t].expand_dims(axis=-1),
axes=[0],
sizes=[num_samples],
residual_t = (
_broadcast_param(
self.residuals[t].expand_dims(axis=-1),
axes=[0],
sizes=[num_samples],
)
if num_samples is not None
else self.residuals[t].expand_dims(axis=-1)
)

# (num_samples, batch_size, 1, obs_dim)
samples_t = (
F.linalg_gemm2(emission_coeff_t, samples_lat_state)
+ residual_t
+ samples_eps_obs[t]
).swapaxes(dim1=2, dim2=3)
)
samples_t = (
samples_t.swapaxes(dim1=2, dim2=3)
if num_samples is not None
else samples_t.swapaxes(dim1=1, dim2=2)
)
samples_seq.append(samples_t)

# sample next state: (num_samples, batch_size, latent_dim, 1)
Expand All @@ -359,11 +370,16 @@ def sample(
)

# (num_samples, batch_size, seq_length, obs_dim)
samples = F.concat(*samples_seq, dim=2)
samples = F.concat(*samples_seq, dim=-2)
return (
samples
if scale is None
else F.broadcast_mul(samples, scale.expand_dims(axis=1))
else F.broadcast_mul(
samples,
scale.expand_dims(axis=1).expand_dims(axis=0)
if num_samples is not None
else scale.expand_dims(axis=1),
)
)

def sample_marginals(
Expand Down
24 changes: 18 additions & 6 deletions test/distribution/test_lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def assert_shape_and_finite(x, shape):
# (coefficients and noise terms) and observations, and the log-density
# of the observations that were computed using pykalman
# (https://pykalman.github.io/).
@pytest.mark.skip
@pytest.mark.parametrize(
"data_filename",
[
Expand Down Expand Up @@ -63,9 +62,12 @@ def test_lds_likelihood(data_filename):
output_dim = lds.emission_coeff[0].shape[1]
latent_dim = lds.emission_coeff[0].shape[2]

assert lds.batch_shape == (batch_size, time_length)
assert lds.event_shape == (output_dim,)

likelihood, final_mean, final_cov = lds.log_prob(targets)

assert_shape_and_finite(likelihood, shape=(batch_size, time_length))
assert_shape_and_finite(likelihood, shape=lds.batch_shape)
assert_shape_and_finite(final_mean, shape=(batch_size, latent_dim))
assert_shape_and_finite(
final_cov, shape=(batch_size, latent_dim, latent_dim)
Expand All @@ -82,16 +84,26 @@ def test_lds_likelihood(data_filename):
f"obtained likelihood = {likelihood_per_item}",
)

samples = lds.sample_marginals(num_samples=100)

assert_shape_and_finite(
samples, shape=(100,) + lds.batch_shape + lds.event_shape
)

sample = lds.sample_marginals()

assert_shape_and_finite(sample, shape=lds.batch_shape + lds.event_shape)

samples = lds.sample(num_samples=100)

assert_shape_and_finite(
samples, shape=(100, batch_size, time_length, output_dim)
samples, shape=(100,) + lds.batch_shape + lds.event_shape
)

sample = lds.sample()

assert_shape_and_finite(sample, lds.batch_shape + lds.event_shape)
assert_shape_and_finite(sample, shape=lds.batch_shape + lds.event_shape)

ll = lds.log_prob(sample)
ll, _, _ = lds.log_prob(sample)

assert_shape_and_finite(ll, lds.batch_shape)
assert_shape_and_finite(ll, shape=lds.batch_shape)

0 comments on commit dfd50f9

Please sign in to comment.