Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Repeated inference with dynamic shape leads to out of memory error #8233

Open
dvhg opened this issue Jun 10, 2021 · 4 comments
Open

[Bug] Repeated inference with dynamic shape leads to out of memory error #8233

dvhg opened this issue Jun 10, 2021 · 4 comments
Assignees
Labels
frontend:pytorch python/tvm/relay/frontend/torch

Comments

@dvhg
Copy link
Contributor

dvhg commented Jun 10, 2021

I'm trying to run PyTorch MaskRCNN on GPU and have been running into GPU memory issues. I get errors when running repeated inferences using different inputs. There's some variety in the error messages but this is the most common:

terminate called after throwing an instance of 'dmlc::Error'
  what():  [20:11:56] /home/ubuntu/tvm/include/tvm/runtime/device_api.h:260: unknown type =0

When looking at GPU memory usage (using nvidia-smi), I see memory usage increases over time until the test crashes once it nears the maximum. I'm running this on Ubuntu 18.04 and a T4 GPU with 16GB of GPU memory.

Following the form of the unit test from test_tensorrt.py, the following script should reproduce the problem I'm seeing (using the COCO dataset). It differs from the unit test in 2 ways:

  1. The VM is run on GPU instead of CPU:
ctx = tvm.gpu(0)
vm = VirtualMachine(vm_exec, ctx)
  1. Inference is run on many different inputs (from COCO dataset) rather than a single inference.

@masahi, I heard you've been working on PyTorch MaskRCNN. Have you seen this issue in your testing, or is there a problem in my script? Thank you!

import tvm
from tvm import relay
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download

import numpy as np
import cv2

import torch
import torchvision

in_size = 300

input_shape = (1, 3, in_size, in_size)


def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace


def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]


class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        return dict_to_tuple(out[0])


model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
model = TraceWrapper(model_func(pretrained=True))

model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

input_name = "input0"
shape_list = [(input_name, input_shape)]
mod, params = relay.frontend.from_pytorch(script_module, shape_list)
target = "cuda"

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
    vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.gpu(0)
vm = VirtualMachine(vm_exec, ctx)

img_dirpath = 'data/COCO_2017/subset/val2017/'
i = 0
import os
for root, dirs, files in os.walk(img_dirpath):
    for f in files:
        print(i)
        i += 1
        imgname = os.path.join(root, f)
        img = cv2.imread(imgname)
        img = cv2.resize(img, (in_size, in_size))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = np.transpose(img / 255.0, [2, 0, 1])
        img = np.expand_dims(img, axis=0).astype('float32')
        vm.set_input("main", **{input_name: img})
        tvm_res = vm.run()

@masahi
Copy link
Member

masahi commented Jun 10, 2021

Is this specific to MaskRCNN? What happens if the target is CPU?

@dvhg
Copy link
Contributor Author

dvhg commented Jun 10, 2021

That's a good point, I didn't think to check memory on CPU targets. Using llvm target, I also see memory usage increase with each inference. After about 300 inferences, the python process consumes ~25% of my 128GB physical RAM. I noticed that the rate of increase seems to slow down but varies a lot depending on the input.

I've also seen this happen with FasterRCNN.

@trevor-m
Copy link
Contributor

trevor-m commented Jun 16, 2021

Hi @dvhg @masahi
I was able to reproduce this with this simple script below. Looks like the issue might affect all ops with dynamic shapes.

On my T4 gpu with 16gb GPU memory and using pooled allocatior, I run out of memory on the 31st iteration.
If I switch to naive allocator, I am able to run the script indefinitely. I can see at certain points the gpu usage decreases. The maximum memory usage I ever see is only 3gb. Looks like we can use naive allocator to work around this issue for the moment.

It looks like maybe pooled allocator is allocating too much memory or doing something weird?

import tvm
from tvm import relay
import numpy as np

in_size = 500
input_shape = (relay.Any(), 3, in_size, in_size)
weight_shape = (32, 3, 3, 3)

x = relay.var("input", shape=input_shape, dtype="float32")
w = relay.var("weight", shape=weight_shape, dtype="float32")
y = relay.nn.conv2d(x, w, channels=32, kernel_size=(3, 3))
mod = tvm.IRModule()
mod["main"] = relay.Function([x, w], y)
params = {"weight": np.random.randn(*weight_shape).astype("float32")}

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
    vm_exec = relay.vm.compile(mod, target="cuda", params=params)

ctx = tvm.gpu(0)
vm =  tvm.runtime.vm.VirtualMachine(vm_exec, ctx)

for i in range(1000):
    print("Iteration: ", i)
    batch_size = i % 100
    x = np.random.randn(batch_size, 3, in_size, in_size).astype("float32")
    tvm_res = vm.run(x)

cc @zhiics

@masahi
Copy link
Member

masahi commented Aug 20, 2021

I've just hit this problem when evaluating PT MaskRCNN on coco dataset. I want to take a look at this issue.

@masahi masahi self-assigned this Aug 20, 2021
@masahi masahi changed the title [Bug] PyTorch MaskRCNN GPU OOM error [Bug] Repeated inference with dynamic shape leads to out of memory error Jan 9, 2022
@areusch areusch added the needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it label Oct 19, 2022
@hpanda-naut hpanda-naut added frontend:pytorch python/tvm/relay/frontend/torch and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Nov 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend:pytorch python/tvm/relay/frontend/torch
Projects
None yet
Development

No branches or pull requests

5 participants