Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

segmentation fault illegal instruction #71

Closed
ProfXGiter opened this issue Jun 24, 2020 · 13 comments
Closed

segmentation fault illegal instruction #71

ProfXGiter opened this issue Jun 24, 2020 · 13 comments

Comments

@ProfXGiter
Copy link

ProfXGiter commented Jun 24, 2020

setup

ubuntu 16.04
tvm 0.7 dev1
pytorch 1.4.0
transformer 2.11.0
other same as requirements.txt

issue

I uncomment the line in diagonaled_mm_tvm.py
DiagonaledMM._get_function('float32', 'cuda')

After that, When I run the code , it show
Loading tvm binary from :./longformer/lib/lib_diagonaled_mm_float32_cuda.so
...
segmentation fault (core dump)
or show
Loading tvm binary from :./longformer/lib/lib_diagonaled_mm_float32_cuda.so
...
illegal instruction (core dump)

other

I test the tvm, tensorflow and pytorch, there are fine.
And I follow the scripts/cheatsheet.txt to regenerate the lib_diagonaled_mm_float32_cuda.so, it can generate succeed.

Any idea or suggestion?

the code is below

import torch
from longformer.longformer import Longformer, LongformerConfig
from longformer.sliding_chunks import pad_to_window_size
from transformers import RobertaTokenizer

config = LongformerConfig.from_pretrained('longformer-base-4096/') 
# choose the attention mode 'n2', 'tvm' or 'sliding_chunks'
# 'n2': for regular n2 attantion
# 'tvm': a custom CUDA kernel implementation of our sliding window attention
# 'sliding_chunks': a PyTorch implementation of our sliding window attention
config.attention_mode = 'tvm'

model = Longformer.from_pretrained('longformer-base-4096/', config=config)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer.model_max_length = model.config.max_position_embeddings

SAMPLE_TEXT = ' '.join(['Hello world! '] * 1000)  # long input document

input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0)  # batch of size 1

# TVM code doesn't work on CPU. Uncomment this if `config.attention_mode = 'tvm'`
model = model.cuda(); input_ids = input_ids.cuda()

# Attention mask values -- 0: no attention, 1: local attention, 2: global attention
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                     # classification: the <s> token
                                     # QA: question tokens

# padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention
input_ids, attention_mask = pad_to_window_size(
        input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)

output = model(input_ids, attention_mask=attention_mask)[0]
@ibeltagy
Copy link
Collaborator

which CUDA version are you using?

@ProfXGiter
Copy link
Author

ProfXGiter commented Jun 25, 2020

which CUDA version are you using?

cuda 10.0
python 3.6.10

@ibeltagy
Copy link
Collaborator

can you try to uninstall tvm 0.7 dev1 and just rely on the tvm runtime we have here https://github.com/allenai/longformer/tree/master/tvm? This runtime is from tvm 0.6, and I don't think the tvm 0.6 binaries would work with tvm 0.7

@ProfXGiter
Copy link
Author

can you try to uninstall tvm 0.7 dev1 and just rely on the tvm runtime we have here https://github.com/allenai/longformer/tree/master/tvm? This runtime is from tvm 0.6, and I don't think the tvm 0.6 binaries would work with tvm 0.7

Thanks, I try that.

@ProfXGiter
Copy link
Author

can you try to uninstall tvm 0.7 dev1 and just rely on the tvm runtime we have here https://github.com/allenai/longformer/tree/master/tvm? This runtime is from tvm 0.6, and I don't think the tvm 0.6 binaries would work with tvm 0.7

Still not working.

@ibeltagy
Copy link
Collaborator

@safooray, any ideas here? did you run into similar issues?

@ibeltagy
Copy link
Collaborator

ibeltagy commented Jun 26, 2020

Can you make sure that import tvm is importing our tvm directory and not another tvm installation?

Also, the binaries we have are for torch==1.2.0. Can you try that instead of 1.4.0?

@ProfXGiter
Copy link
Author

Can you make sure that import tvm is importing our tvm directory and not another tvm installation?

Also, the binaries we have are for torch==1.2.0. Can you try that instead of 1.4.0?

Sorry for reply later.
Yeah, I change the tvm==0.6.0 based gpu, torch ==1.2.0, and transformer == 2.2.0, and using tvm directory not the tvm installation, but still happened.

@ibeltagy
Copy link
Collaborator

Another suggestion; can you try running it from inside the docker container that we use to compile the cuda kernel?
Follow the instructions here: https://github.com/allenai/longformer/blob/master/scripts/cheatsheet.txt#L6 to build and run the docker image, then try to run it. You don't need to recompile the binaries, it is enough to load the existing one.

I am curious, what are you using it for, and would the sliding_chunks implementation be enough for your use case?

@ProfXGiter
Copy link
Author

Another suggestion; can you try running it from inside the docker container that we use to compile the cuda kernel?
Follow the instructions here: https://github.com/allenai/longformer/blob/master/scripts/cheatsheet.txt#L6 to build and run the docker image, then try to run it. You don't need to recompile the binaries, it is enough to load the existing one.

I am curious, what are you using it for, and would the sliding_chunks implementation be enough for your use case?

Thanks the suggestion, I try it.

I am going to do the research about using TVM or AutoTVM improve the Transformer inference time. When I look out the github, I found your repo is excellent and worth to study.

@ibeltagy
Copy link
Collaborator

Very interesting. Maybe a fused self-attention function or something. I will be curious to see how this goes.

Depending on how familiar you are with TVM, you might find the following discussions useful,
https://discuss.tvm.ai/t/optimizing-matrix-multiplication-for-gpu/4212/24
https://discuss.tvm.ai/t/competitive-gemm-matmul-example/5478
https://discuss.tvm.ai/t/developing-a-faster-schedule-for-longformers-kernel/6367

@ProfXGiter
Copy link
Author

Very interesting. Maybe a fused self-attention function or something. I will be curious to see how this goes.

Depending on how familiar you are with TVM, you might find the following discussions useful,
https://discuss.tvm.ai/t/optimizing-matrix-multiplication-for-gpu/4212/24
https://discuss.tvm.ai/t/competitive-gemm-matmul-example/5478
https://discuss.tvm.ai/t/developing-a-faster-schedule-for-longformers-kernel/6367

Thanks : )

@ibeltagy
Copy link
Collaborator

Closing. Please feel free to reopen if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants