In [None]:
from datasets import load_dataset

dataset = load_dataset("squad")

In [None]:
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
train_dataset

In [None]:
from tqdm import tqdm
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
train_dataset[0]

In [None]:
# Search for the longest context+answer in the train dataset
def search_max_len(dataset):
    max_len = 0
    for i in tqdm(range(len(dataset))):
        data = dataset[i]
        context = data['context']
        answer = data['answers']['text'][0]
        marked_text = f"{context} [SEP] {answer}"
        tokens = tokenizer(marked_text)['input_ids']
        max_len = max(max_len, len(tokens))
    return max_len

In [16]:
def remove_long_dataset(batch):
    contexts = batch['context']
    answers = [ans['text'][0] for ans in batch['answers']]
    
    # 构建批量的marked_text
    marked_texts = [f"{context} [SEP] {answer}" for context, answer in zip(contexts, answers)]
    
    # 使用tokenizer批量处理marked_texts
    tokenized_outputs = tokenizer(marked_texts)
    input_ids = tokenized_outputs['input_ids']
    
    # 判断每个样本的长度是否不超过512
    is_short_enough = [len(tokens) <= 512 for tokens in input_ids]
    
    return {"is_short_enough": is_short_enough}

In [17]:
new_train_dataset = train_dataset.map(remove_long_dataset, batched=True)
new_val_dataset = val_dataset.map(remove_long_dataset, batched=True)

Map:   2%|▏         | 2000/87599 [00:04<03:14, 439.54 examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (516 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 87599/87599 [02:59<00:00, 487.15 examples/s]
Map: 100%|██████████| 10570/10570 [00:22<00:00, 468.80 examples/s]


In [18]:
new_train_dataset = new_train_dataset.filter(lambda example: example['is_short_enough'])
new_val_dataset = new_val_dataset.filter(lambda example: example['is_short_enough'])

Filter: 100%|██████████| 87599/87599 [00:01<00:00, 64914.45 examples/s]
Filter: 100%|██████████| 10570/10570 [00:00<00:00, 55966.77 examples/s]


In [19]:
# Remove the "is_short_enough" key
new_train_dataset = new_train_dataset.remove_columns("is_short_enough")
new_val_dataset = new_val_dataset.remove_columns("is_short_enough")

In [21]:
# Push to Huggingface
from datasets import DatasetDict

dataset_to_upload = DatasetDict({
    'train': new_train_dataset,
    'validation': new_val_dataset
})

dataset_to_upload.push_to_hub('squad_combined_bert_512')

Creating parquet from Arrow format: 100%|██████████| 88/88 [00:04<00:00, 19.76ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:06<00:00,  6.96s/it]
Creating parquet from Arrow format: 100%|██████████| 11/11 [00:00<00:00, 19.72ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.67s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/OneFly7/squad_combined_bert_512/commit/e7c3383524299f8ad1e9daf95ae26ab12ed26d49', commit_message='Upload dataset', commit_description='', oid='e7c3383524299f8ad1e9daf95ae26ab12ed26d49', pr_url=None, pr_revision=None, pr_num=None)