Description
Use torch.jit.script to compile the HardConcrete sampling code for better performance.
Motivation
The HardConcrete distribution sampling is called frequently during training. JIT compilation could provide speedups, especially for large models.
Proposed Implementation
@torch.jit.script
def sample_hard_concrete(logits: Tensor, temp: float,
gamma: float, zeta: float) -> Tensor:
# Optimized sampling code
...
Tasks
Description
Use torch.jit.script to compile the HardConcrete sampling code for better performance.
Motivation
The HardConcrete distribution sampling is called frequently during training. JIT compilation could provide speedups, especially for large models.
Proposed Implementation
Tasks