# ALBERT
[ALBERT](https://arxiv.org/pdf/1909.11942.pdf),A Lite Bert    
参考链接：   
https://amitness.com/2020/02/albert-visual-summary/

## BERT 的缺点
内存限制和通讯开销：
- BERT非常大的模型，巨量的训练参数，想要从头开始训练，需要大量的计算资源，且计算机内存会限制模型大小
- 一种解决方法就是分布式训练；数据被拆分到多个机器上，单独进行训练，然后再同步不同的梯度，大量参数的同步过程造成较大的通讯开销，延缓训练过程；同理 模型并行化 时，也会造成同样的瓶颈 

模型退化：
- 模型变大，并不一定会获得更好的性能，反而性能可能下降

## ALBERT的改进
 ALBERT 本质就是对 BERT 模型压缩后的产物。
 > 模型压缩有很多手段，包括剪枝，参数共享，低秩分解，网络结构设计，知识蒸馏等
 
 ALBERT降低了模型参数量从而降低了模型的训练时间（通信开销降低），但是，**模型的预测推理时间并没有降低**。

### Cross-layer parameter sharing
- BERT-LARGE 有 24 层，而 BERT-BASE 有 12 层；随着层数增加，参数量指数级增加   
      
      
- ALBERT 提出了 跨层参数共享 的概念，不是单独学习 BERT 每一层的参数，而是将每一层的参数设置为相同
    - 可以只共享 feed-forward 层的参数，或 attention 层的参数，或整个编码层的参数
    - 共享整层参数时，模型参数量明显下降，从 110M 下降到 31M 
    
- 参数共享对性能的影响很小，且能够帮助模型稳定网络的参数。这点是通过L2距离与 cosine similarity 得出的

### SOP 替代 NSP
- SOP 全称为 Sentence Order Prediction，其用来取代 NSP 在 BERT 中的作用，
    - ROBERTA 和 XLNET 模型显示 NSP 非但没有作用，反而会对模型带来一些损害。删除 NSP 任务，模型在一些列任务上的性能会提升
    - NSP 相比 MLM 并不是一个难的任务，NSP 设置到主题预测和连贯性预测，主题预测比较容易学习，因为与 MLM 任务重合；因此就算没有学到 连贯性预测，NSP 也会得到较高的精度
    
      
- SOP的方式与NSP相似，也是判断第二句话是不是第一句话的下一句，但对于负例来说，SOP并不从不相关的句子中生成，而是将原来连续的两句话翻转形成负例。

### Factorized embedding parameterization

通过对Embedding 部分降维来达到降低参数的作用。
- BERT-Base 中 Embedding 层的维度与隐层的维度一样都是768；  
- 而对于词的分布式表示，如 Word2Vec 就多采用50或300这样的维度。
- 通过将 Embedding 部分分解来达到降低参数量的作用，其以公式表示如下：

$$O(V\times H) \rightarrow O(V\times E+E\times H)$$

    V：词表大小；H：隐层维度；E：词向量维度

以 BERT-Base 为例，    
- 词表大小为3w，此时的参数量为：768 * 3w = 23040000。 
- 如果将 Embedding 的维度改为 128，那么此时Embedding层的参数量为： 128 * 3w + 128 * 768 = 3938304。
- Embedding参数量从原来的23M变为了现在的4M，
- 但从整个模型来看，BERT-Base的参数量在 110M，降低19M也不能产生什么革命性的变化。
- 可以说 Embedding 层的因式分解其实并不是降低参数量的主要手段。
- 意忽略了Position Embedding的那部分参数量， 主要是考虑到512相对于3W显得有点微不足道。

### ALBERT 的性能提升
- 参数量比 BERT-large 小 18 倍
- 训练比 BERT 快 1.7 倍
- 在 GLUE, RACE 和 SQUAD 任务上获得最佳性能

# ALBERT 中文实体识别

## 分词器

In [1]:
from transformers import BertTokenizer, AlbertForTokenClassification
model_path = "../../H/models/huggingface/torch/albert_chinese_xlarge/"
tokenizer = BertTokenizer.from_pretrained(model_path)

## 数据集
预料来源：https://github.com/InsaneLife/ChineseNLPCorpus/tree/master/NER/renMinRiBao

In [2]:
import codecs

input_data = codecs.open('../datasets/ner/renmin/renmin4.txt', 'r', 'utf-8')

# 1. 将标注子句 拆分成 字列表 和 对应的标注列表 #############
#####################################################
datas = []
labels = []

# 表征该处的标签是填充的,应不应该使用？？？？？？
# tags = set(['PAD']) 
tags = set()

for line in input_data.readlines():
    linedata = list()
    linelabel = list()

    line = line.split()

    numNotO = 0
    for word in line:
        word = word.split('/')
        linedata.append(word[0])
        linelabel.append(word[1])

        tags.add(word[1])

        if word[1] != 'O':  # 标注全为 O 的子句
            numNotO += 1

    if numNotO != 0:  # 只保存 标注不全为 O 的子句
        datas.append(linedata)
        labels.append(linelabel)

input_data.close()
print("文本序列的数量：", len(datas))  # 字列表 组成的列表
assert (len(labels) == len(datas))  # 对应的 标注列表 组成的列表

# 2. 创建标签字典 ################################
#####################################################

print("所有标签：", tags)
tag2id = {tag: i for i, tag in enumerate(tags)}

id2tag = {i: tag for tag, i in tag2id.items()}
print(id2tag)

print("-" * 80)


文本序列的数量： 37924
所有标签： {'B_nt', 'M_nr', 'B_nr', 'E_ns', 'E_nr', 'B_ns', 'M_ns', 'O', 'E_nt', 'M_nt'}
{0: 'B_nt', 1: 'M_nr', 2: 'B_nr', 3: 'E_ns', 4: 'E_nr', 5: 'B_ns', 6: 'M_ns', 7: 'O', 8: 'E_nt', 9: 'M_nt'}
--------------------------------------------------------------------------------


In [3]:
tag2id

{'B_nt': 0,
 'M_nr': 1,
 'B_nr': 2,
 'E_ns': 3,
 'E_nr': 4,
 'B_ns': 5,
 'M_ns': 6,
 'O': 7,
 'E_nt': 8,
 'M_nt': 9}

In [4]:
# 为了便于理解，将标签排序

tags = [
    'B_ns', 'M_ns', 'E_ns', 'B_nr', 'M_nr', 'E_nr', 'B_nt', 'M_nt', 'E_nt',
    'O',
]
tag2id = {tag: idx for idx, tag in enumerate(tags)}
id2tag = {idx: tag for idx, tag in enumerate(tags)}
tag2id, id2tag

({'B_ns': 0,
  'M_ns': 1,
  'E_ns': 2,
  'B_nr': 3,
  'M_nr': 4,
  'E_nr': 5,
  'B_nt': 6,
  'M_nt': 7,
  'E_nt': 8,
  'O': 9},
 {0: 'B_ns',
  1: 'M_ns',
  2: 'E_ns',
  3: 'B_nr',
  4: 'M_nr',
  5: 'E_nr',
  6: 'B_nt',
  7: 'M_nt',
  8: 'E_nt',
  9: 'O'})

In [5]:
tags_count = {}
for seq in labels:
    for tag in seq:
        tags_count[tag] = tags_count.get(tag, 0) + 1
print("各类标签出现数量：")
for tag in tags:
    if tag in tags_count:
        print(tag, ":", tags_count[tag])
print("总标签数：", sum([len(seq) for seq in labels]))

各类标签出现数量：
B_ns : 22427
M_ns : 13579
E_ns : 22427
B_nr : 19981
M_nr : 18235
E_nr : 19981
B_nt : 10834
M_nt : 40955
E_nt : 10834
O : 389034
总标签数： 568287


In [6]:
# 处理成等长

import numpy as np


def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int64',
                  padding='post',
                  truncating='post',
                  value=0.):

    num_samples = len(sequences)
    lengths = [len(sample) for sample in sequences]

    if maxlen is None:
        maxlen = np.max(lengths)

    x = np.full((num_samples, maxlen), value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        trunc = np.asarray(trunc, dtype=dtype)

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

In [7]:
# 输入向量
input_ids = pad_sequences(
    [tokenizer.convert_tokens_to_ids(seq) for seq in datas],
    maxlen = 60,
)
input_ids.shape

(37924, 60)

In [8]:
input_ids[0]

array([ 704, 1066,  704, 1925, 2600,  741, 6381,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0])

In [9]:
# 检查是否正确
print(datas[999])
print(tokenizer.convert_ids_to_tokens(input_ids[999]))

['在', '一', '个', '中', '国', '的', '原', '则', '下']
['在', '一', '个', '中', '国', '的', '原', '则', '下', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']


In [10]:
# 训练目标序列
tags = pad_sequences(
    [[tag2id[l] for l in seq] for seq in labels],
    maxlen=60,
    value=0.,
)
tags.shape

(37924, 60)

In [11]:
tags[0]

array([6, 7, 7, 8, 9, 9, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [12]:
print(labels[999])
print([id2tag[idx] for idx in tags[999]])

['O', 'O', 'O', 'B_ns', 'E_ns', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'B_ns', 'E_ns', 'O', 'O', 'O', 'O', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns', 'B_ns']


In [13]:
# 掩码，表征哪些元素是填充的

masks = (input_ids != 0).astype(np.float)  # float 类型
masks
masks[999]

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [14]:
# 拆分数据集，训练、验证、测试

from sklearn.model_selection import train_test_split
tr_inputs, val_inputs, tr_tags, val_tags, tr_masks, val_masks = train_test_split(
    input_ids,
    tags,
    masks,
    random_state=2018,
    test_size=0.25,
)

In [15]:
tr_inputs.shape, val_inputs.shape

((28443, 60), (9481, 60))

In [16]:
# 转换成 torch 张量

import torch

# input_ids ，数据类型：torch.LongTensor，形状：(batch_size, sequence_length)
tr_inputs = torch.tensor(tr_inputs)
val_inputs = torch.tensor(val_inputs)

# labels，数据类型：torch.LongTensor，形状：(batch_size, sequence_length)
tr_tags = torch.tensor(tr_tags)
val_tags = torch.tensor(val_tags)

# attention_mask，数据类型：torch.FloatTensor，形状：(batch_size, sequence_length)
tr_masks = torch.tensor(tr_masks)
val_masks = torch.tensor(val_masks)

In [17]:
tr_inputs.shape, val_inputs.shape

(torch.Size([28443, 60]), torch.Size([9481, 60]))

In [18]:
# 创建批量数据集
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

bs = 32
train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=bs)

valid_data = TensorDataset(val_inputs, val_masks, val_tags)
valid_sampler = SequentialSampler(valid_data)
valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=bs)

In [19]:
print(len(train_dataloader), bs * len(train_dataloader))

889 28448


In [20]:
print(len(valid_dataloader), bs * len(valid_dataloader))

297 9504


## 创建模型
模型引用：https://huggingface.co/voidful/albert_chinese_xlarge#

In [21]:
from transformers import AlbertForTokenClassification

In [22]:
tokenizer = BertTokenizer.from_pretrained(model_path)
model = AlbertForTokenClassification.from_pretrained(
    model_path,
    num_labels=len(id2tag),
    output_attentions=False,
    output_hidden_states=False,
)

In [23]:
model.cuda()

AlbertForTokenClassification(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(21128, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=2048, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((2048,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=2048, out_features=2048, bias=True)
                (key): Linear(in_features=2048, out_features=2048, bias=True)
                (value): Linear(in_features=2048, out_features=

## 训练模型

In [None]:
class ZhAlbert(nn.Module):
    def __init__(self, model_path, ff_dim, num_tags, dropout):
        self.num_tags = num_tags
        
        self.albert = AlbertModel.from_pretrained(model_path)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(ff_dim, num_tags)
    
    def forward(self, x, mask, y):
        

In [24]:
import transformers
from transformers import BertForTokenClassification, AdamW

# 1. 优调整个模型，bert 模型及其上的分类层

FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]

# 2. 仅仅训练最顶层的分类层    
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{
        "params": [p for n, p in param_optimizer]
    }]

# 优化器    
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)

In [28]:
from transformers import get_linear_schedule_with_warmup

epochs = 20
max_grad_norm = 1.0

# 总的训练次数
total_steps = len(train_dataloader) * epochs

# 学习率规划
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps,
)

In [29]:
from sklearn.metrics import f1_score, accuracy_score, classification_report

In [30]:
from tqdm import tqdm, trange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 记录每一周次训练完的平均损失和验证损失
loss_values, validation_loss_values = [], []

for epoch in trange(epochs, desc="Epoch"):
    # ========================================
    #               训练
    # ========================================

    #  训练模式
    model.train()

    # 损失
    total_loss = 0

    # 训练循环
    for step, batch in enumerate(train_dataloader):
        # 数据 gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # 梯度清零
        model.zero_grad()

        # 前向计算，获得损失
        outputs = model(b_input_ids,
                        token_type_ids=None,
                        attention_mask=b_input_mask,
                        labels=b_labels)
        loss = outputs[0]

        # 反向传播
        loss.backward()

        # 累加损失
        total_loss += loss.item()

        # 梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(),
                                       max_norm=max_grad_norm)
        # 更新参数
        optimizer.step()
        # 更新学习率
        scheduler.step()

        if step % 100 == 0:
            print("Epoch: {}, Step: {}, Train loss: {}".format(
                epoch, step, total_loss / (step + 1)))
        

    # 计算每一训练循环的平均损失
    avg_train_loss = total_loss / len(train_dataloader)
    print("Epoch: {}, Average train loss: {} ".format(epoch, avg_train_loss))

    loss_values.append(avg_train_loss)

    # ========================================
    #               验证
    # ========================================

    # 验证模式
    model.eval()

    # 验证损失及验证精度
    eval_loss, eval_accuracy = 0, 0

    nb_eval_steps, nb_eval_examples = 0, 0

    predictions, true_labels = [], []
    for batch in valid_dataloader:
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch

        # 验证时，不更新梯度
        with torch.no_grad():

            outputs = model(
                b_input_ids,
                token_type_ids=None,
                attention_mask=b_input_mask,
                labels=b_labels,  # 没有提供标签值，返回值为 权重分布，而不是损失值
            )

        # 数据移动到 cpu 上
        logits = outputs[1].detach().cpu()
        label_ids = b_labels.to('cpu')
        b_input_mask = b_input_mask.to('cpu')

        # 累加损失值
        eval_loss += outputs[0].mean().item()

        # 预测标签
        b_preds = torch.argmax(logits, dim=2)

        predictions.append(b_preds.masked_select(b_input_mask.bool()))
        true_labels.append(label_ids.masked_select(b_input_mask.bool()))

    eval_loss = eval_loss / len(valid_dataloader)
    validation_loss_values.append(eval_loss)
    print("Validation loss: {} at epoch {}".format(eval_loss, epoch))
    
    predictions = torch.cat(predictions)
    true_labels = torch.cat(true_labels)
    

    
# 计算精度
    pred_tags = [id2tag[idx] for idx in predictions.tolist()]
    valid_tags = [id2tag[idx] for idx in true_labels.tolist()]
    print("Validation Accuracy: {}at epoch {}".format(
        accuracy_score(valid_tags, pred_tags), epoch))
    print("Validation F1-Score: {}at epoch {}".format(
        f1_score(valid_tags, pred_tags, average='macro'), epoch))
    valid_report = classification_report(valid_tags, pred_tags)
    print(valid_report)
    print()

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Train loss: 0.5352237224578857
Epoch: 0, Step: 100, Train loss: 0.42720396300353625
Epoch: 0, Step: 200, Train loss: 0.43819765605736727
Epoch: 0, Step: 300, Train loss: 0.43702682377492075
Epoch: 0, Step: 400, Train loss: 0.4368162702889811
Epoch: 0, Step: 500, Train loss: 0.43662129369324554
Epoch: 0, Step: 600, Train loss: 0.43581701044036625
Epoch: 0, Step: 700, Train loss: 0.433231198753678
Epoch: 0, Step: 800, Train loss: 0.4292947600843308
Epoch: 0, Average train loss: 0.42743512190233063 
Validation loss: 0.39762468525656947 at epoch 0
Validation Accuracy: 0.8272046281555606at epoch 0
Validation F1-Score: 0.49681934244747694at epoch 0


Epoch:   5%|▌         | 1/20 [19:16<6:06:17, 1156.73s/it]

              precision    recall  f1-score   support

        B_nr       0.35      0.58      0.43      4839
        B_ns       0.33      0.25      0.28      5576
        B_nt       0.55      0.17      0.26      2695
        E_nr       0.35      0.32      0.33      4838
        E_ns       0.45      0.53      0.49      5574
        E_nt       0.64      0.81      0.72      2692
        M_nr       0.58      0.17      0.27      4382
        M_ns       0.44      0.47      0.45      3453
        M_nt       0.72      0.80      0.76     10073
           O       0.98      0.98      0.98     96582

    accuracy                           0.83    140704
   macro avg       0.54      0.51      0.50    140704
weighted avg       0.83      0.83      0.82    140704


Epoch: 1, Step: 0, Train loss: 0.32099899649620056
Epoch: 1, Step: 100, Train loss: 0.38655441940420926
Epoch: 1, Step: 200, Train loss: 0.3707519302617258
Epoch: 1, Step: 300, Train loss: 0.36644711590486506
Epoch: 1, Step: 400, Train loss

Epoch:  10%|█         | 2/20 [38:36<5:47:15, 1157.51s/it]

              precision    recall  f1-score   support

        B_nr       0.79      0.85      0.82      4839
        B_ns       0.91      0.84      0.88      5576
        B_nt       0.85      0.88      0.86      2695
        E_nr       0.86      0.69      0.76      4838
        E_ns       0.84      0.79      0.82      5574
        E_nt       0.84      0.84      0.84      2692
        M_nr       0.82      0.79      0.81      4382
        M_ns       0.70      0.79      0.74      3453
        M_nt       0.87      0.89      0.88     10073
           O       0.98      0.98      0.98     96582

    accuracy                           0.93    140704
   macro avg       0.85      0.84      0.84    140704
weighted avg       0.93      0.93      0.93    140704


Epoch: 2, Step: 0, Train loss: 0.16232149302959442
Epoch: 2, Step: 100, Train loss: 0.17172129481735796
Epoch: 2, Step: 200, Train loss: 0.1650525854335199
Epoch: 2, Step: 300, Train loss: 0.1634592230546217
Epoch: 2, Step: 400, Train loss:

Epoch:  15%|█▌        | 3/20 [57:55<5:28:05, 1157.98s/it]

              precision    recall  f1-score   support

        B_nr       0.95      0.96      0.95      4839
        B_ns       0.93      0.92      0.92      5576
        B_nt       0.90      0.90      0.90      2695
        E_nr       0.90      0.95      0.93      4838
        E_ns       0.90      0.90      0.90      5574
        E_nt       0.85      0.91      0.88      2692
        M_nr       0.94      0.95      0.94      4382
        M_ns       0.89      0.74      0.81      3453
        M_nt       0.88      0.91      0.89     10073
           O       0.98      0.98      0.98     96582

    accuracy                           0.96    140704
   macro avg       0.91      0.91      0.91    140704
weighted avg       0.96      0.96      0.96    140704


Epoch: 3, Step: 0, Train loss: 0.08002409338951111
Epoch: 3, Step: 100, Train loss: 0.11343243734745106
Epoch: 3, Step: 200, Train loss: 0.11030179945594487
Epoch: 3, Step: 300, Train loss: 0.10515119568243672
Epoch: 3, Step: 400, Train los

Epoch:  20%|██        | 4/20 [1:17:13<5:08:50, 1158.18s/it]

              precision    recall  f1-score   support

        B_nr       0.95      0.97      0.96      4839
        B_ns       0.94      0.91      0.93      5576
        B_nt       0.88      0.93      0.91      2695
        E_nr       0.94      0.96      0.95      4838
        E_ns       0.95      0.87      0.91      5574
        E_nt       0.89      0.94      0.92      2692
        M_nr       0.96      0.96      0.96      4382
        M_ns       0.85      0.85      0.85      3453
        M_nt       0.90      0.94      0.92     10073
           O       0.99      0.98      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.93      0.93      0.93    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 4, Step: 0, Train loss: 0.004194826819002628
Epoch: 4, Step: 100, Train loss: 0.08078027643853485
Epoch: 4, Step: 200, Train loss: 0.08059446531259895
Epoch: 4, Step: 300, Train loss: 0.07731485844549546
Epoch: 4, Step: 400, Train lo

Epoch:  25%|██▌       | 5/20 [1:36:32<4:49:34, 1158.31s/it]

              precision    recall  f1-score   support

        B_nr       0.95      0.98      0.96      4839
        B_ns       0.96      0.92      0.94      5576
        B_nt       0.92      0.94      0.93      2695
        E_nr       0.95      0.96      0.95      4838
        E_ns       0.93      0.90      0.92      5574
        E_nt       0.90      0.91      0.91      2692
        M_nr       0.96      0.95      0.96      4382
        M_ns       0.86      0.87      0.86      3453
        M_nt       0.90      0.94      0.92     10073
           O       0.99      0.98      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.93      0.94      0.93    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 5, Step: 0, Train loss: 0.03187444806098938
Epoch: 5, Step: 100, Train loss: 0.05894992280622373
Epoch: 5, Step: 200, Train loss: 0.05865482999062842
Epoch: 5, Step: 300, Train loss: 0.05919558091624408
Epoch: 5, Step: 400, Train los

Epoch:  30%|███       | 6/20 [1:55:51<4:30:19, 1158.54s/it]

              precision    recall  f1-score   support

        B_nr       0.97      0.98      0.97      4839
        B_ns       0.94      0.95      0.94      5576
        B_nt       0.92      0.93      0.93      2695
        E_nr       0.94      0.96      0.95      4838
        E_ns       0.90      0.94      0.92      5574
        E_nt       0.91      0.93      0.92      2692
        M_nr       0.97      0.97      0.97      4382
        M_ns       0.83      0.90      0.86      3453
        M_nt       0.92      0.93      0.93     10073
           O       0.99      0.98      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.93      0.95      0.94    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 6, Step: 0, Train loss: 0.017148686572909355
Epoch: 6, Step: 100, Train loss: 0.04474899180274973
Epoch: 6, Step: 200, Train loss: 0.04890611589902465
Epoch: 6, Step: 300, Train loss: 0.04962078881631002
Epoch: 6, Step: 400, Train lo

Epoch:  35%|███▌      | 7/20 [2:15:09<4:10:59, 1158.41s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.96      0.97      4839
        B_ns       0.94      0.95      0.94      5576
        B_nt       0.94      0.93      0.93      2695
        E_nr       0.97      0.95      0.96      4838
        E_ns       0.93      0.92      0.93      5574
        E_nt       0.93      0.92      0.93      2692
        M_nr       0.98      0.96      0.97      4382
        M_ns       0.90      0.86      0.88      3453
        M_nt       0.93      0.93      0.93     10073
           O       0.98      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.95      0.94      0.94    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 7, Step: 0, Train loss: 0.03858230635523796
Epoch: 7, Step: 100, Train loss: 0.03429088222512994
Epoch: 7, Step: 200, Train loss: 0.04127531766541654
Epoch: 7, Step: 300, Train loss: 0.039859507077039405
Epoch: 7, Step: 400, Train lo

Epoch:  40%|████      | 8/20 [2:34:28<3:51:44, 1158.67s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.97      0.97      4839
        B_ns       0.96      0.94      0.95      5576
        B_nt       0.92      0.95      0.94      2695
        E_nr       0.96      0.96      0.96      4838
        E_ns       0.95      0.91      0.93      5574
        E_nt       0.93      0.94      0.93      2692
        M_nr       0.97      0.98      0.97      4382
        M_ns       0.84      0.89      0.87      3453
        M_nt       0.91      0.96      0.93     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.94      0.95      0.94    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 8, Step: 0, Train loss: 0.02617947943508625
Epoch: 8, Step: 100, Train loss: 0.03863345995420903
Epoch: 8, Step: 200, Train loss: 0.035441361563175976
Epoch: 8, Step: 300, Train loss: 0.035436487860515085
Epoch: 8, Step: 400, Train l

Epoch:  45%|████▌     | 9/20 [2:53:46<3:32:20, 1158.25s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.96      0.97      4839
        B_ns       0.94      0.95      0.95      5576
        B_nt       0.94      0.91      0.92      2695
        E_nr       0.97      0.96      0.96      4838
        E_ns       0.92      0.93      0.92      5574
        E_nt       0.94      0.89      0.92      2692
        M_nr       0.97      0.97      0.97      4382
        M_ns       0.93      0.83      0.87      3453
        M_nt       0.94      0.90      0.92     10073
           O       0.98      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.95      0.93      0.94    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 9, Step: 0, Train loss: 0.0999036505818367
Epoch: 9, Step: 100, Train loss: 0.02913193331021668
Epoch: 9, Step: 200, Train loss: 0.029320975582564798
Epoch: 9, Step: 300, Train loss: 0.030679794494256042
Epoch: 9, Step: 400, Train lo

Epoch:  50%|█████     | 10/20 [3:13:03<3:12:59, 1158.00s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.95      0.95      0.95      5576
        B_nt       0.93      0.94      0.94      2695
        E_nr       0.96      0.97      0.96      4838
        E_ns       0.94      0.93      0.94      5574
        E_nt       0.92      0.95      0.93      2692
        M_nr       0.96      0.98      0.97      4382
        M_ns       0.89      0.87      0.88      3453
        M_nt       0.92      0.95      0.93     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.94      0.95      0.95    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 10, Step: 0, Train loss: 0.002057593548670411
Epoch: 10, Step: 100, Train loss: 0.016354165911342543
Epoch: 10, Step: 200, Train loss: 0.019615491087356145
Epoch: 10, Step: 300, Train loss: 0.020212128447667323
Epoch: 10, Step: 400, 

Epoch:  55%|█████▌    | 11/20 [3:32:21<2:53:42, 1158.06s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.97      0.98      4839
        B_ns       0.93      0.96      0.94      5576
        B_nt       0.95      0.93      0.94      2695
        E_nr       0.97      0.96      0.96      4838
        E_ns       0.89      0.95      0.92      5574
        E_nt       0.94      0.93      0.94      2692
        M_nr       0.97      0.97      0.97      4382
        M_ns       0.85      0.91      0.88      3453
        M_nt       0.94      0.93      0.93     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.94      0.95      0.94    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 11, Step: 0, Train loss: 0.04864921420812607
Epoch: 11, Step: 100, Train loss: 0.020617362255308923
Epoch: 11, Step: 200, Train loss: 0.01870826850106606
Epoch: 11, Step: 300, Train loss: 0.018346095874602366
Epoch: 11, Step: 400, Tr

Epoch:  60%|██████    | 12/20 [3:51:39<2:34:23, 1157.98s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.97      0.98      4839
        B_ns       0.96      0.94      0.95      5576
        B_nt       0.93      0.95      0.94      2695
        E_nr       0.98      0.96      0.97      4838
        E_ns       0.94      0.92      0.93      5574
        E_nt       0.94      0.93      0.93      2692
        M_nr       0.97      0.98      0.97      4382
        M_ns       0.84      0.92      0.88      3453
        M_nt       0.92      0.95      0.93     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.97    140704
   macro avg       0.94      0.95      0.95    140704
weighted avg       0.97      0.97      0.97    140704


Epoch: 12, Step: 0, Train loss: 0.03301943838596344
Epoch: 12, Step: 100, Train loss: 0.015860037982756602
Epoch: 12, Step: 200, Train loss: 0.017225448519109624
Epoch: 12, Step: 300, Train loss: 0.01650072940825593
Epoch: 12, Step: 400, Tr

Epoch:  65%|██████▌   | 13/20 [4:10:57<2:15:05, 1157.87s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.96      0.94      0.95      5576
        B_nt       0.93      0.95      0.94      2695
        E_nr       0.97      0.97      0.97      4838
        E_ns       0.95      0.92      0.93      5574
        E_nt       0.93      0.95      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.92      0.85      0.88      3453
        M_nt       0.93      0.94      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 13, Step: 0, Train loss: 0.000479490146972239
Epoch: 13, Step: 100, Train loss: 0.015417802542342002
Epoch: 13, Step: 200, Train loss: 0.014208690174281903
Epoch: 13, Step: 300, Train loss: 0.01347711198876977
Epoch: 13, Step: 400, T

Epoch:  70%|███████   | 14/20 [4:30:15<1:55:48, 1158.04s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.94      0.96      0.95      5576
        B_nt       0.95      0.93      0.94      2695
        E_nr       0.96      0.97      0.97      4838
        E_ns       0.92      0.95      0.93      5574
        E_nt       0.93      0.94      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.88      0.89      0.89      3453
        M_nt       0.94      0.92      0.93     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 14, Step: 0, Train loss: 0.0015078940195962787
Epoch: 14, Step: 100, Train loss: 0.012611548313588633
Epoch: 14, Step: 200, Train loss: 0.011818829071748567
Epoch: 14, Step: 300, Train loss: 0.010764713678741053
Epoch: 14, Step: 400,

Epoch:  75%|███████▌  | 15/20 [4:49:33<1:36:29, 1157.92s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.96      0.95      0.95      5576
        B_nt       0.95      0.95      0.95      2695
        E_nr       0.98      0.96      0.97      4838
        E_ns       0.94      0.94      0.94      5574
        E_nt       0.94      0.94      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.90      0.89      0.89      3453
        M_nt       0.94      0.93      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 15, Step: 0, Train loss: 0.0005320027121342719
Epoch: 15, Step: 100, Train loss: 0.009322206869820538
Epoch: 15, Step: 200, Train loss: 0.009169618311987968
Epoch: 15, Step: 300, Train loss: 0.010665729227499027
Epoch: 15, Step: 400,

Epoch:  80%|████████  | 16/20 [5:08:50<1:17:10, 1157.72s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.95      0.95      0.95      5576
        B_nt       0.93      0.95      0.94      2695
        E_nr       0.97      0.97      0.97      4838
        E_ns       0.93      0.94      0.94      5574
        E_nt       0.94      0.94      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.91      0.88      0.89      3453
        M_nt       0.94      0.95      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 16, Step: 0, Train loss: 0.0003239345969632268
Epoch: 16, Step: 100, Train loss: 0.007251097451210239
Epoch: 16, Step: 200, Train loss: 0.009396115280692154
Epoch: 16, Step: 300, Train loss: 0.009491230373274097
Epoch: 16, Step: 400,

Epoch:  85%|████████▌ | 17/20 [5:28:08<57:53, 1157.71s/it]  

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.95      0.95      0.95      5576
        B_nt       0.93      0.95      0.94      2695
        E_nr       0.97      0.96      0.97      4838
        E_ns       0.94      0.94      0.94      5574
        E_nt       0.94      0.94      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.90      0.89      0.89      3453
        M_nt       0.93      0.95      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 17, Step: 0, Train loss: 0.020014125853776932
Epoch: 17, Step: 100, Train loss: 0.005708400528842818
Epoch: 17, Step: 200, Train loss: 0.006152419265122466
Epoch: 17, Step: 300, Train loss: 0.0058768058959419835
Epoch: 17, Step: 400,

Epoch:  90%|█████████ | 18/20 [5:47:26<38:35, 1157.74s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.96      0.95      0.95      5576
        B_nt       0.93      0.96      0.94      2695
        E_nr       0.97      0.97      0.97      4838
        E_ns       0.94      0.93      0.94      5574
        E_nt       0.93      0.95      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.91      0.88      0.89      3453
        M_nt       0.93      0.95      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 18, Step: 0, Train loss: 0.0001083395181922242
Epoch: 18, Step: 100, Train loss: 0.004615276137532293
Epoch: 18, Step: 200, Train loss: 0.004288479154095498
Epoch: 18, Step: 300, Train loss: 0.005680499854976229
Epoch: 18, Step: 400,

Epoch:  95%|█████████▌| 19/20 [6:06:43<19:17, 1157.74s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.96      0.95      0.95      5576
        B_nt       0.94      0.96      0.95      2695
        E_nr       0.97      0.97      0.97      4838
        E_ns       0.94      0.94      0.94      5574
        E_nt       0.94      0.95      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.92      0.88      0.90      3453
        M_nt       0.94      0.95      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704


Epoch: 19, Step: 0, Train loss: 0.00011570870992727578
Epoch: 19, Step: 100, Train loss: 0.006113917242947708
Epoch: 19, Step: 200, Train loss: 0.006143870602468427
Epoch: 19, Step: 300, Train loss: 0.006183979809760894
Epoch: 19, Step: 400

Epoch: 100%|██████████| 20/20 [6:26:01<00:00, 1158.06s/it]

              precision    recall  f1-score   support

        B_nr       0.98      0.98      0.98      4839
        B_ns       0.96      0.95      0.96      5576
        B_nt       0.94      0.95      0.95      2695
        E_nr       0.97      0.97      0.97      4838
        E_ns       0.94      0.94      0.94      5574
        E_nt       0.94      0.95      0.94      2692
        M_nr       0.97      0.98      0.98      4382
        M_ns       0.91      0.88      0.90      3453
        M_nt       0.94      0.95      0.94     10073
           O       0.99      0.99      0.99     96582

    accuracy                           0.98    140704
   macro avg       0.95      0.95      0.95    140704
weighted avg       0.98      0.98      0.98    140704







> 从整个训练过程来看，比较难训练，但最终准确率较高

In [69]:
# # 保存预训练模型
# model.save_pretrained("../../H/models/huggingface/torch/zh_albert_ner/")
torch.save(model.state_dict(),
           "../../H/models/huggingface/torch/zh_albert_ner/pytorch_model.bin")

# # 直接从路径中加载模型
# model_path = "../../H/models/huggingface/torch/zh_albert_ner/"
# model = AlbertForTokenClassification.from_pretrained(model_path)

## 使用模型

In [64]:
input_ids = torch.tensor(
    tokenizer.encode("冯永祥突发奇想，跑到阿尔及利亚旅行，意外结识了印度人民党的领导",
                     add_special_tokens=False)).unsqueeze(0).to(device)

In [65]:
input_ids

tensor([[1101, 3719, 4872, 4960, 1355, 1936, 2682, 8024, 6651, 1168, 7350, 2209,
         1350, 1164,  762, 3180, 6121, 8024, 2692, 1912, 5310, 6399,  749, 1313,
         2428,  782, 3696, 1054, 4638, 7566, 2193]], device='cuda:0')

In [66]:
outputs = model(input_ids)

In [67]:
pred_dis = outputs[0]
pred = torch.argmax(pred_dis, axis=2)

In [68]:
[id2tag[idx] for idx in pred[0].cpu().numpy()]

['B_nr',
 'M_nr',
 'E_nr',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B_ns',
 'M_ns',
 'M_ns',
 'M_ns',
 'E_ns',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'O',
 'B_nt',
 'M_nt',
 'M_nt',
 'M_nt',
 'E_nt',
 'O',
 'O',
 'O']

In [71]:
pred

tensor([[3, 4, 5, 9, 9, 9, 9, 9, 9, 9, 0, 1, 1, 1, 2, 9, 9, 9, 9, 9, 9, 9, 9, 6,
         7, 7, 7, 8, 9, 9, 9]], device='cuda:0')