Skip to content

Commit

Permalink
Improve the handling of td.MixtureSameFamily
Browse files Browse the repository at this point in the history
When MixtureProjectionNetwork is used for action distribution, we need to correctly
handle get_mode() and summarize_distribution() for it.
  • Loading branch information
emailweixu committed May 1, 2024
1 parent 98ec37d commit f6823cd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
11 changes: 10 additions & 1 deletion alf/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,10 +1128,19 @@ def get_mode(dist):
# approach to compute the mode, by using the mode of the component
# distribution that has the highest component probability.
# [B]
batch_shape = dist.batch_shape
ind = get_mode(dist.mixture_distribution)
# [B, num_component, d]
component_mode = get_mode(dist.component_distribution)
mode = component_mode[torch.arange(component_mode.shape[0]), ind]
if len(batch_shape) == 1:
mode = component_mode[torch.arange(batch_shape[0]), ind]
elif ind.ndim == 2:
d0, d1 = batch_shape
mode = component_mode[torch.arange(d0).unsqueeze(-1),
torch.arange(d1), ind]
else:
raise NotImplementedError(
"Batch shape %s is not supported" % batch_shape)
elif isinstance(dist, StableCauchy):
mode = dist.loc
elif isinstance(dist, td.Independent):
Expand Down
33 changes: 33 additions & 0 deletions alf/utils/summary_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def summarize_distribution(name, distributions):
* Normal, StableCauchy, Beta: mean and std of each dimension will be summarized
* Above distribution wrapped by Independent and TransformedDistribution:
the base distribution is summarized
* MixtureSameFamily: the mixture weights and the most likely component distribution
are summarized
* Tensor: each dimenstion dist[..., a] will be summarized
Note that unsupported distributions will be ignored (no error reported).
Expand All @@ -316,6 +318,24 @@ def _summarize_one(path, dist):
add_mean_hist_summary("%s_loc/%s/%s" % (name, path, a),
dist[..., a])
else:
ind = None
if isinstance(dist, td.MixtureSameFamily):
probs = dist.mixture_distribution.probs
n = probs.shape[-1]
if n <= 10: # 10 is arbitrarily chosen to avoid too many summaries
for i in range(n):
add_mean_hist_summary(
"%s_probs/%s/%s" % (name, path, i), probs[..., i])
else:
entropy = -torch.xlogy(probs, probs).sum(-1)
add_mean_hist_summary("%s_cond_entropy/%s" % (name, path),
entropy)
probs = probs.reshape(-1, probs.shape[-1]).mean(0)
entropy = -torch.xlogy(probs, probs).sum()
alf.summary.scalar("%s_entropy/%s" % (name, path), entropy)

ind = dist_utils.get_mode(dist.mixture_distribution)
dist = dist.component_distribution
dist = dist_utils.get_base_dist(dist)
if isinstance(dist, (td.Normal, dist_utils.StableCauchy,
dist_utils.TruncatedDistribution)):
Expand All @@ -327,6 +347,19 @@ def _summarize_one(path, dist):
else:
return

if ind is not None:
if len(ind.shape) == 1:
i0 = torch.arange(ind.shape[0])
loc = loc[i0, ind]
log_scale = log_scale[i0, ind]
elif len(ind.shape) == 2:
i0 = torch.arange(ind.shape[0]).unsqueeze(-1)
i1 = torch.arange(ind.shape[1])
loc = loc[i0, i1, ind]
log_scale = log_scale[i0, i1, ind]
else:
return

action_dim = loc.shape[-1]
for a in range(action_dim):
add_mean_hist_summary("%s_log_scale/%s/%s" % (name, path, a),
Expand Down

0 comments on commit f6823cd

Please sign in to comment.