Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: CUDA error: device-side assert triggered - is_global_attn = is_index_global_attn.flatten().any().item() #99

Closed
zarandioon opened this issue Aug 20, 2020 · 13 comments

Comments

@zarandioon
Copy link

I'm trying to train a new model from scratch where it's length is 1024 (using huggingface implementation of longformer), but I get the following exception at a line that is recently added:

--> 150         is_global_attn = is_index_global_attn.flatten().any().item()
    151 
    152         hidden_states = hidden_states.transpose(0, 1)

RuntimeError: CUDA error: device-side assert triggered

I tried Reformer and it worked as expected. The Longfomer config is as follows?

LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": 64,
  "bos_token_id": 0,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 257,
  "sep_token_id": 258,
  "type_vocab_size": 2,
  "vocab_size": 261
}

Any idea what the issue is?

@ibeltagy
Copy link
Collaborator

This line is unlikely to be the reason for the error. Try export CUDA_LAUNCH_BLOCKING=1 to disable asynchronous cuda execution, then run your code again. This will point you to the actual line that's causing the problem.

@zarandioon
Copy link
Author

For some reason it's failing when it's trying to compute token type embeddings. I'm not passing token type and, according to the BERT code, it should create a zero tensor as token type. Any idea what can cause this?

~/SageMaker/langmodel/transformers/src/transformers/modeling_bert.py in forward(self, input_ids, token_type_ids, position_ids, inputs_embeds)
    209             inputs_embeds = self.word_embeddings(input_ids)
    210         position_embeddings = self.position_embeddings(position_ids)
--> 211         token_type_embeddings = self.token_type_embeddings(token_type_ids)
    212 
    213         embeddings = inputs_embeds + position_embeddings + token_type_embeddings

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    530             result = self._slow_forward(*input, **kwargs)
    531         else:
--> 532             result = self.forward(*input, **kwargs)
    533         for hook in self._forward_hooks.values():
    534             hook_result = hook(self, input, result)

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/sparse.py in forward(self, input)
    112         return F.embedding(
    113             input, self.weight, self.padding_idx, self.max_norm,
--> 114             self.norm_type, self.scale_grad_by_freq, self.sparse)
    115 
    116     def extra_repr(self):

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   1482         # remove once script supports set_grad_enabled
   1483         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   1485 
   1486 

RuntimeError: CUDA error: device-side assert triggered

@ibeltagy
Copy link
Collaborator

probably you need to set type_vocab_size to 0 or 1 in your config. To debug, put a breakpoint before the self.token_type_embeddings line and make sure that self.token_type_embeddings and token_type_ids are compatible.

@zarandioon
Copy link
Author

I tried type_vocab_size zero and 1 last night, but that did not help. Let me check the shape of the tensor.

LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": 64,
  "bos_token_id": 256,
  "eos_token_id": 258,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 257,
  "sep_token_id": 258,
  "type_vocab_size": 0,
  "vocab_size": 261
}

@ibeltagy
Copy link
Collaborator

Not just the shapes. Also, make sure values in token_type_ids are all zeros.

@zarandioon
Copy link
Author

zarandioon commented Aug 21, 2020

The input parameters of token_type_embeddings are: config.type_vocab_size: 0 and token_type_embeddings config.hidden_size: 768

This is the shape of token_type_ids that is passed to token_type_embeddings, when it fails:
token_type_ids.size(): torch.Size([64, 1024])

I have explicitly set token_type_ids to zero by removing the if condition:

https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py#L205-L206

        #if token_type_ids is None:
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

The config at the point BertEmbeddings is being initialized:

https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bert.py#L180

config: LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    64,
    64,
    64,
    64,
    64,
    64
  ],
  "bos_token_id": 256,
  "eos_token_id": 258,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 257,
  "sep_token_id": 258,
  "type_vocab_size": 0,
  "vocab_size": 261
}

token_type_embeddings config.type_vocab_size: 0
token_type_embeddings config.hidden_size: 768

@zarandioon
Copy link
Author

I tried with type_vocab_size set to 1 and 2, and I get the same error when it's trying to embed the types!

I'm thinking about removing type embedding all together, as I do not think that applies to RoBERTa ...

config: LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    64,
    64,
    64,
    64,
    64,
    64
  ],
  "bos_token_id": 256,
  "eos_token_id": 258,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 257,
  "sep_token_id": 258,
  "type_vocab_size": 1,
  "vocab_size": 261
}

token_type_embeddings config.type_vocab_size: 1
token_type_embeddings config.hidden_size: 768
config: LongformerConfig {
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    64,
    64,
    64,
    64,
    64,
    64
  ],
  "bos_token_id": 256,
  "eos_token_id": 258,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "longformer",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 257,
  "sep_token_id": 258,
  "type_vocab_size": 2,
  "vocab_size": 261
}

token_type_embeddings config.type_vocab_size: 2
token_type_embeddings config.hidden_size: 768

@zarandioon
Copy link
Author

I removed token_type_embeddings logic, and now it fails with the same error when it's trying to add inputs_embeds and position_embeddings!! The sizes seem to match:

inputs_embeds.size(): torch.Size([64, 1024, 768])
position_embeddings.size(): torch.Size([64, 1024, 768])
    219         print(f"inputs_embeds.size(): {inputs_embeds.size()}")
    220         print(f"position_embeddings.size(): {position_embeddings.size()}")
--> 221         embeddings = inputs_embeds + position_embeddings # + token_type_embeddings
    222         embeddings = self.LayerNorm(embeddings)
    223         embeddings = self.dropout(embeddings)

RuntimeError: CUDA error: device-side assert triggered

@ibeltagy
Copy link
Collaborator

Do you still have export CUDA_LAUNCH_BLOCKING=1?

Do you have a small code snippet to reproduce the error?

@zarandioon
Copy link
Author

Yea, let me start with a clean workspace and come up with minimal code to reproduce this. It might help me to pin-point the issue as well.

@zarandioon
Copy link
Author

Using the following notebook cells and any line-by-line text data stored in ./data/train.txt, you should be to reproduce this issue.

# install transfomer
!git clone https://github.com/huggingface/transformers
!(cd transformers && pip install .)
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB_DISABLED"] = "true"
from transformers import LineByLineTextDataset
from torch.utils.data.dataset import Dataset
import torch
import os

max_len=1024

class LineByLineTextDataset(Dataset):
    def __init__(self, file_path: str):
        assert os.path.isfile(file_path), f"file {file_path} dose not exist"

        with open(file_path, encoding="utf-8") as f:
            self.examples = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return self.examples[i]

train_dataset = LineByLineTextDataset(file_path="./data/train.txt")
from transformers import LongformerConfig

CLS = 256 + 0
SEP = 256 + 2
PAD = 256 + 1
MASK = 256 + 4

config = LongformerConfig(
    bos_token_id=CLS,
    eos_token_id=SEP,
    attention_window=64,
    vocab_size=256 + 5,
    max_position_embeddings=max_len+2,
    pad_token_id=PAD,
    sep_token_id=SEP,
    num_hidden_layers=6,
    type_vocab_size=2,
)
from transformers import LongformerForMaskedLM
model = LongformerForMaskedLM(config=config)
import itertools
import numpy as np

def masked_inputs_labels(batch, mlm_probability=0.15):
    esp_mask = np.isin(batch, [CLS, SEP, PAD])
    mask_prob = np.full(batch.shape, mlm_probability)
    mask_prob[esp_mask] = 0
    mask = np.random.binomial(1, mask_prob).astype(bool)
    
    inputs = np.copy(batch)
    labels = batch
    
    # mask label
    labels[~mask] = -100  # compute loss only on masked tokens
    
    # mask input
   
    # 80% of the time, replace masked input tokens with mask_token ([MASK])
    mask_replaced = np.random.binomial(1, np.full(inputs.shape, 0.8)).astype(bool) & mask
    inputs[mask_replaced] = MASK

    # 10% of the time, we replace masked input tokens with random word
    mask_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype(bool) & mask & ~mask_replaced
    random_words = np.random.randint(256, size=inputs.shape)
    inputs[mask_random] = random_words[mask_random]
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch

def text_to_input(text, seq_len):
    bs = text.rstrip().encode('utf-8')
    bs = list(bs[:min(len(bs), max_len-2)])
    bs = [CLS] + bs + [SEP]
    
    bs = bs + [PAD] * (max_len - len(bs))

    return bs

class DataCollator:
    def __call__(self, texts) -> Dict[str, torch.Tensor]:
        token_ids = [text_to_input(text, seq_len=max_len) for text in texts]
        token_ids = np.array(list(itertools.zip_longest(*token_ids, fillvalue=PAD))).T
        inputs, labels = masked_inputs_labels(token_ids)
        print(f"inputs.shape: {inputs.shape}, labels.shape, {labels.shape}")
        return {"input_ids": torch.Tensor(inputs).long(), "labels": torch.Tensor(labels).long()}
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./model",
    overwrite_output_dir=True,
    num_train_epochs=20,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    save_steps=10_000,
    save_total_limit=10,
    local_rank=-1,
    evaluate_during_training=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=DataCollator(),
    prediction_loss_only=True,
)

trainer.train()

@zarandioon
Copy link
Author

In RoBERTa model, for some reasons (potentially for backwards compatibility reasons), the pad ids is added to the position id. For this case, I was using a PAD ID that was not zero, and this was causing the ids to go beyond the embedding index limit and that was causing this issue. Thanks ibeltagy@ for your help on this.

https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_roberta.py#L804

@ibeltagy
Copy link
Collaborator

Glad you figured it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants