Skip to content

Commit

Permalink
do not require bk_flow_mask_fn for all distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Apr 29, 2024
1 parent 31d7946 commit fb2f862
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions src/pyjuice/layer/input_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,8 @@ def backward(self, data: torch.Tensor, node_flows: torch.Tensor,
)

# Handle the masked input nodes
if missing_mask is not None:
if missing_mask is not None and self.bk_flow_mask_fn is not None:
if not self.provided("_flows_mask_kernel"):
assert self.bk_flow_mask_fn is not None, f"`bk_flow_mask_fn` is not implemented for distribution {type(self.dist)}."
self._flows_mask_kernel = self._compile_triton_kernel(self._flows_kernel_template, flow_fn = self.bk_flow_mask_fn)

self._flows_mask_kernel[grid](
Expand Down

0 comments on commit fb2f862

Please sign in to comment.