Skip to content

[BUG] FLOPS compute **FAILS** for F.interpolate when using scale_factor #4504

Closed
@xmfbit

Description

@xmfbit

Describe the bug

When use scale_factor instead of size in F.interpolate, the computer function will fail for some cases. The original code implementation is not well tested.

To Reproduce
Code to reproduce the bug:

import torch
from torch import nn
from torch.nn import functional as F

# here is a module with F.interpolate
class Module(nn.Module):
    def forward(self, x):
        # x is a 4D Tensor, which will be upsampled by 2X
        return F.interpolate(x, scale_factor=(2,2))
    
x = torch.randn(1, 3, 4, 5)
m = Module()
# output shape: 1, 3, 8, 10
print(f"output shape of m(x): {m(x).shape}")

from deepspeed.profiling.flops_profiler import get_model_profile

#  here the bug occurs!
get_model_profile(m, kwargs={"x": x}, print_profile=True)

The error message:

Traceback (most recent call last):
  File "test_interpolate.py", line 18, in <module>
    get_model_profile(m, kwargs={"x": x}, print_profile=True)
  File "/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/deepspeed/profiling/flops_profiler/profiler.py", line 1221, in get_model_profile
    _ = model(*args, **kwargs)
  File "/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1208, in _call_impl
    result = forward_call(*input, **kwargs)
  File "test_interpolate.py", line 9, in forward
    return F.interpolate(x, scale_factor=(2,2))
  File "/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/deepspeed/profiling/flops_profiler/profiler.py", line 836, in newFunc
    flops, macs = funcFlopCompute(*args, **kwds)
  File "/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/deepspeed/profiling/flops_profiler/profiler.py", line 727, in _upsample_flops_compute
    flops * scale_factor**len(input)
TypeError: unsupported operand type(s) for ** or pow(): 'tuple' and 'int'

The code should be:

# line 724
   if isinstance(scale_factor, (list, tuple)):
        # see documention of `F.interpolate`
        # the spatial dims are defined as the last `n-2` dims of the tensor
        assert len(scale_factor) ==  input.ndim - 2
        flops *= int(_prod(scale_factor))
    else:
        flops *= scale_factor**(input.ndim - 2)
    return flops, 0

Expected behavior

ds_report output

The bug remains in main branch, and the ds_report of mine env is:

[2023-10-12 04:10:59,871] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  using untested triton version (2.1.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/torch']
torch version .................... 1.13.0+cu117
deepspeed install path ........... ['/home/code/miniconda3/envs/deep/lib/python3.8/site-packages/deepspeed']
deepspeed info ................... 0.11.1, unknown, unknown
torch cuda version ............... 11.7
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.1, cuda 12.1
shared memory (/dev/shm) size .... 1007.58 GB

Screenshots
NO

Docker context
NO

Additional context
NO

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions