### Utilizing `@trace_impl` and `@vmap_impl`

When designing a function or method, you may not always consider whether it is `JIT`-compatible. However, this property becomes crucial in specific scenarios, such as solving Hyperparameter Optimization (HPO) problems. For more details on deploying HPO with EvoX, refer to [Efficient HPO with EvoX](#/guide/user/3-hpo).

A typical characteristic of such problems is that only certain parts of the algorithm need modification—for instance, the `step` method of an algorithm. This allows you to avoid rewriting the entire algorithm. In such cases, you can use the `@trace_impl` or `@vmap_impl` decorator to rewrite the function as a trace-JIT-time or vmap-JIT-time proxy for the specified `target` method.

The decorators [`@trace_impl`](#trace_impl) and [`@vmap_impl`](#vmap_impl) accept a single input parameter: the target method invoked when not tracing JIT. These decorators are applicable only to member methods within a `jit_class`.

Since the annotated function serves as a rewritten version of the target function, it must maintain identical input/output signatures (e.g., number and types of arguments). Otherwise, the resulting behavior is undefined.

If the annotated function is intended for use with `vmap`, it must satisfy three additional constraints:

1. **No In-Place Operations on Attributes:**
   The algorithm must not include methods that perform in-place operations on its attributes.

```python
class ExampleAlgorithm(Algorithm):
    def __init__(self, ...):
        self.pop = torch.rand(10, 10)  # Attribute of the algorithm

    def step_in_place(self):  # Method with in-place operations
        self.pop.copy_(pop)

    def step_out_of_place(self):  # Method without in-place operations
        self.pop = pop
```

2. **Avoid Python Control Flow:**
   The code logic must not rely on Python control flow structures. To handle Python control flow, use [`TracingCond`](#TracingCond), [`TracingWhile`](#TracingWhile), and [`TracingSwitch`](#TracingSwitch).

```python
class ExampleAlgorithm(Algorithm):
    def __init__(self, ...):
        self.pop = torch.rand(10, 10)

    def plus(self, y):
        self.pop += y

    def minus(self, y):
        self.pop -= y

    def step_with_python_control_flow(self, y):  # Function with Python control flow
        x = torch.rand()
        if x > 0.5:
            self.plus(y)
        else:
            self.minus(y)

    def step_without_python_control_flow(self, y):  # Function without Python control flow
        x = torch.rand()
        cond = x > 0.5
        _if_else_ = TracingCond(self.plus, self.minus)
        _if_else_.cond(cond, y)
        self.pop = pop
```

3. **Avoid In-Place Operations on `self`:**
   In-place operations on `self` are not well-defined and cannot be compiled. Since, in tracing mode, variables outside the method may be incorrectly interpreted as static variables, use state to track them.

```python
@jit_class
class ExampleAlgorithm(Algorithm):
    def __init__(self, pop_size, ...):
        super().__init__()
        self.pop = torch.rand(pop_size, pop_size)

    def strategy_1(self):  # One update strategy
        new_pop = self.pop * self.pop
        self.pop = new_pop

    def strategy_2(self):  # Another update strategy
        new_pop = self.pop + self.pop
        self.pop = new_pop

    def step(self):
        control_number = torch.rand()
        if control_number < 0.5:  # Conditional control
            self.strategy_1()
        else:
            self.strategy_2()

    @trace_impl(step)  # Rewrite step function for vmap support
    def trace_step_without_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        branches = (self.strategy_1, self.strategy_2)
        state, names = self.prepare_control_flow(*branches)  # Utilize state to track self.pop
        _if_else_ = TracingCond(*branches)
        state = _if_else_.cond(state, cond, pop)
        self.after_control_flow(state, *names)

    @trace_impl(step)
    def trace_step_with_operations_to_self(self):
        pop = torch.rand(self.pop_size, self.dim, dtype=self.lb.dtype, device=self.lb.device)
        pop = pop * (self.ub - self.lb)[None, :] + self.lb[None, :]
        pop = pop * self.hp[0]
        control_number = torch.rand()
        cond = control_number < 0.5
        _if_else_ = TracingCond(cond)
```

### Utilizing `use_state`

[`use_state`](#use_state) transforms a given stateful function (which performs in-place alterations on `nn.Module`s) into a pure-functional version that receives an additional `state` parameter (of type `Dict[str, torch.Tensor]`) and returns the altered state.

The input `func` is the stateful function to be transformed or its generator function, and `is_generator` specifies whether `func` is a function or a function generator (e.g., a lambda that returns the stateful function). It defaults to `True`.

Here is a simple example:

```python
@jit_class
class Example(ModuleBase):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold
        self.sub_mod = nn.Module()
        self.sub_mod.buf = nn.Buffer(torch.zeros(()))

    def h(self, q: torch.Tensor) -> torch.Tensor:
        if q.flatten()[0] > self.threshold:
            x = torch.sin(q)
        else:
            x = torch.tan(q)
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    @trace_impl(h)
    def th(self, q: torch.Tensor) -> torch.Tensor:
        x += self.g(x).abs()
        x *= x.shape[1]
        self.sub_mod.buf = x.sum()
        return x

    def g(self, p: torch.Tensor) -> torch.Tensor:
        x = torch.cos(p)
        return x * p.shape[0]

fn = use_state(lambda: t.h, is_generator=True)
jit_fn = jit(fn, trace=True, lazy=True)
results = jit_fn(fn.init_state(), torch.rand(10, 1))
print(results)  # ({"self.sub_mod.buf": torch.Tensor(5.6)}, torch.Tensor([[0.56], ...]))

# IN-PLACE update all relevant variables using the given state
fn.set_state(results[0])
```

### Utilizing `core._vmap_fix`

The module [`_vmap_fix`](#_vmap_fix) provides useful functions. After the automatic import, `_vmap_fix` enables `torch.vmap` to be correctly traced by `torch.jit.trace`, while resolving issues such as random number handling that couldn't be properly traced during the `vmap` process. It also provides the `debug_print` function, which allows dynamic printing of Tensor values during both `vmap` and tracing.For example:

- [`batched_random`](#batched_random) generates a batched tensor of random values by applying the given function to the size extended with the current vmap batch size.
- [`align_vmap_tensor`](#align_vmap_tensor) aligns a tensor with the batching dimensions of a current batched tensor.

Detailed information can be found in the [`_vmap_fix`](#_vmap_fix) documentation.
