Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tokenization issue when training: OverflowError: can't convert negative int to unsigned #2645

Open
ZHAOFEGNSHUN opened this issue May 14, 2024 · 3 comments

Comments

@ZHAOFEGNSHUN
Copy link

ZHAOFEGNSHUN commented May 14, 2024

Traceback (most recent call last):
  File "/nfs/XLNet/XLNet/xlnet_similarity.py", line 49, in <module>
    model.fit(
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py", line 1075, in fit
    data = next(data_iterator)
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 681, in __next__
    data = self._next_data()
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 721, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py", line 908, in smart_batching_collate
    sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py", line 908, in <listcomp>
    sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/sentence_transformers/SentenceTransformer.py", line 592, in tokenize
    return self._first_module().tokenize(texts, **kwargs)
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/sentence_transformers/models/Transformer.py", line 146, in tokenize
    self.tokenizer(
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 2858, in __call__
    encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 2944, in _call_one
    return self.batch_encode_plus(
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 3135, in batch_encode_plus
    return self._batch_encode_plus(
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py", line 496, in _batch_encode_plus
    self.set_truncation_and_padding(
  File "/home/luban/.conda/envs/my/lib/python3.9/site-packages/transformers/tokenization_utils_fast.py", line 451, in set_truncation_and_padding
    self._tokenizer.enable_truncation(**target)
OverflowError: can't convert negative int to unsigned

what happend?how to slove?

@tomaarsen
Copy link
Collaborator

Hello!

That is a bit of an odd error. It's trying to do tokenization via transformers (and tokenizers?), but failing with an error that I've not seen before. I see someone else has gotten a similar issue once: aub-mind/arabert#129

You can try the following:

from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer

model_name = "my_model"
model = SentenceTransformer(model_name)
model[0].tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)

model.fit("...")

This no longer uses tokenizers (which I suspect is where the issue originates), but a Python-based tokenizer instead.

  • Tom Aarsen

@tomaarsen tomaarsen changed the title issues Tokenization issue when training: OverflowError: can't convert negative int to unsigned May 14, 2024
@ZHAOFEGNSHUN
Copy link
Author

We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 7 to truncate the input but the first sequence has a length 4. ██████████████████████████████████████████████████████████████████████████████████▌ | 5/6 [00:00<00:00, 5.38it/s]
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 12 to truncate the input but the first sequence has a length 9.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 23 to truncate the input but the first sequence has a length 20.
We need to remove 36 to truncate the input but the first sequence has a length 33.
Iteration: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 5.39it/s]
Epoch: 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 3/4 [00:03<00:01, 1.21s/it]We need to remove 7 to truncate the input but the first sequence has a length 4. | 0/6 [00:00<?, ?it/s]
We need to remove 6 to truncate the input but the first sequence has a length 3.
We need to remove 9 to truncate the input but the first sequence has a length 6.
We need to remove 9 to truncate the input but the first sequence has a length 6.
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 12 to truncate the input but the first sequence has a length 9.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 11 to truncate the input but the first sequence has a length 8. | 1/6 [00:00<00:00, 5.07it/s]
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 12 to truncate the input but the first sequence has a length 9.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 17 to truncate the input but the first sequence has a length 14.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 23 to truncate the input but the first sequence has a length 20.
We need to remove 14 to truncate the input but the first sequence has a length 11. | 2/6 [00:00<00:00, 4.99it/s]
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 36 to truncate the input but the first sequence has a length 33.
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 16 to truncate the input but the first sequence has a length 13.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 6 to truncate the input but the first sequence has a length 3. ███████████████████████▌ | 3/6 [00:00<00:00, 4.83it/s]
We need to remove 6 to truncate the input but the first sequence has a length 3.
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 6 to truncate the input but the first sequence has a length 3.
We need to remove 15 to truncate the input but the first sequence has a length 12.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 14 to truncate the input but the first sequence has a length 11. ███████████████████████████████████████████████████ | 4/6 [00:00<00:00, 4.91it/s]
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 12 to truncate the input but the first sequence has a length 9.
We need to remove 12 to truncate the input but the first sequence has a length 9.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 10 to truncate the input but the first sequence has a length 7.
We need to remove 6 to truncate the input but the first sequence has a length 3. ██████████████████████████████████████████████████████████████████████████████████▌ | 5/6 [00:01<00:00, 4.93it/s]
We need to remove 13 to truncate the input but the first sequence has a length 10.
We need to remove 14 to truncate the input but the first sequence has a length 11.
We need to remove 6 to truncate the input but the first sequence has a length 3.
We need to remove 42 to truncate the input but the first sequence has a length 39.
We need to remove 11 to truncate the input but the first sequence has a length 8.
We need to remove 18 to truncate the input but the first sequence has a length 15.
We need to remove 11 to truncate the input but the first sequence has a length 8.
Iteration: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 5.00it/s]
Epoch: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00, 1.23s/it]
wow! I think it may be caused by my sentence being too short, so should the word segmentation strategy be changed? I use the XLNET.
Here's my data format:
Total training samples: 24
[' The passenger thought he was using his balance on the app', 'The passenger thought she had put the payment option on her card and when she finished the ride it was in cash and she did not have any cash at the time.'] 1.0
[' Passenger thought he had put it on his card.', 'Passenger asked because the person who asked her said it was already paid online'] 1.0
['Passenger thought he still had a balance on the app', ' Customer said that this amount was passed on to her'] 1.0
['The passenger thought it was on the card', 'He informed that he had already paid'] 1.0
[' Passenger did not have the rest', 'You only paid 11 reais on pix'] 1.0

@ZHAOFEGNSHUN
Copy link
Author

ZHAOFEGNSHUN commented May 14, 2024

"""
This examples trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) for the STSbenchmark from scratch. It generates sentence embeddings
that can be compared using cosine-similarity to measure the similarity.

Usage:
python training_nli.py

OR
python training_nli.py pretrained_transformer_model_name
"""

from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv
from transformers import XLNetTokenizer, XLNetModel
from sentence_transformers.models import Transformer, Pooling, Normalize

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Check if dataset exists. If not, download and extract  it
sts_dataset_path = "/nfs/XLNet/XLNet/stsbenchmark.tsv.gz"

if not os.path.exists(sts_dataset_path):
    util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
model_name = '/nfs/XLNet/XLNet/xlnet-base-cased'
model = SentenceTransformer(model_name)
model[0].tokenizer = XLNetTokenizer.from_pretrained(model_name, use_fast=False)
# model_name = XLNetModel.from_pretrained("/nfs/XLNet/XLNet/xlnet-base-cased", local_files_only=True)
# tokenizer = XLNetTokenizer.from_pretrained("/nfs/XLNet/XLNet/xlnet-base-cased")

# Read the dataset
train_batch_size = 32
num_epochs = 4
model_save_path = (
    "output/training_stsbenchmark_"
)

# Use Huggingface/transformers model (like BERT, RoBERTa, XLNet, XLM-R) for mapping tokens to embeddings
word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(
    word_embedding_model.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True,
    pooling_mode_cls_token=False,
    pooling_mode_max_tokens=False,
)

# model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
    reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row["score"]) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)

        if row["split"] == "dev":
            dev_samples.append(inp_example)
        elif row["split"] == "test":
            test_samples.append(inp_example)
        else:
            train_samples.append(inp_example)


train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)


logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")


# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))


# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=num_epochs,
    evaluation_steps=1000,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
)


##############################################################################
#
# Load the stored model and evaluate its performance on STS benchmark dataset
#
##############################################################################

model = SentenceTransformer(model_save_path)
test_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_samples, name="sts-test")
test_evaluator(model, output_path=model_save_path)

This is my code.Is there something wrong with my code?

Thank you very much for your reply!!!❤️

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants