Skip to content

Commit

Permalink
[AOTI] align data_size of the constants (pytorch#127610)
Browse files Browse the repository at this point in the history
pytorch#124272 set the alignment to the `consts_o` but if there're `data_size` of tensor in the `consts_o` non divisible by the alignment, the following tensors are not aligned anymore, resulting in poor performance on CPU.
We align the `data_size` as well in this PR and pad the serialized bytes. Since `size` of the tensor instead of the `data_size` is used when creating tensor from the serialized bytes ([link](https://github.com/pytorch/pytorch/blob/f4d7cdc5e63c786b1f6588eafa53bbc6d33c3826/torch/csrc/inductor/aoti_runtime/model.h#L236-L259)), there won't be correctness issue. `data_size` is only used to record the [bytes_read](https://github.com/pytorch/pytorch/blob/f4d7cdc5e63c786b1f6588eafa53bbc6d33c3826/torch/csrc/inductor/aoti_runtime/model.h#L217).

This PR will improve the performance on CPU for 4 models in HF, 7 models in TIMM and 1 model in Torchbench.

For the unit test, I add a bias value the original `data_size` of which is not divisible by the alignment to test the correctness:
```
constants_info_[0].dtype = static_cast<int32_t>(at::kFloat);
constants_info_[0].data_size = 64; # was 40 before this PR
constants_info_[0].shape = {10};

constants_info_[1].dtype = static_cast<int32_t>(at::kFloat);
......
```

Pull Request resolved: pytorch#127610
Approved by: https://github.com/jgong5, https://github.com/desertfire
  • Loading branch information
chunyuan-w authored and TharinduRusira committed Jun 14, 2024
1 parent 553de1d commit f7aeb49
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 59 deletions.
3 changes: 2 additions & 1 deletion test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,10 @@ class LinearModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.weight = torch.randn(10, 10, device=device).to(dtype)
self.bias = torch.randn(10, device=device).to(dtype)

def forward(self, y):
return torch.nn.functional.linear(y, self.weight)
return torch.nn.functional.linear(y, self.weight, self.bias)

example_inputs = (torch.randn(10, 10, device=self.device).to(dtype),)

Expand Down
44 changes: 30 additions & 14 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
_reload_python_module_in_subproc,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import clear_on_fresh_inductor_cache, is_linux
from torch._inductor.utils import ALIGN_BYTES, clear_on_fresh_inductor_cache, is_linux

from torch._logging import trace_structured
from torch._subclasses.fake_tensor import (
Expand Down Expand Up @@ -2059,10 +2059,14 @@ def _compile_consts_linux(consts: bytes) -> str:
# as read-only (i.e. .lrodata) which could accomodate larger size of data
# to be linked.
rename_data = " .data=.lrodata,alloc,load,readonly,data,contents"

assert (
ALIGN_BYTES & (ALIGN_BYTES - 1)
) == 0 and ALIGN_BYTES >= 64, "must be power of 2 and >= 64"
cmd = (
f"{objcopy_command} --rename-section"
f"{rename_data}"
" --set-section-alignment .data=64" # following the gAlignment of CPU in c10/core/alignment.h
f" --set-section-alignment .data={ALIGN_BYTES}" # following the gAlignment of CPU in c10/core/alignment.h
f" {consts_o} {consts_o}"
)
log.debug("aot constant rename section command: %s", cmd)
Expand Down Expand Up @@ -2186,7 +2190,14 @@ def _compile_consts_darwin(consts: bytes) -> str:
else:
run_command_and_check(compile_cmd)

def _to_bytes(t: torch.Tensor) -> bytes:
def _to_bytes(t: torch.Tensor, all_cuda: bool) -> bytes:
def _pad_to_alignment(raw_bytes):
padded_bytes = raw_bytes.ljust(
(len(raw_bytes) + ALIGN_BYTES - 1) // ALIGN_BYTES * ALIGN_BYTES,
b"\x00",
)
return padded_bytes

# This serializes the tensor's untyped_storage to bytes by accessing
# the raw data of the underlying structure.
import ctypes
Expand All @@ -2195,22 +2206,27 @@ def _to_bytes(t: torch.Tensor) -> bytes:
return b""

if t.is_mkldnn:
raw_array = ctypes.cast(
torch.ops.mkldnn.data_ptr(t),
ctypes.POINTER(ctypes.c_ubyte * torch.ops.mkldnn._nbytes(t)),
)
return bytes(raw_array.contents)
data_ptr = torch.ops.mkldnn.data_ptr(t)
nbytes = torch.ops.mkldnn._nbytes(t)
else:
t_cpu = t.untyped_storage().cpu()
data_ptr = t_cpu.data_ptr()
nbytes = t_cpu.nbytes()

t_cpu = t.untyped_storage().cpu()
raw_array = ctypes.cast(
t_cpu.data_ptr(),
ctypes.POINTER(ctypes.c_ubyte * t_cpu.nbytes()),
data_ptr,
ctypes.POINTER(ctypes.c_ubyte * nbytes),
)
raw_bytes = bytes(raw_array.contents)
return raw_bytes if all_cuda else _pad_to_alignment(raw_bytes)

return bytes(raw_array.contents)

all_cuda = all(
graph.get_original_value_of_constant(name).is_cuda
for name in graph.constants.keys()
if name not in graph.folded_constants
)
serialized_weights = b"".join(
_to_bytes(graph.get_original_value_of_constant(name))
_to_bytes(graph.get_original_value_of_constant(name), all_cuda)
for name in graph.constants.keys()
if name not in graph.folded_constants
)
Expand Down
30 changes: 19 additions & 11 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey
from .. import config, ir
from ..codecache import CudaKernelParamCache
from ..utils import cache_on_self, sympy_product
from ..utils import _align, ALIGN_BYTES, cache_on_self, sympy_product
from ..virtualized import V
from .aoti_hipify_utils import maybe_hipify_code_wrapper
from .common import IndentedBuffer
Expand Down Expand Up @@ -239,8 +239,6 @@ class RAIIPyObject {
"""
)

from .memory_planning import ALIGN_BYTES

# Round up to the nearest multiple of ALIGN_BYTES
# ALIGN_BYTES must be a power of 2
self.header.splice(
Expand Down Expand Up @@ -721,6 +719,11 @@ def codegen_model_constructor(self):
), f"input {name=} cannot be symbolic"
self.write_input_output_info("inputs_info_", idx, name)

all_cuda = all(
V.graph.get_original_value_of_constant(name).is_cuda
for name in V.graph.constants.keys()
if name not in V.graph.folded_constants
)
for idx, name in enumerate(V.graph.constants.keys()):
tensor = V.graph.get_original_value_of_constant(name)
assert isinstance(tensor, torch.Tensor)
Expand All @@ -731,14 +734,19 @@ def codegen_model_constructor(self):
self.prefix.writeline(
f"constants_info_[{idx}].offset = {tensor.storage_offset()};"
)
if tensor.is_mkldnn:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {torch.ops.mkldnn._nbytes(tensor)};"
)
else:
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {tensor.untyped_storage().nbytes()};"
)

# If constants to serialize contain cpu tensors, we always align data_size it to 64.
# When loading the constants, the valid data will depends on the size
# not the data_size so there won't be correctness issue.
data_size = (
torch.ops.mkldnn._nbytes(tensor)
if tensor.is_mkldnn
else tensor.untyped_storage().nbytes()
)
self.prefix.writeline(
f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};"
)

from_folded = "true" if name in V.graph.folded_constants else "false"
self.prefix.writeline(
f"constants_info_[{idx}].from_folded = {from_folded};"
Expand Down
32 changes: 1 addition & 31 deletions torch/_inductor/codegen/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch
from .. import config, ir
from ..utils import cache_on_self, CachedMethod, IndentedBuffer
from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer
from ..virtualized import V

from .wrapper import (
Expand All @@ -22,36 +22,6 @@
)


ALIGN_BYTES = 64
assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"


def _align(nbytes):
"""Round up to the nearest multiple of ALIGN_BYTES"""
return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES


def _is_aligned(v: sympy.Expr):
"""v can be statically proven to be a multiple of ALIGN_BYTES"""
if isinstance(v, (sympy.Add, sympy.Max)):
return all(map(_is_aligned, v.args))
return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES


class align(sympy.Function):
"""Symbolically round up to the nearest multiple of ALIGN_BYTES"""

nargs = (1,)
is_integer = True

@classmethod
def eval(cls, value):
if isinstance(value, (int, sympy.Integer)):
return _align(int(value))
if _is_aligned(value):
return value


@dataclasses.dataclass
class LiveRange:
"""
Expand Down
35 changes: 33 additions & 2 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,36 @@
_T = TypeVar("_T")
VarRanges = Dict[sympy.Expr, sympy.Expr]

ALIGNMENT = 16
GPU_ALIGN_BYTES = 16

ALIGN_BYTES = 64
assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"


def _align(nbytes):
"""Round up to the nearest multiple of ALIGN_BYTES"""
return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES


def _is_aligned(v: sympy.Expr):
"""v can be statically proven to be a multiple of ALIGN_BYTES"""
if isinstance(v, (sympy.Add, sympy.Max)):
return all(map(_is_aligned, v.args))
return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES


class align(sympy.Function):
"""Symbolically round up to the nearest multiple of ALIGN_BYTES"""

nargs = (1,)
is_integer = True

@classmethod
def eval(cls, value):
if isinstance(value, (int, sympy.Integer)):
return _align(int(value))
if _is_aligned(value):
return value


def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float:
Expand Down Expand Up @@ -1548,7 +1577,9 @@ def tensor_is_aligned(tensor: torch.Tensor):
# but symbolic storage_offsets are. For consistency, we suppress guard creation
# upon performing this check: that ensures that we don't add recompiles when we
# add this logic.
return (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % ALIGNMENT == 0
return (
tensor.storage_offset() * get_dtype_size(tensor.dtype)
) % GPU_ALIGN_BYTES == 0


def should_assume_input_aligned(example_input: torch.Tensor):
Expand Down

0 comments on commit f7aeb49

Please sign in to comment.