In [1]:
# !pip install pytorch-crf
# !pip install seqeval
# !pip install transformers

In [2]:
from torch.utils.data import Dataset

categories = set()

class ReadData(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f.read().split('\n\n')):
                if not line:
                    break
                sentence, tags = '', []
                for i, c in enumerate(line.split('\n')):
                    word, tag = c.split('\t')
                    if word==' ':
                        word='1'
                    sentence += word
                    if tag[0] == 'B':
                        tags.append([i, i, word, tag[2:]]) # Remove the B- or I-
                        # print(tags)
                        categories.add(tag[2:])
                    elif tag[0] == 'I':
                        # print(tags)
                        # print(word)
                        # print(tag)
                        tags[-1][1] = i
                        tags[-1][2] += word
                Data[idx] = {
                    'sentence': sentence, 
                    'tags': tags
                }
        return Data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [3]:
train_data = ReadData('./example.train')
valid_data = ReadData('./example.dev')
# test_data = ReadData('./example.test')

print(train_data[0])

{'sentence': '主机厂家已机组提供高电压耐受能力情况说明，未说明具体耐受能力范围，，缺少对应的报告文件支持。能效技术监督。3.常用标准、规程、措施、制度、技术资料和各种记录缺失。主机厂家已提供符合要求的高电压耐受能力证明报告及对应的支持文件', 'tags': [[0, 31, '主机厂家已机组提供高电压耐受能力情况说明，未说明具体耐受能力范围', 'Pro'], [34, 44, '缺少对应的报告文件支持', 'Cau'], [74, 79, '各种记录缺失', 'Cau'], [81, 111, '主机厂家已提供符合要求的高电压耐受能力证明报告及对应的支持文件', 'Met']]}


In [4]:
categories

{'Cau', 'Met', 'Pro'}

In [5]:
id2label = {0:'O'}
for c in list(sorted(categories)):
    id2label[len(id2label)] = f"B-{c}"
    id2label[len(id2label)] = f"I-{c}"
label2id = {v: k for k, v in id2label.items()}

print(id2label)
print(label2id)

{0: 'O', 1: 'B-Cau', 2: 'I-Cau', 3: 'B-Met', 4: 'I-Met', 5: 'B-Pro', 6: 'I-Pro'}
{'O': 0, 'B-Cau': 1, 'I-Cau': 2, 'B-Met': 3, 'I-Met': 4, 'B-Pro': 5, 'I-Pro': 6}


In [6]:
# from transformers import AutoTokenizer
# import numpy as np

# checkpoint = "bert-base-chinese"
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# sentence = '主机厂家已机组提供高电压耐受能力情况说明（未说明具体耐受能力范围）'
# tags = [[9, 13, '高电压耐受', 'Phe']]

# encoding = tokenizer(sentence, truncation=True)
# tokens = encoding.tokens()
# label = np.zeros(len(tokens), dtype=int)
# for char_start, char_end, word, tag in tags:
#     token_start = encoding.char_to_token(char_start)
#     token_end = encoding.char_to_token(char_end)
#     label[token_start] = label2id[f"B-{tag}"]
#     label[token_start+1:token_end+1] = label2id[f"I-{tag}"]

# print(tokens)
# print(label)
# print([id2label[id] for id in label])

In [7]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import numpy as np

checkpoint = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def collote_fn(batch_samples):
    # batch_sentence, batch_tags,mask = [], [], []
    batch_sentence, batch_tags = [], []
    for sample in batch_samples:
        # print(sample)
        batch_sentence.append(sample['sentence'])
        batch_tags.append(sample['tags'])
        # mask.append(sample['mask_tensor'])
    batch_inputs = tokenizer(
        batch_sentence, 
        padding=True, 
        truncation=True, 
        return_tensors="pt",
        # max_length=256
    )
    batch_label = np.zeros(batch_inputs['input_ids'].shape, dtype=int)
    for s_idx, sentence in enumerate(batch_sentence):
        encoding = tokenizer(sentence, truncation=True)
        batch_label[s_idx][0] = 0
        batch_label[s_idx][len(encoding.tokens())-1:] = 0
        for char_start, char_end, _, tag in batch_tags[s_idx]:
            token_start = encoding.char_to_token(char_start)
            token_end = encoding.char_to_token(char_end)
            batch_label[s_idx][token_start] = label2id[f"B-{tag}"]
            batch_label[s_idx][token_start+1:token_end+1] = label2id[f"I-{tag}"]
    return batch_inputs, torch.tensor(batch_label)

