Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2x performance drop using pytorch depending on how input data is fed into model #198

Closed
stas-sl opened this issue Aug 14, 2023 · 10 comments
Closed

Comments

@stas-sl
Copy link

stas-sl commented Aug 14, 2023

Hi!

I'm not sure if this is a good place to get help/report issue, if not, let me what would be a better way.

First of all thanks for great library! However trying it out, I faced some performance issues that I can't fully understand. Here is an example script. It is quite simple, I'm using standard resnet50 model and feeding it with some data in a loop. However what I noticed is that small changes in how I'm feeding the data might cause quite drastic 2x performance drop.

import time
import torch
from torchvision.models import resnet50

n_warmup = 5
n_bench = 50

model = resnet50()
model.eval()
model = torch.compile(model, backend='aio', options={'modelname': 'resnet50'})

data = torch.randn(1, 3, 224, 224)

for i in range(n_bench + n_warmup):
    if i == n_warmup:
        start = time.time()
    x = data                              # Latency: ~100 ms
    # x = torch.stack([data[0]])          # Latency: ~200 ms
    # x = torch.randn(1, 3, 224, 224)     # Latency: ~100 ms
    # x = data.clone()                    # Latency: ~200 ms 
    model(x)

duration = time.time() - start
latency = duration / n_bench

print(f'Latency: {latency * 1000:.0f} ms')

So, if you use same data or create random data on each iteration it works fast. However if you use torch.stack or just .clone performance drops. TBH, I don't understand why these quite small changes would matter.

I tried to use torch.jit.script and torch.jit.trace instead of torch.compile, but results weren't any better, in fact using torch.jit.script latency was ~200ms even in 1st case (same data on each iteration). But torch.jit.trace and torch.compile were very similar.

For comparison, without compilation/scripting/tracing latency is ~290ms no matter how you feed the data.

What I also noticed is difference in CPU usage. In both cases both CPU cores are 100% utilized, however when it is slow, more time is spent in kernel threads (red portion)

Fast:
Pasted Graphic 9

Slow:
Pasted Graphic 8

I'm also attaching logs obtained with AIO_DEBUG_MODE=5

log_fast.txt
log_slow.txt

I'm using your latest docker image amperecomputingai/pytorch:1.7.0

Is there something obvious that I'm missing?

@jan-grzybek-ampere
Copy link
Member

jan-grzybek-ampere commented Aug 14, 2023

Thanks for reaching out!

You will need to play with AIO_SKIP_MASTER_THREAD env variable (possible values are 0 and 1, default is 0) to get best performance, i.e.:

AIO_SKIP_MASTER_THREAD=1 AIO_NUM_THREADS=2 numactl -C 0-1 python3 example.py

x = data                         # Latency: ~250 ms
x = torch.stack([data[0]])       # Latency: ~100 ms
x = torch.randn(1, 3, 224, 224)  # Latency: ~400 ms
x = data.clone()                 # Latency: ~100 ms

AIO_SKIP_MASTER_THREAD=0 AIO_NUM_THREADS=2 numactl -C 0-1 python3 example.py

x = data                         # Latency: ~100 ms
x = torch.stack([data[0]])       # Latency: ~320 ms
x = torch.randn(1, 3, 224, 224)  # Latency: ~110 ms
x = data.clone()                 # Latency: ~320 ms

As you can see, ~100 ms latency is possible on your 2-threaded Ampere Altra VM in each case listed in your example, provided proper value of env is set.

Fyi, we are working on a solution relieving the user from the need to adjust this parameter.
For the time being please refer to: https://ampereaidevelopus.s3.amazonaws.com/releases/1.7.0/Ampere+Optimized+PyTorch+Documentation+v1.7.0.pdf

Btw, you will get even better performance by auto-casting to fp16:

AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" AIO_SKIP_MASTER_THREAD=0 AIO_NUM_THREADS=2 numactl -C 0-1 python3 test2.py
Latency: 50 ms

:)

@stas-sl
Copy link
Author

stas-sl commented Aug 14, 2023

Thanks a lot for fast and helpful reply )

Indeed AIO_SKIP_MASTER_THREAD=1 helped. For some reason it works equally fast now for all 4 cases (~100-110ms), I don't mind, but a bit confusing, why it differs from your results.

So the documentation says:

If the model contains nodes not supported by Ampere Optimized Pytorch we recommend setting following
environmental variable: AIO_SKIP_MASTER_THREAD=1

Does it mean that torch.stack and torch.clone are not currently supported, but if they will be in future, then it should work equally fast without skipping main thread?

As for auto-casting to fp16 - it looks like magic ) Indeed it works 2x faster. I supposed pytorch was not supporting fp16 on CPU as mentioned in #152. I actually tried myself, before I found that issue. So setting AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" is kind of workaround to bypass pytorch limitations somehow? I'm not sure yet if I will use this mode because of potential precision loss - I have to check how my model will perform, but it is great to know that it actually works!

I'm closing the issue, as you've already answered, but would appreciate another reply )


Not related, but, do you have plans to release your packages to be used outside docker or so that they can be used in my own custom container?

@stas-sl stas-sl closed this as completed Aug 14, 2023
@kkontny
Copy link
Contributor

kkontny commented Aug 16, 2023

Does it mean that torch.stack and torch.clone are not currently supported, but if they will be in future, then it should work equally fast without skipping main thread?

Yes torch.stack and torch.clone are currently not supported, also there is a different issue in this case: the this functions are used outside of optimised model in this case (code is outside of model optimised by torch.compile / torch.jit.trace functions.

As for auto-casting to fp16 - it looks like magic ) Indeed it works 2x faster. I supposed pytorch was not supporting fp16 on CPU as mentioned in #152. I actually tried myself, before I found that issue. So setting AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" is kind of workaround to bypass pytorch limitations somehow? I'm not sure yet if I will use this mode because of potential precision loss - I have to check how my model will perform, but it is great to know that it actually works!

Since x86 CPU didn't supported FP16 before very recent AVX-512 extension, it seems that nobody really cared about CPU support for FP16 in Pytorch. Currently Pytorch support of FP16 on CPU is very limited, but currently there is some work going on master branch of Pytorch, eg:
pytorch/pytorch#98493
pytorch/pytorch#98819

And yes AIO_IMPLICIT_FP16_TRANSFORM_FILTER=".*" is kind of workaround about this issue, to bring up support for FP16 in Ampere Optimized Pytorch, even when framework doesn't support it.

@jan-grzybek-ampere
Copy link
Member

Not related, but, do you have plans to release your packages to be used outside docker or so that they can be used in my own custom container?

Please contact us at ai-support@amperecomputing.com and we should be able to get you a working .deb installer.

@stas-sl
Copy link
Author

stas-sl commented Oct 9, 2023

Hi, it's me again 🙈.

It seems like a similar issue, though a bit different. This time it depends on input size. For some (smaller) inputs it works fast, but after some threshold it suddenly slows down 2-3x. Here I'm using vision transformer from timm library. It basically reshapes an image from 2d to 1d sequence and runs quite basic transformer on it. So for img_size=110 and patch_size=10, sequence length will be 11 * 11 = 121, and if you increase img_size to 120, then sequence length will be 12 * 12 = 144.

import torch
import timm
import time

# img_size = 110 # latency: 10ms
img_size = 120 # latency: 33ms

model = timm.models.VisionTransformer(
    img_size=img_size,
    patch_size=10,
    embed_dim=128,
    num_heads=8,
    depth=12
)
model.eval()
model = torch.compile(model, backend='aio', options={'modelname': 'vit'})

data = torch.rand(1, 3, img_size, img_size)

n_warmup = 5
n = 100

with torch.no_grad():
    for i in range(n + n_warmup):
        if i == n_warmup:
            start = time.time()
        model(data)

duration = time.time() - start
latency = duration / n * 1000
cps = 1000 / latency

print(f'Latency: {round(latency)}ms, rate: {round(cps)} per second')

With AIO_SKIP_MASTER_THREAD=1 it works a bit faster, though still there is same slowdown if changing input size.

Should I provide logs or maybe you have ideas what could be wrong without them?

@kkontny
Copy link
Contributor

kkontny commented Oct 11, 2023

Hi, I've tried to run your script with:
AIO_SKIP_MASTER_THREAD=1 AIO_NUM_THREADS=4 OMP_NUM_THREADS=4 python test.py
for img_size = 110
I'm getting:
Latency: 13ms, rate: 77 per second
for img_size = 120
Latency: 15ms, rate: 66 per second

There is some difference but, not that big.

How do you run the script? What number of threads are you using?

@stas-sl
Copy link
Author

stas-sl commented Oct 11, 2023

Thanks for looking into this. I'm using 8 threads. I'm attaching debug logs, hope it will contain all necessary information.

log_fast.txt
log_slow.txt

Also I'm testing on your previous docker container version (amperecomputingai/pytorch:1.7.0), I see that there is a newer one already. I'm not sure if it could make a difference, but I can try to test on it.

When comparing logs side by side, I don't see any major difference besides a bit larger tensor shapes, which match my calculations (122 seq length for image_size=110 and 145 seq length for image_size=120). But it shouldn't affect performance that much, I believe. The only difference besides shapes, I see is this:
image

Could be it be related to cache sizes somehow? It doesn't necessarily depends on image_size, for example if I change embed_dim, the drop will occur at different image_size threshold.

@stas-sl
Copy link
Author

stas-sl commented Oct 11, 2023

Hmm.... looks like I found the reason.

There are the following lines in Attention module implementation:

if self.fused_attn:
    x = F.scaled_dot_product_attention(
        q, k, v,
        dropout_p=self.attn_drop.p if self.training else 0.,
    )
else:
    q = q * self.scale
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)
    x = attn @ v

I actually wanted to ask if you have optimizations for transformers/attention. Pytorch has this scaled_dot_product_attention method that I guess is optimized version of the code in the else branch. It is possible to set which attention to use via env variable: TIMM_FUSED_ATTN=0/1, but if not set explicitly it checks whether scaled_dot_product_attention method is available, and if yes, then it uses it.

If I disable fused attention explicitly by setting TIMM_FUSED_ATTN=0, then it works actually a bit faster and there no performance drop. So I guess scaled_dot_product_attention is just not implemented in Ampere, and just basic matrix multiplication should be used.

@kkontny
Copy link
Contributor

kkontny commented Oct 13, 2023

Hi,
few things to mention:.

  1. My results were measured with model containing scaled_dot_product_attention op which was not handled by our kernels, just by regular Pytorch implementation. So there is some problem, but yet I don't really know where. I was testing our newest implementations, maybe that also affects the performance.
  2. Indeed we don't support scaled_dot_product_attention yet, so it may be beneficial to turn it off in Pytorch implementation. We started some effort to support it, but I won't expect it to be available this year. Model without this op should be entirely handled by our kernels, overall I think it may be still better then Pytorch implementation of scaled_dot_product_attention.
  3. The difference you saw in logs means that our software chosen different implementation of that operation for bigger size, likely it should be better.

@stas-sl
Copy link
Author

stas-sl commented Oct 13, 2023

Thanks for clarification!

I've actually tried it on a completely new VM with the latest docker image (1.8.0) and looks like there is still same issue for me when scaled_dot_product_attention is used. I tried all combinations of TIMM_FUSED_ATTN x AIO_SKIP_MASTER_THREAD x image_size.

As you can see below if TIMM_FUSED_ATTN=1 then for any AIO_SKIP_MASTER_THREAD there is perf drop when increasing img_size from 110 to 120.

image

I'm not sure if this can give any clue, but when perf drop occurs I see similar increase of red/kernel CPU usage, like if AIO_SKIP_MASTER_THREAD=1 not set, although without it red portion is even larger. However for image_size=110, CPU bars are completely green.

image

This is not critical for me at the moment, as there is a workaround which is even faster for now, but if at some point you'll get an idea what could be wrong, I'd be interested to know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants