# Optimizing Development for EvoX via PyTorch Advanced Techniques

## Basic Optimization Support for Functions in PyTorch

PyTorch provides fundamental optimization support for functions, primarily through vectorizing map operations and Just-In-Time (JIT) compilation. These techniques enable efficient batch processing and enhance execution performance, respectively. Introductions of these optimizations are provided in the following sections.

### Batch Processing Support through Vectorizing Map in PyTorch

Vectorizing map, implemented in PyTorch as torch.vmap, is a powerful tool that takes callable functions and returns a new and enhanced version. According to a specified strategy, this new function vectorizes the operations of the original one, which facilitates efficient batch processing. In EvoX, for example, this feature plays a crucial role in population-based evolutionary processes.

In [1]:
import torch


def dummy_evaluation(pop_x: torch.Tensor, y: torch.Tensor):
    return pop_x * y
batched_dummy_evaluation = torch.vmap(dummy_evaluation, (0, None))

population_size = 3
individual_vector_size = 9
pop_x = torch.arange(individual_vector_size).repeat(population_size, 1)
y = torch.arange(individual_vector_size)

batched_dummy_evaluation(pop_x, y)

tensor([[ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64],
        [ 0,  1,  4,  9, 16, 25, 36, 49, 64]])

### Just-In-Time (JIT) Support in PyTorch

In PyTorch, `torch.jit.trace` and `torch.jit.script` provide two distinct types of JIT tools, supporting function performance optimization through tracing and scripting, respectively.

Based on the tracing strategy, the `torch.jit.trace` method offers higher parsing speed and broader compatibility, such as with `torch.vmap` operations. Although it provides excellent support for simple functions, it is not suitable for complex tasks involving dynamic if-else branches and loop control flows.

In [2]:
import functools


@functools.partial(torch.vmap, in_dims=(0, None))
def vmap_sample_func(x: torch.Tensor, y: torch.Tensor):
    return x.sum() + y

In the example below, the traced `vmap` function successfully returns the correct code representation:

In [3]:
traced_vmap_func = torch.jit.trace(vmap_sample_func,
    example_inputs=(pop_x, y),
)
print(traced_vmap_func.code)

def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  _0 = torch.add(torch.view(torch.sum(x, [1]), [3, 1]), y)
  return _0



Alternatively, the `torch.jit.script` method, which adopts a scripting strategy, is better suited for complex tasks that involve dynamic control flows but has limited compatibility.

In this example, the same `vmap_sample_func` function, after being scripted, returns an **incorrect** code representation:

In [4]:
scripted_vmap_func = torch.jit.script(vmap_sample_func,
    example_inputs=(pop_x, y),
)
print(scripted_vmap_func.code)

def vmap_sample_func(x: Tensor,
    y: Tensor) -> Tensor:
  return torch.add(torch.sum(x), y)





### Combined Usage of JIT and Vectorizing Map in PyTorch

Based on the introductions above, when `torch.jit.trace` and `torch.jit.script` are used in combination with `torch.vmap`, coordination is required due to compatibility considerations.

The figure below illustrates the relationship between `torch.jit.script`, `torch.jit.trace`, and `torch.vmap`, highlighting their mutual invocation paths. If module A invokes module B, it implies that B can be called by A.

<div style="text-align: center;">
    <img src="../../_static/jit_vmap.png" alt="jit introduction" style="width: 400px"/>
</div>

For detailed usage of JIT and vectorizing map on PyTorch, please refer to the official PyTorch documentation for [TorchScript](https://pytorch.org/docs/stable/jit.html) and [`torch.vmap`](https://pytorch.org/docs/stable/generated/torch.vmap.html).

## Specific Optimization Support in EvoX

Within EvoX, most functions are defined inside classes, particularly subclasses of `ModuleBase`. To provide more comprehensive optimization supports, EvoX offers specific enhancements.