Skip to content

Commit

Permalink
[ONNX] Reduce exporter memory usage by removing intermediate values (p…
Browse files Browse the repository at this point in the history
…ytorch#101148)

This commit reduces the exporter memory usage by as much as 50%. During the shape inference step, the exporter caches the values of intermediate tensors in a `ConstantValueMap`. This can use as much memory as the model itself, or even more. For example, model weight tensors are often fed to a Transpose layer, and the output of that is the same size of the weights. This commit fixes the issue by removing the intermediate tensor values after they are used by all consumers.

The cached values are only used for shape inference, so removing them after use should be safe. `ConstantValueMap` is cleared anyways once shape inference is complete for the entire graph.

As an example, here is the model from issue pytorch#61263:
```python
import torch
import math

# Size in GB
tensor_size = 1
model_size = 8

layers_num = model_size // tensor_size
kB = 1024
MB = kB * kB
GB = MB * kB
precision_size = 4 # bytes per float
activation_size = math.floor(math.sqrt(tensor_size * GB / precision_size))

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        for i in range(layers_num):
            name = "fc_%d" % i
            linear = torch.nn.Linear(activation_size, activation_size)
            setattr(self, name, linear)
    def forward(self, x):
        for i in range(layers_num):
            name = "fc_%d" % i
            linear = getattr(self, name)
            x = linear(x)
        return x

model = Net().cuda()
input = torch.zeros(activation_size, requires_grad=True).cuda()
with torch.no_grad():
    torch.onnx.export(model, (input, ), './model_large.onnx', do_constant_folding=False, opset_version=13)
```
It is just some large linear layers stacked together. Before this commit, my max GPU usage during export was about 16.7 GB, twice the model size. With this commit in combination with pytorch#101134, it was only about 9.5 GB.

Together with pytorch#101134, fixes issue pytorch#61263

Pull Request resolved: pytorch#101148
Approved by: https://github.com/BowenBao
  • Loading branch information
ilyasher authored and alimoezzi committed Jun 3, 2023
1 parent be9295f commit d537e2e
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/csrc/jit/passes/onnx/constant_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ c10::optional<at::Tensor> ConstantValueMap::GetValue(
return ConstantValueMap::getInstance().tensorValueMap[tensorName];
}

void ConstantValueMap::EraseValue(const std::string& tensorName) {
ConstantValueMap::getInstance().tensorValueMap.erase(tensorName);
}

std::vector<int64_t> ConstantValueMap::GetCompleteShapeInto1DInt64Vector(
const c10::SymbolicShape& shape) {
TORCH_INTERNAL_ASSERT(shape.isComplete());
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/onnx/constant_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ConstantValueMap {
static void SetValue(const std::string& tensorName, const at::Tensor& value);
static bool HasValue(const std::string& tensorName);
static c10::optional<at::Tensor> GetValue(const std::string& tensorName);
static void EraseValue(const std::string& tensorName);

static std::vector<int64_t> GetCompleteShapeInto1DInt64Vector(
const c10::SymbolicShape& shape);
Expand Down
35 changes: 35 additions & 0 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,40 @@ void FetchBlockInputMetadataFromParent(Block* b) {
}
}

void RemoveProcessedInputs(const Node* n) {
// After processing a node for shape inference, remove intermediate tensors
// that are stored in ConstantValueMap to reduce memory usage.
// This will only remove tensors that are no longer needed by any other node.

// Returns whether a node was already processed for shape inference.
const auto isNodeProcessed = [](const Node* node) {
const auto& outputs = node->outputs();
return std::any_of(outputs.begin(), outputs.end(), [](const Value* output) {
// Assumes shape inference can at least determine the rank of the outputs.
// If this assumption is wrong, some intermediate tensors will only be
// deleted once shape inference is completed for the entire graph.
return ConstantValueMap::HasRank(output->debugName());
});
};

// An input value is no longer needed if all of its consumer nodes
// have already been processed.
const auto isValueNoLongerNeeded = [isNodeProcessed](const Value* input) {
const auto& uses = input->uses();
return std::all_of(
uses.begin(), uses.end(), [isNodeProcessed](const Use& use) {
return isNodeProcessed(use.user);
});
};

for (const auto* input : n->inputs()) {
if (ConstantValueMap::HasValue(input->debugName()) &&
isValueNoLongerNeeded(input)) {
ConstantValueMap::EraseValue(input->debugName());
}
}
}

void ONNXShapeTypeInference(
Block* b,
const ParamMap& params_dict,
Expand Down Expand Up @@ -1850,6 +1884,7 @@ void ONNXShapeTypeInference(
ONNXShapeTypeInference(subblock, params_dict, opset_version);
}
ONNXShapeTypeInference(n, params_dict, opset_version);
RemoveProcessedInputs(n);
}
}

Expand Down

0 comments on commit d537e2e

Please sign in to comment.