In [1]:
import transformers
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from transformers import default_data_collator
import torch
from torch.utils.data import DataLoader
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_checkpoint = "distilbert-base-uncased-distilled-squad"
output_dir = "./onnx"

## 导入 squad 数据集

In [3]:
squad_v2 = False
datasets = load_dataset("squad_v2" if squad_v2 else "squad")
batch_size = 1

Found cached dataset squad (/root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 120.85it/s]


## 导入 tokenizer 和 model

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained("/work/ckpts/distilbert-base-uncased-distilled-squad")

## squad 数据处理

In [5]:
max_length = 384 # 输入数据的最大长度
doc_stride = 128 # 当切分时，重叠的长度
pad_on_right = tokenizer.padding_side == "right" # 考虑到可能 “context” 出现在左边的情况，一般在右边

### eval_dataset

In [6]:
def prepare_validation_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [7]:
eval_dataset = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-squad",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

data_collator = default_data_collator

trainer = Trainer(
    model,
    args,
    train_dataset=None,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-ef22020a0f067642.arrow


In [8]:
eval_dataset = trainer.eval_dataset
eval_dataloader = trainer.get_eval_dataloader(eval_dataset)

## 导出 ONNX

In [19]:
batch = next(iter(eval_dataloader))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# convert to tuple
input_tuple = tuple(v[0].unsqueeze(0).to(device) for k, v in batch.items())

model = model.to(device)
model.eval()
model.float()
model_to_save = model.module if hasattr(model, "module") else model

In [20]:
output_model_file = os.path.join(output_dir, "distilbert-squad.onnx")
print(f"exporting model to {output_model_file}")

axes = {0: "batch_size", 1: "seq_len"}
torch.onnx.export(
    model_to_save,
    input_tuple,
    output_model_file,
    export_params=True,
    opset_version=13,
    do_constant_folding=True,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["output_start_logits", "output_end_logits"],
    dynamic_axes={
        "input_ids": axes,
        "attention_mask": axes,
        "token_type_ids": axes,
        "output_start_logits": axes,
        "output_end_logits": axes,
    },
    verbose=True,
)

print("onnx export finished")

exporting model to ./distilbert-squad.onnx
verbose: False, log level: Level.ERROR

onnx export finished
