Skip to content

Add JIT compilation for HardConcrete sampling #4

@MaxGhenis

Description

@MaxGhenis

Description

Use torch.jit.script to compile the HardConcrete sampling code for better performance.

Motivation

The HardConcrete distribution sampling is called frequently during training. JIT compilation could provide speedups, especially for large models.

Proposed Implementation

@torch.jit.script
def sample_hard_concrete(logits: Tensor, temp: float, 
                         gamma: float, zeta: float) -> Tensor:
    # Optimized sampling code
    ...

Tasks

  • Profile current sampling performance
  • Add JIT decorators to hot paths
  • Ensure compatibility with autograd
  • Benchmark improvements

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions