Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could torch.einsum gain speed boost ? #394

Closed
fyubang opened this issue Jul 16, 2019 · 15 comments
Closed

Could torch.einsum gain speed boost ? #394

fyubang opened this issue Jul 16, 2019 · 15 comments

Comments

@fyubang
Copy link

@fyubang fyubang commented Jul 16, 2019

I am trying to fine tune xlnet and found that the memory was half, but it was slower than fp32(even when I double the batch size).

Environment: v100, cuda 10.0, torch 1.1

The environment is ok, because I tried bert + fp16 and it was much faster than fp32.
I thought it is the problem of torch.einsum, but I am not that sure.

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 16, 2019

Hi @fyubang,

could you post a link to the repo you are using so that we can have a look?

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 16, 2019

Hi @fyubang,

could you post a link to the repo you are using so that we can have a look?

Sorry for forgetting about the link, I used the code here:
https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 17, 2019

Hi @ptrblck,
I tried the new repo of Huggingface, it did not work either.
https://github.com/huggingface/pytorch-transformers/blob/master/examples/run_squad.py

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 17, 2019

Thanks for the link, @fyubang.
We'll take a look at it.

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 17, 2019

We tried to compare the performance between a FP32 run and an amp run using opt_level='O1'.
For this, we've cloned the current repo from @huggingface and used the command as given here for the FP32 run:

python -m torch.distributed.launch --nproc_per_node=8 ./examples/run_squad.py \
    --model_type bert \
    --model_name_or_path bert-large-uncased-whole-word-masking \
    --do_train \
    --do_eval \
    --do_lower_case \
    --train_file $SQUAD_DIR/train-v1.1.json \
    --predict_file $SQUAD_DIR/dev-v1.1.json \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ../models/wwm_uncased_finetuned_squad/ \
    --per_gpu_eval_batch_size=3   \
    --per_gpu_train_batch_size=3 

Using 8 V100 GPUs (each with 32GB), we could achieve a mean speed of ~2.65 iterations/second.

However, supporting the --fp16 argument to the same command, apex raises an error, since DDP is being initialized before amp.initialize was called.
Did you observe the same error?

After changing the order of initialization, we could successfully run the script on the same machine achieving ~3.70 iterations/second, which seems reasonable.

By "it did not work either", are you referring to the raised error or to a slower run using amp?

CC @huggingface
Is this a known issue and would you be interested in a fix?

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 17, 2019

I rerun the test using the xlnet:

python -m torch.distributed.launch --nproc_per_node=8 ./examples/run_squad.py \
    --model_type xlnet \
    --model_name_or_path xlnet-large-cased \
    --do_train \
    --do_eval \
    --do_lower_case \
    --train_file $SQUAD_DIR/train-v1.1.json \
    --predict_file $SQUAD_DIR/dev-v1.1.json \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir ../models/wwm_uncased_finetuned_squad/ \
    --per_gpu_eval_batch_size=3   \
    --per_gpu_train_batch_size=3 

and got the following numbers:
FP32: ~1.35 iterations/second
AMP O1: ~1.44 iterations/second

The performance benefit is indeed smaller and worth having a closer look at.

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 18, 2019

@ptrblck
Thanks for your reply. I got similar result with you.
I thought the reason is that the author had a heavy use of torch.einsum, like:
torch.einsum('ibnd,jbnd->ijbn', a, b)

I tried to replace it by:

a_tmp = a.permute(1,2,0,3)
b_tmp = b.permute(1,2,3,0)
res = a_tmp.matmul(b_tmp)
res = res.permute(2,3,0,1)

but it became even slower than torch.einsum.

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 18, 2019

@ptrblck
I tested the speed of computation of matmul when the shapes of input are (a,b,c,d) and (a,b,d,e).
I found that fp16 is much much slower than fp32 (like 1: 20). It may be the reason why fp16 was slower.

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 19, 2019

@fyubang
Note that the shapes for GEMMs should be multiples of 8 as explained in our pinned topic.

Here is a small benchmark using 1) shapes of factors of 8 and 2) missing this condition slightly:

# 1)
I, J, K = 64, 1024, 1024
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)

nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))

> 16.043us per iteration

# 2)
I, J, K = 63, 1023, 1023
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)

nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('{:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))

> 39.476us per iteration

Could this also be the reason for the minor speedup in the XLNET?

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 21, 2019

@ptrblck
Thanks for your reply.
In fact, when I tryied (i, j) matmul (j,k), it can always have a speed boost.
However, the problem is when I tried (a,b,c) matmul (a, c, d), it will not get accelerated.
In addition, here is the config of xlnet:

{
  "attn_type": "bi",
  "bi_data": false,
  "clamp_len": -1,
  "d_head": 64,
  "d_inner": 4096,
  "d_model": 1024,
  "dropatt": 0.1,
  "dropout": 0.1,
  "ff_activation": "gelu",
  "init": "normal",
  "init_range": 0.1,
  "init_std": 0.02,
  "initializer_range": 0.02,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "mem_len": null,
  "n_head": 16,
  "n_layer": 24,
  "n_token": 32000,
  "reuse_len": null,
  "same_length": false,
  "untie_r": true
}

Since they are all multiples of 8, I think it is not the problem of "multiples of 8".

@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 21, 2019

Do you mean by "In fact, when I tryied (i, j) matmul (j,k), it can always have a speed boost", that each FP16 matmul in this form will be faster than the corresponding FP32 matmul regardless of the input shapes?
This sounds strange to me, as I'll get similar FP16 (non-x8-shaped) timings to FP32 ones, while x8-shaped FP16 matmuls yields a speedup.
I've added also another dimension and also get a speedup for x8-shaped FP16.

Could you try to add some warmup iterations before the actual timings?
The first measured time might be a bit biased.

Thanks for the information about xlnet. We'll look into it.

@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 22, 2019

@ptrblck
Thanks for your reply.
For the first quesiton:
Yes, I mean it, but maybe I used the shape like 20, 60 instead of 63.

For the second quesiton:
Could you try this code (they are all x8-shaped) and check if you can still have a speedup for fp16 ?

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import torch
from time import time
# 1) fp32
a = torch.empty(24,32,40,48, dtype=torch.float32).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float32).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float32).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float32).to('cuda')

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    c.matmul(d)
torch.cuda.synchronize()
print(time()-st)

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st)

# 2) fp16
a = torch.empty(24,32,40,48, dtype=torch.float16).to('cuda')
b = torch.empty(64,32,40,48, dtype=torch.float16).to('cuda')
c = torch.empty(40,80,24, dtype=torch.float16).to('cuda')
d = torch.empty(40,24,16, dtype=torch.float16).to('cuda')

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.matmul(c,d)
torch.cuda.synchronize()
print(time()-st)

torch.cuda.synchronize()
st = time()
for _ in range(1000):
    torch.einsum('ibnd,jbnd->ijbn', a, b)
torch.cuda.synchronize()
print(time()-st)

my result is:

0.028162240982055664
0.10057997703552246
0.38828039169311523
11.749611377716064
@ptrblck

This comment has been minimized.

Copy link
Collaborator

@ptrblck ptrblck commented Jul 22, 2019

There are my results for your calculations on a TITAN V:

0.017162799835205078
0.09859037399291992
0.015858173370361328
0.042925119400024414
@fyubang

This comment has been minimized.

Copy link
Author

@fyubang fyubang commented Jul 22, 2019

@ptrblck Thanks for your reply.
It seems torch.einsum does have a speedup.
I will double check again.

@ngimel

This comment has been minimized.

Copy link
Contributor

@ngimel ngimel commented Jul 22, 2019

Closed in favor of pytorch/pytorch#23061, this does not seem to be amp-specific.

@ngimel ngimel closed this Jul 22, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
3 participants
You can’t perform that action at this time.