# Hugging Face Transformers 微调语言模型-问答任务

我们已经学会使用 Pipeline 加载支持问答任务的预训练模型，本教程代码将展示如何微调训练一个支持问答任务的模型。

**注意：微调后的模型仍然是通过提取上下文的子串来回答问题的，而不是生成新的文本。**

#### 模型执行问答效果示例

![Widget inference representing the QA task](docs/images/question_answering.png)

In [1]:
# 根据你使用的模型和GPU资源情况，调整以下关键参数

# 根据自身设置下载不同的数据集
squad_v2 = False
# 模型名
model_checkpoint = "distilbert-base-uncased"
# 批次大小
batch_size = 16

# 下载数据集

在本教程中，我们将使用[斯坦福问答数据集(SQuAD）](https://rajpurkar.github.io/SQuAD-explorer/)。

#### SQuAD 数据集

**斯坦福问答数据集(SQuAD)** 是一个阅读理解数据集，由众包工作者在一系列维基百科文章上提出问题组成。每个问题的答案都是相应阅读段落中的文本片段或范围，或者该问题可能无法回答。

SQuAD2.0将SQuAD1.1中的10万个问题与由众包工作者对抗性地撰写的5万多个无法回答的问题相结合，使其看起来与可回答的问题类似。要在SQuAD2.0上表现良好，系统不仅必须在可能时回答问题，还必须确定段落中没有支持任何答案，并放弃回答。


### 下载数据集

In [2]:
# 导入数据集下载包

from datasets import load_dataset

In [3]:
# 根据配置判断 下载对应的数据集

datasets = load_dataset("squad_v2" if squad_v2 else "squad")

In [4]:
# 查看数据集格式
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

#### 对比数据集

相比快速入门使用的 Yelp 评论数据集，我们可以看到 SQuAD 训练和测试集都新增了用于上下文、问题以及问题答案的列：

**YelpReviewFull Dataset：**

```json
DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

#### 数据格式：
{'id': '5733be284776f41900661182',  
 'title': 'University_of_Notre_Dame',  
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',  
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',  
 'answers':   
    {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}
}
- id:在数据集中的id  
- title:文本的标题  
- context:上下文文本  
- question:问题  
- ansers:标注后的结果  text：标注结果  问题结果在上下文的位置（以字符为单位）

In [5]:
# 查看数据集具体格式
datasets['train'][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

#### 从上下文中组织回复内容

答案是通过它们在文本中的起始位置（这里是515，注意：这里的515是以字符为单位），以及它们的完整文本表示的，这是上面提到的上下文的子字符串。  
例如 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [6]:
# 通过答案(answers)首个地址+答案长度(text)从原始文本中查找答案
answer_start = datasets['train'][0]['answers']['answer_start'][0]
answer_end = answer_start + len(datasets['train'][0]['answers']['text'][0])
print(datasets['train'][0]['context'][answer_start:answer_end])
print(datasets['train'][0]['answers']['text'])

Saint Bernadette Soubirous
['Saint Bernadette Soubirous']


In [7]:
import numpy as np 
# 导入dataset的数据类型
from datasets import ClassLabel, Sequence, Value
import pandas as pd
from IPython.display import display, HTML

In [8]:
# 用来随机抽取数据集中的数据，并且以html的方式展示出来
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= dataset.num_rows, "Can't pick more elements than there are in the dataset."
    picks = np.random.choice(dataset.num_rows, size=num_examples, replace=True).tolist()
    # 转换成pd格式
    df = pd.DataFrame(dataset[picks])

    # for column, typ in dataset.features.items():
    #     if isinstance(typ, ClassLabel):
    #         df[column] = df[column].transform(lambda i: typ.names[i])
    #     elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
    #         df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])

    display(HTML(df.to_html()))

In [9]:
# transform 可以将一个函数应用于 DataFrame 或 Series 的每个元素上。与 apply 方法不同，
df = pd.DataFrame({
    'A': [1, 2, 3, 4],
    'B': [5, 6, 7, 8]
})
df['A'] = df['A'].transform(lambda x:x+10)
df

Unnamed: 0,A,B
0,11,5
1,12,6
2,13,7
3,14,8


In [10]:
show_random_elements(dataset['train'], 3)

Unnamed: 0,id,title,context,question,answers
0,57277f9d5951b619008f8b69,Carnival,"The Slovenian countryside displays a variety of disguised groups and individual characters among which the most popular and characteristic is the Kurent (plural: Kurenti), a monstrous and demon-like, but fluffy figure. The most significant festival is held in Ptuj (see: Kurentovanje). Its special feature are the Kurents themselves, magical creatures from another world, who visit major events throughout the country, trying to banish the winter and announce spring's arrival, fertility, and new life with noise and dancing. The origin of the Kurent is a mystery, and not much is known of the times, beliefs, or purposes connected with its first appearance. The origin of the name itself is obscure.",Where is the most significant Slovenian festival held?,"{'text': ['Ptuj'], 'answer_start': [260]}"
1,57301192a23a5019007fccef,Printed_circuit_board,"At the glass transition temperature the resin in the composite softens and significantly increases thermal expansion; exceeding Tg then exerts mechanical overload on the board components - e.g. the joints and the vias. Below Tg the thermal expansion of the resin roughly matches copper and glass, above it gets significantly higher. As the reinforcement and copper confine the board along the plane, virtually all volume expansion projects to the thickness and stresses the plated-through holes. Repeated soldering or other exposition to higher temperatures can cause failure of the plating, especially with thicker boards; thick boards therefore require high Tg matrix.",What do thick boards require to resist plating failure?,"{'text': ['high Tg matrix'], 'answer_start': [655]}"
2,5728c8262ca10214002da7b6,Estonia,"In 2007, however, a large current account deficit and rising inflation put pressure on Estonia's currency, which was pegged to the Euro, highlighting the need for growth in export-generating industries. Estonia exports mainly machinery and equipment, wood and paper, textiles, food products, furniture, and metals and chemical products. Estonia also exports 1.562 billion kilowatt hours of electricity annually. At the same time Estonia imports machinery and equipment, chemical products, textiles, food products and transportation equipment. Estonia imports 200 million kilowatt hours of electricity annually.",When did a huge deficit and rising inflation place pressure on Estonia's currency?,"{'text': ['2007'], 'answer_start': [3]}"


## 预处理数据

In [11]:
# AutoTokenizer 是transformers库中的一个重要组件，用于加载和使用不同的预训练模型的标记化器（Tokenizer）。
from transformers import AutoTokenizer

In [12]:
# 想要使用的模型可以从 from_pretrained() 方法的预训练模型的名称或路径中推测出来。
# 加载预训练的模型的分词器。
# AutoTokenizer.from_pretrained 通过输入的分词器名称或路径，查找到对应的分词器。
# 该方法根据指定的模型检查点（model_checkpoint）自动加载与之相对应的预训练分词器。这个检查点通常是一个模型的名称。
# 如bert-base-uncased、gpt-2等。
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [13]:
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

#### 以下断言确保我们的 Tokenizers 使用的是 FastTokenizer 分词器（Rust 实现，速度和功能性上有一定优势）。

In [14]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

#### 使用分词器进行分词处理

有些模型具有可用的快速标记器，具体哪些模型有哪些模型没有可以通过大模型表查看。

可以直接在两个句子上调用此分词器（一个用于答案，一个用于上下文）：

双句子输入：当你传递两个句子给分词器时，它通常会将这两个句子视为一对句子。这在一些任务中很常见，比如问答任务或句子关系判断任务。

分词处理：分词器会将每个句子分割成更小的单元（词、子词或符号）。对于某些模型（如BERT），它还会添加特殊的标记，如 [CLS] 和 [SEP]，以分隔句子并标记句子的开始和结束。

输出：分词器的输出通常包含几个组件，最主要的是 input_ids（分词后的词汇表中的ID序列），以及可能的是 attention_mask（标识哪些ID是有意义的，哪些是填充的）和 token_type_ids（标识每个令牌属于哪个句子）。


In [16]:
# 使用默认选项时,能输入两个句子
token = tokenizer("what is your name?", "my name is AnMin")
print(token)
print(token.keys())
print(token['input_ids'])
# 标识哪些ID是有意义的，哪些是填充的 （1为有意义 0 为填充）
print(token['attention_mask'])
# decode 解码，根据tokenID映射回原始的句子
print(tokenizer.decode(token['input_ids']))
# convert_ids_to_tokens 根据指定得input_ids映射回原始的词
print(tokenizer.convert_ids_to_tokens(token['input_ids'][:6]))

{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 2019, 10020, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
dict_keys(['input_ids', 'attention_mask'])
[101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 2019, 10020, 102]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
[CLS] what is your name? [SEP] my name is anmin [SEP]
['[CLS]', 'what', 'is', 'your', 'name', '?']


In [22]:
# 两句话作为列表输入
tokens = tokenizer(["what is your name?", "my name is AnMin"])
print(tokens)
print(tokens.keys())
for token in tokens['input_ids']:
    print(token)
for token in tokens['attention_mask']:
    print(token)
# decode 解码，根据tokenID映射回原始的句子    
for token in tokens['input_ids']:
    print(tokenizer.decode(token))

{'input_ids': [[101, 2054, 2003, 2115, 2171, 1029, 102], [101, 2026, 2171, 2003, 2019, 10020, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}
dict_keys(['input_ids', 'attention_mask'])
[101, 2054, 2003, 2115, 2171, 1029, 102]
[101, 2026, 2171, 2003, 2019, 10020, 102]
[1, 1, 1, 1, 1, 1, 1]
[1, 1, 1, 1, 1, 1, 1]
[CLS] what is your name? [SEP]
[CLS] my name is anmin [SEP]


### Tokenizer 进阶操作 - 处理长文本问题

**在问答预处理中的一个特定问题是如何处理非常长的文档。**

在其他任务中，当文档的长度超过模型最大句子长度时，我们通常会截断它们，但在这里，如果删除上下文的一部分可能会导致我们丢失正在寻找的答案。

为了解决这个问题，允许数据集中的一个（长）示例生成多个输入特征，每个特征的长度都小于模型的最大长度（或我们设置的超参数）。

In [23]:
# The maximum length of a feature (question and context)
# 设置模型特征的最大输入长度(问题 加上 上下文)
max_length = 384

# The authorized overlap between two part of the context when splitting it is needed.
# 需要拆分上下文时，上下文的两个部分之间的授权重叠。
doc_stride = 128

#### 超出最大长度的文本数据处理方式

从训练集中找出一个超过最大长度（384）的文本 作为演示文本：

In [40]:
# 找到一个长度大于384的文本作为例子,问题和文本加起来token小于384个
for i, example in enumerate(dataset["train"]):
    # if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        # break
    if len(tokenizer(example["context"])["input_ids"]) > 384:
        break
# 挑选出来超过384（最大长度）的数据样例
example = dataset["train"][i]

example

{'id': '56be95823aeaaa14008c910c',
 'title': 'Beyoncé',
 'context': 'On April 4, 2008, Beyoncé married Jay Z. She publicly revealed their marriage in a video montage at the listening party for her third studio album, I Am... Sasha Fierce, in Manhattan\'s Sony Club on October 22, 2008. I Am... Sasha Fierce was released on November 18, 2008 in the United States. The album formally introduces Beyoncé\'s alter ego Sasha Fierce, conceived during the making of her 2003 single "Crazy in Love", selling 482,000 copies in its first week, debuting atop the Billboard 200, and giving Beyoncé her third consecutive number-one album in the US. The album featured the number-one song "Single Ladies (Put a Ring on It)" and the top-five songs "If I Were a Boy" and "Halo". Achieving the accomplishment of becoming her longest-running Hot 100 single in her career, "Halo"\'s success in the US helped Beyoncé attain more top-ten singles on the list than any other woman during the 2000s. It also included the suc

In [41]:
# 获得问题和文本的token长度
len(tokenizer(example["question"], example["context"])["input_ids"])

437

#### 文本预处理-截断上下文后 不保留超出部分 丢弃截断后的数据

truncation：参数的选项，用于设置截取的方式

- True 或 'longest_first': 这是默认选项。当输入长度超过最大长度限制时，会从最长的输入序列开始截断，直到总长度符合要求。如果有多个序列（例如，在文本对任务中），则首先截断最长的序列，如果需要，再截断第二长的序列，依此类推。

- 'only_first': 当处理一对序列时（例如，在问答任务或文本对比任务中），这个选项仅截断第一个序列（通常是问题或假设），而保留第二个序列（通常是上下文或前提）的完整性。

- 'only_second': 与'only_first'相反，这个选项仅截断第二个序列，保留第一个序列的完整性。在某些问答任务中，这可能有助于确保问题的完整性。

- False: 不进行任何截断。如果输入序列超过了模型的最大长度限制，将会抛出错误。这个选项适用于确保输入数据完全符合模型要求的场景。

# truncation截断 方式一 

从最长的输入序列开始截断，直到总长度符合要求

In [42]:
token = tokenizer(
    example["question"],
    example["context"],
    max_length = max_length, #最大的特征输入长度
    truncation = True
    )
print(token.keys())
print(len(token['input_ids']))
print(tokenizer.decode(token['input_ids']))

dict_keys(['input_ids', 'attention_mask'])
384
[CLS] beyonce got married in 2008 to whom? [SEP] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single " crazy in love ", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song " single ladies ( put a ring on it ) " and the top - five songs " if i were a boy " and " halo ". achieving the accomplishment of becoming her longest - running hot 100 single in her career, " halo "'s success in the us helped beyonce attain more top - ten singles on the list than any other w

In [45]:
tokens = tokenizer(
    [example["question"],example["context"]],
    max_length = max_length, #最大的特征输入长度
    truncation = True
    )
print(tokens.keys())
for token in tokens['input_ids']:
    print(f"token len is {len(token)}")
    print(f"data is {tokenizer.decode(token)}")

dict_keys(['input_ids', 'attention_mask'])
token len is 10
data is [CLS] beyonce got married in 2008 to whom? [SEP]
token len is 384
data is [CLS] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single " crazy in love ", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song " single ladies ( put a ring on it ) " and the top - five songs " if i were a boy " and " halo ". achieving the accomplishment of becoming her longest - running hot 100 single in her career, " halo "'s success in the us helped beyonce attain 

##### truncation截断 方式二 
only_first 截断第一个序列（通常是问题或假设），保留第二个序列（通常是上下文或前提）的完整性。
使用这种方式时候 一定要保证第二个序列的长度小于设置的max_length长度

In [None]:
tokens = tokenizer(
    example["question"],
    example["context"],
    max_length = max_length,
    truncation = "only_first"
    )
print(tokens.keys())

In [53]:
tokens = tokenizer(
    [example["question"],example["context"]],
    max_length = max_length,
    truncation = "only_first"
    )
print(tokens.keys())
for token in tokens['input_ids']:
    print(f"token len is {len(token)}")
    print(f"data is {tokenizer.decode(token)}")

dict_keys(['input_ids', 'attention_mask'])
token len is 10
data is [CLS] beyonce got married in 2008 to whom? [SEP]
token len is 384
data is [CLS] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single " crazy in love ", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song " single ladies ( put a ring on it ) " and the top - five songs " if i were a boy " and " halo ". achieving the accomplishment of becoming her longest - running hot 100 single in her career, " halo "'s success in the us helped beyonce attain 

##### truncation截断 方式三
'only_second': 仅截断第二个序列，保留第一个序列的完整性。在某些问答任务中，这可能有助于确保问题的完整性。

In [54]:
token = tokenizer(
    example["question"],
    example["context"],
    max_length = max_length,
    truncation = "only_second"
    )
print(token.keys())
print(len(token['input_ids']))
print(tokenizer.decode(token['input_ids']))

dict_keys(['input_ids', 'attention_mask'])
384
[CLS] beyonce got married in 2008 to whom? [SEP] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single " crazy in love ", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song " single ladies ( put a ring on it ) " and the top - five songs " if i were a boy " and " halo ". achieving the accomplishment of becoming her longest - running hot 100 single in her career, " halo "'s success in the us helped beyonce attain more top - ten singles on the list than any other w

In [None]:
# tokenizer期望得到两个序列进行相应的处理，但如果只接收到一个序列，就会引发异常。
tokens = tokenizer(
    [example["question"],example["context"]],
    max_length = max_length,
    truncation = "only_second"
    )
print(tokens.keys())
# for token in tokens['input_ids']:
#     print(f"token len is {len(token)}")
#     print(f"data is {tokenizer.decode(token)}")

#### 关于处理截断后产生的超出部分的文本的策略

参数：
- truncation：参数的选项，设置截取的方式
- return_overflowing_tokens： 用于设置超出设置长度后被截取后的文本，如何处理。当你设置 return_overflowing_tokens=True 时，分词器会返回一个额外的字段(overflow_to_sample_mapping)。同时input_ids和attention_mask变成列表格式，返回的包括截取前的也包括被截取后补偿的

<!-- - 直接截断超出部分: 当 truncation=`only_second` 时，截断第二个序列，保留第一个序列的完整性
- 仅截断上下文（context），保留问题（question）：设置 `return_overflowing_tokens=True` 和设置`stride`长度时 stride为截断后要补偿的 -->

返回值:
- input_ids：编码，是一个列表，包含截断前的 和 阶段后超出部分的
- attention_mask：标识哪些ID是有意义的，哪些是填充的
- overflow_to_sample_mapping：截断后的序列分别来自哪个原始文本。

In [65]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)
print(tokenized_example.keys())
# 按照max_length截取的
print(f"token len is {len(tokenized_example['input_ids'][0])}")
print(f"data is {tokenizer.decode(tokenized_example['input_ids'][0])}")
# 截取后按照doc_stride补偿的 192 = 128 + 10 + 428 - 384 + 10
print(f"token len is {len(tokenized_example['input_ids'][0])}")
print(f"data is {tokenizer.decode(tokenized_example['input_ids'][0])}")
print(tokenized_example['overflow_to_sample_mapping'])

dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping'])
token len is 384
data is [CLS] beyonce got married in 2008 to whom? [SEP] on april 4, 2008, beyonce married jay z. she publicly revealed their marriage in a video montage at the listening party for her third studio album, i am... sasha fierce, in manhattan's sony club on october 22, 2008. i am... sasha fierce was released on november 18, 2008 in the united states. the album formally introduces beyonce's alter ego sasha fierce, conceived during the making of her 2003 single " crazy in love ", selling 482, 000 copies in its first week, debuting atop the billboard 200, and giving beyonce her third consecutive number - one album in the us. the album featured the number - one song " single ladies ( put a ring on it ) " and the top - five songs " if i were a boy " and " halo ". achieving the accomplishment of becoming her longest - running hot 100 single in her career, " halo "'s success in the us helped beyonce attain 

使用此策略截断后，Tokenizer 将返回多个 `input_ids` 列表。

In [47]:
[len(x) for x in tokenized_example["input_ids"]]

[384, 157]

可以看到没有设置stride的时候，则默认与max_length长度相同 每次阶段出来的都是满足384个

In [52]:
print(tokenizer.decode(tokenized_example['input_ids'][0])[:100])
print(tokenizer.decode(tokenized_example['input_ids'][1])[:100])

[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team 
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the most by the fighting i


#### 使用 offsets_mapping 获取原始的 input_ids

设置 `return_offsets_mapping=True`，将使得截断分割生成的多个 input_ids 列表中的 token，通过映射保留原始文本的 input_ids。

当 return_offsets_mapping=True 时，分词器会为每个令牌返回一个元组，表示该令牌在原始未分词文本中的字符级偏移量。这个元组的形式通常是 (start, end)，

其中 start 是令牌在原文中的开始位置，end 是结束位置（不包括该位置）。这里的偏移指的是 字母级别的偏移

如下所示：第一个标记（[CLS]）的起始和结束字符都是（0, 0），因为它不对应问题/答案的任何部分，然后第二个标记与问题(question)的字符0到3相同.



In [71]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)

In [68]:
tokenized_example.keys()

dict_keys(['input_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [73]:
tokenized_example['offset_mapping']

[[(0, 0),
  (0, 7),
  (8, 11),
  (12, 19),
  (20, 22),
  (23, 27),
  (28, 30),
  (31, 35),
  (35, 36),
  (0, 0),
  (0, 2),
  (3, 8),
  (9, 10),
  (10, 11),
  (12, 16),
  (16, 17),
  (18, 25),
  (26, 33),
  (34, 37),
  (38, 39),
  (39, 40),
  (41, 44),
  (45, 53),
  (54, 62),
  (63, 68),
  (69, 77),
  (78, 80),
  (81, 82),
  (83, 88),
  (89, 93),
  (93, 96),
  (97, 99),
  (100, 103),
  (104, 113),
  (114, 119),
  (120, 123),
  (124, 127),
  (128, 133),
  (134, 140),
  (141, 146),
  (146, 147),
  (148, 149),
  (150, 152),
  (152, 153),
  (153, 154),
  (154, 155),
  (156, 161),
  (162, 168),
  (168, 169),
  (170, 172),
  (173, 182),
  (182, 183),
  (183, 184),
  (185, 189),
  (190, 194),
  (195, 197),
  (198, 205),
  (206, 208),
  (208, 209),
  (210, 214),
  (214, 215),
  (216, 217),
  (218, 220),
  (220, 221),
  (221, 222),
  (222, 223),
  (224, 229),
  (230, 236),
  (237, 240),
  (241, 249),
  (250, 252),
  (253, 261),
  (262, 264),
  (264, 265),
  (266, 270),
  (271, 273),
  (274, 277)

In [74]:

start, end = tokenized_example['offset_mapping'][0][1]
example["question"][start:end]

'Beyonce'

In [84]:
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

how How


In [85]:
second_token_id = tokenized_example["input_ids"][0][2]
offsets = tokenized_example["offset_mapping"][0][2]
print(tokenizer.convert_ids_to_tokens([second_token_id])[0], example["question"][offsets[0]:offsets[1]])

many many


#### convert_ids_to_tokens 和 decoder 区别：
#### convert_ids_to_tokens：可以是token序列号
#### decoder：是在整个字符串级别上进行的 不能多个


In [58]:
# 问题
example["question"]

"How many wins does the Notre Dame men's basketball team have?"

借助`sequence_ids`方法，我们可以方便的区分token的来源编号：

- 对于特殊标记：返回None，
- 对于正文Token：返回句子编号（从0开始编号）。

综上，现在我们可以很方便的在一个输入特征中找到答案的起始和结束 Token。

In [76]:
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [56]:
len(sequence_ids)

384

In [63]:
# 获取回答结果
answers = example['answers']
print(answers)


{'text': ['over 1,600'], 'answer_start': [30]}


In [64]:
# 获取回答的结果在文本上的起始地址（以字符为单位的）
start_char = answers['answer_start'][0]
# 获取回答的结果在文本上的结束地址（以字符为单位的
end_char = start_char + len(answers['text'][0])

计算文本内容在tokenized_example中的位置

In [60]:
# 当前span在文本中的起始标记索引。sequence_id为token的来源编号
# sequence_ids是每个ID token的来源不同文本的编号：
# 这里计算出contxt在整个文本中的位置
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1
# 当前span在文本中的结束标记索引。计算是第一句的长度
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1
# 找到上下文文本在整个文本中的位置，整个文本 = question + context
token_start_index, token_end_index

(16, 382)

In [None]:
# 检测答案是否超出span范围（如果超出范围，该特征将以CLS标记索引标记）。
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # 将token_start_index和token_end_index移动到答案的两端。
    # 注意：如果答案是最后一个单词，我们可以移到最后一个标记之后（边界情况）。
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("答案不在此特征中。")

In [None]:
# 通过查找 offset mapping 位置，解码 context 中的答案 
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
# 直接打印 数据集中的标准答案（answer["text"])
print(answers["text"][0])

#### 关于填充的策略

- 对于没有超过最大长度的文本，填充补齐长度。
- 对于需要左侧填充的模型，交换 question 和 context 顺序

In [67]:
pad_on_right = tokenizer.padding_side == "right"

### 整合以上所有预处理步骤

让我们将所有内容整合到一个函数中，并将其应用到训练集。

针对不可回答的情况（上下文过长，答案在另一个特征中），我们为开始和结束位置都设置了cls索引。

如果allow_impossible_answers标志为False，我们还可以简单地从训练集中丢弃这些示例。

In [None]:
def prepare_train_features(examples):
    ''' example datasetlaoder 后的数据'''
    # 一些问题的左侧可能有很多空白字符，这对我们没有用，而且会导致上下文的截断失败
    # （标记化的问题将占用大量空间）。因此，我们删除左侧的空白字符。
    examples["question"] = [q.lstrip() for q in examples["question"]]
    # 获取token编码，这里注意需要判断是填充还是截取
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    # 由于一个示例可能给我们提供多个特征（如果它具有很长的上下文），我们需要一个从特征到其对应示例的映射。这个键就提供了这个映射关系。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # 偏移映射将为我们提供从令牌到原始上下文中的字符位置的映射。这将帮助我们计算开始位置和结束位置。
    offset_mapping = tokenized_examples.pop("offset_mapping")
    # 让我们为这些示例进行标记！
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []
    for i, offsets in enumerate(offset_mapping):
        # 我们将使用 CLS 特殊 token 的索引来标记不可能的答案。
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # 获取与该示例对应的序列（以了解上下文和问题是什么）。
        sequence_ids = tokenized_examples.sequence_ids(i)

        # 一个示例可以提供多个跨度，这是包含此文本跨度的示例的索引。
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # 如果没有给出答案，则将cls_index设置为答案。
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案在文本中的开始和结束字符索引。
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # 当前跨度在文本中的开始令牌索引。
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # 当前跨度在文本中的结束令牌索引。
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # 检测答案是否超出跨度（在这种情况下，该特征的标签将使用CLS索引）。
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # 否则，将token_start_index和token_end_index移到答案的两端。
                # 注意：如果答案是最后一个单词（边缘情况），我们可以在最后一个偏移之后继续。
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples


#### datasets.map 的进阶使用

使用 `datasets.map` 方法将 `prepare_train_features` 应用于所有训练、验证和测试数据：

- batched: 批量处理数据。
- remove_columns: 因为预处理更改了样本的数量，所以在应用它时需要删除旧列。
- load_from_cache_file：是否使用datasets库的自动缓存

datasets 库针对大规模数据，实现了高效缓存机制，能够自动检测传递给 map 的函数是否已更改（因此需要不使用缓存数据）。如果在调用 map 时设置 `load_from_cache_file=False`，可以强制重新应用预处理。

In [None]:
tokenized_datasets = datasets.map(prepare_train_features,
                                  batched=True,
                                  remove_columns=datasets["train"].column_names)

## 微调模型

现在我们的数据已经准备好用于训练，我们可以下载预训练模型并进行微调。

由于我们的任务是问答，我们使用 `AutoModelForQuestionAnswering` 类。(对比 Yelp 评论打分使用的是 `AutoModelForSequenceClassification` 类）

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [84]:
from transformers import AutoModelForQuestionAnswering, Trainer, TrainingArguments

In [86]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### 设置训练超参数（TrainingArguments）

In [None]:
batch_size=64
model_dir = f"models/{model_checkpoint}-finetuned-squad"

args = TrainingArguments(
    output_dir=model_dir,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

#### Data Collator（数据整理器）

数据整理器将训练数据整理为批次数据，用于模型训练时的批次处理。本教程使用默认的 `default_data_collator`。

In [None]:
from transformers import default_data_collator

data_collator = default_data_collator

### 实例化训练器（Trainer）

为了减少训练时间（需要大量算力支持），我们不在本教程的训练模型过程中计算模型评估指标。

而是训练完成后，再独立进行模型评估。

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()