# train_dataloader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collote_fn)
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collote_fn)
valid_dataloader = DataLoader(valid_data, batch_size=32, shuffle=False, collate_fn=collote_fn)
# test_dataloader = DataLoader(test_data, batch_size=4, shuffle=False, collate_fn=collote_fn)

batch_X, batch_y = next(iter(train_dataloader))
# print('batch_X shape:', {k: v.shape for k, v in batch_X.items()})
# print('batch_y shape:', batch_y.shape)
# print(batch_X)
# print(batch_y)

In [8]:
from torch import nn
from transformers import AutoModel
from torchcrf import CRF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {device} device')

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.bert = AutoModel.from_pretrained(checkpoint)
        self.config = self.bert.config
        self.BiLstm=nn.LSTM(input_size=self.config.hidden_size,hidden_size=512,batch_first=True,bidirectional=True,num_layers=2)
        self.Linear = nn.Linear(512*2, len(id2label))
        self.crf = CRF(len(id2label),batch_first=True)
            
    # def forward(self, x):
    def forward(self, x, y):
        # 1.
        # output = self.bert(**x).last_hidden_state
        # output, _ = self.BiLstm(output)
        # output = self.Linear(output)
        # return output
        # 2.
        output = self.bert(**x).last_hidden_state
        output, _ = self.BiLstm(output)
        output = self.Linear(output)
        # loss = self.crf(emissions=output,tags=y,mask=mask_tensor)
        # tag = self.crf.decode(emissions=output,,mask=mask_tensor)
        loss = self.crf(emissions=output,tags=y)
        tag = self.crf.decode(emissions=output)
        tag=torch.tensor(tag)
        return loss, tag
    
    def decode(self,x):
        output = self.bert(**x).last_hidden_state
        output, _ = self.BiLstm(output)
        output = self.Linear(output)
        tag = self.crf.decode(emissions=output)
        tag=torch.tensor(tag)
        return tag
    
model = model().to(device)
# print(model)

Using cuda device


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
from tqdm.auto import tqdm

def train_loop(dataloader, model, loss_fn, optimizer, lr_scheduler, epoch, total_loss):
# def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = (epoch-1) * len(dataloader)
    
    model.train()
    for batch, (X, y) in enumerate(dataloader, start=1):
        optimizer.zero_grad()
        X, y = X.to(device), y.to(device)
        loss, tag = model(X, y)
        #  通过 pred.permute(0, 2, 1) 交换后两维，将模型预测结果从(batch,seq,7) 调整为 (batch,7,seq)。
        # loss = loss_fn(pred.permute(0, 2, 1), y)
        loss = abs(loss)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

In [10]:
# !pip install seqeval
# from seqeval.metrics import classification_report
# from seqeval.scheme import IOB2

# y_true = [['O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'B-LOC', 'O'], ['B-PER', 'I-PER', 'O']]
# y_pred = [['O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'B-LOC', 'O'], ['B-PER', 'I-PER', 'O']]

# print(classification_report(y_true, y_pred, mode='strict', scheme=IOB2))

In [11]:
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

def test_loop(dataloader, model):
    true_labels, true_predictions = [], []

    model.eval()
    with torch.no_grad():
        for X, y in tqdm(dataloader):
            X, y = X.to(device), y.to(device)
            
            # pred = model(X)
            loss, tag = model(X, y)
            
            # predictions = pred.argmax(dim=-1)
            predictions = tag
            
            true_labels += [[id2label[int(l)] for l in label if l != -100] for label in y]
            true_predictions += [
                [id2label[int(p)] for (p, l) in zip(prediction, label) if l != -100]
                for prediction, label in zip(predictions, y)
            ]
    print(classification_report(true_labels, true_predictions, mode='strict', scheme=IOB2))

In [None]:
from transformers import AdamW, get_scheduler

learning_rate = 1e-5
epoch_num = 100
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=learning_rate)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num*len(train_dataloader),
)

total_loss = 0.
loss_list=[]
for t in range(epoch_num):
    print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
    total_loss = train_loop(train_dataloader, model, loss_fn, optimizer, lr_scheduler, t+1, total_loss)
    # total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
    if(t%10==0):
        test_loop(valid_dataloader, model)

print("Done!")

Epoch 1/100
-------------------------------




  0%|          | 0/24 [00:00<?, ?it/s]

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


  0%|          | 0/24 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         Cau       0.00      0.00      0.00       393
         Met       0.00      0.00      0.00       585
         Pro       0.00      0.00      0.00       900

   micro avg       0.00      0.00      0.00      1878
   macro avg       0.00      0.00      0.00      1878
weighted avg       0.00      0.00      0.00      1878

Epoch 2/100
-------------------------------


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 3/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 4/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 5/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 6/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 7/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 8/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 9/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 10/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 11/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         Cau       0.00      0.00      0.00       393
         Met       0.00      0.00      0.00       585
         Pro       0.00      0.00      0.00       900

   micro avg       0.00      0.00      0.00      1878
   macro avg       0.00      0.00      0.00      1878
weighted avg       0.00      0.00      0.00      1878

Epoch 12/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 13/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 14/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 15/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 16/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 17/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 18/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 19/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 20/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 21/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         Cau       0.00      0.00      0.00       393
         Met       0.00      0.00      0.00       585
         Pro       0.00      0.00      0.00       900

   micro avg       0.00      0.00      0.00      1878
   macro avg       0.00      0.00      0.00      1878
weighted avg       0.00      0.00      0.00      1878

Epoch 22/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 23/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 24/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 25/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 26/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 27/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 28/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 29/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 30/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 31/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         Cau       0.00      0.00      0.00       393
         Met       0.00      0.00      0.00       585
         Pro       0.73      0.23      0.35       900

   micro avg       0.73      0.11      0.19      1878
   macro avg       0.24      0.08      0.12      1878
weighted avg       0.35      0.11      0.17      1878

Epoch 32/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 33/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 34/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 35/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 36/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 37/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 38/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 39/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 40/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

Epoch 41/100
-------------------------------


  0%|          | 0/24 [00:00<?, ?it/s]

  0%|          | 0/24 [00:00<?, ?it/s]

In [None]:
sentence = '在使用过程中若发现油位指示窗内出现油面，说明波纹囊有渗漏，绝缘油进入空气腔。发现指示窗有油应马上通知厂家处理，并采取临时措施。'

results = []
with torch.no_grad():
    inputs = tokenizer(sentence, truncation=True, return_tensors="pt")
    inputs = inputs.to(device)
    pred = model.decode(inputs)
    predictions = pred[0].tolist()
    pred_label = []
    inputs_with_offsets = tokenizer(sentence, return_offsets_mapping=True)
    tokens = inputs_with_offsets.tokens()
    offsets = inputs_with_offsets["offset_mapping"]
    idx = 0
    while idx < len(predictions):
        pred = predictions[idx]
        label = id2label[pred]
        if label != "O":
            label = label[2:] # Remove the B- or I-
            start, end = offsets[idx]
            while (
                idx + 1 < len(predictions) and 
                id2label[predictions[idx + 1]] == f"I-{label}"
            ):
                # all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
                _, end = offsets[idx + 1]
                idx += 1
            word = sentence[start:end]
            pred_label.append(
                {
                    "entity_group": label,
                    "word": word,
                    "start": start,
                    "end": end,
                }
            )
        idx += 1
    print(pred_label)

In [None]:
sentence = '气体继电器保护装置的信号动作时，值班员应立即停止报警信号，并检查变压器，查明信号动作的原因，是否因空气侵入变压器内，或是油位降低，或是二次回路故障。'
results = []
with torch.no_grad():
    inputs = tokenizer(sentence, truncation=True, return_tensors="pt")
    inputs = inputs.to(device)
    pred = model.decode(inputs)
    predictions = pred[0].tolist()
    pred_label = []
    inputs_with_offsets = tokenizer(sentence, return_offsets_mapping=True)
    tokens = inputs_with_offsets.tokens()
    offsets = inputs_with_offsets["offset_mapping"]
    idx = 0
    while idx < len(predictions):
        pred = predictions[idx]
        label = id2label[pred]
        if label != "O":
            label = label[2:] # Remove the B- or I-
            start, end = offsets[idx]
            while (
                idx + 1 < len(predictions) and 
                id2label[predictions[idx + 1]] == f"I-{label}"
            ):
                # all_scores.append(probabilities[idx + 1][predictions[idx + 1]])
                _, end = offsets[idx + 1]
                idx += 1
            word = sentence[start:end]
            pred_label.append(
                {
                    "entity_group": label,
                    "word": word,
                    "start": start,
                    "end": end,
                }
            )
        idx += 1
    print(pred_label)

In [None]:
test_loop(valid_dataloader, model)

In [None]:
# # 打印模型的 state_dict
# print("Model's state_dict:")
# for param_tensor in model.state_dict():
#     print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# # 打印优化器的 state_dict
# print("Optimizer's state_dict:")
# for var_name in optimizer.state_dict():
#     print(var_name, "\t", optimizer.state_dict()[var_name])

# 保存/加载 state_dict（推荐）
要注意这个细节，如果使用nn.DataParallel在一台电脑上使用了多个GPU，那么加载模型的时候也必须先进行nn.DataParallel。

保存模型的推理过程的时候，只需要保存模型训练好的参数，使用torch.save()保存state_dict，能够方便模型的加载。因此推荐使用这种方式进行模型保存。

记住一定要使用model.eval()来固定dropout和归一化层，否则每次推理会生成不同的结果。

In [None]:
# # 保存：
torch.save(model.state_dict(), './model1')
# 加载：
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('./model1'))
model.eval()


# 保存/加载整个模型
这种保存/加载模型的过程使用了最直观的语法，所用代码量少。这使用Python的pickle保存所有模块。这种方法的缺点是，保存模型的时候，序列化的数据被绑定到了特定的类和确切的目录。这是因为pickle不保存模型类本身，而是保存这个类的路径，并且在加载的时候会使用。因此，当在其他项目里使用或者重构的时候，加载模型的时候会出错。

一般来说，PyTorch的模型以.pt或者.pth文件格式保存。

一定要记住在评估模式的时候调用model.eval()来固定dropout和批次归一化。否则会产生不一致的推理结果。

In [None]:
# # 保存：
# torch.save(model, './model1')
# # 加载：
# model = torch.load('./model1')
# model.eval()

# 保存加载用于推理的常规Checkpoint/或继续训练
在保存用于推理或者继续训练的常规检查点的时候，除了模型的state_dict之外，还必须保存其他参数。保存优化器的state_dict也非常重要，因为它包含了模型在训练时候优化器的缓存和参数。除此之外，还可以保存停止训练时epoch数，最新的模型损失，额外的torch.nn.Embedding层等。

要保存多个组件，则将它们放到一个字典中，然后使用torch.save()序列化这个字典。一般来说，使用.tar文件格式来保存这些检查点。

加载各个组件，首先初始化模型和优化器，然后使用torch.load()加载保存的字典，然后可以直接查询字典中的值来获取保存的组件。

同样，评估模型的时候一定不要忘了调用model.eval()。

In [None]:
# # 保存：
# torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss,
#             ...
#             }, './model1')
# # 加载：
# model = TheModelClass(*args, **kwargs)
# optimizer = TheOptimizerClass(*args, **kwargs)

# checkpoint = torch.load(PATH)
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']

# model.eval()
# # - 或者 -
# model.train()
