In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install transformers
!pip install pytorch-crf

Collecting transformers
  Downloading transformers-4.10.0-py3-none-any.whl (2.8 MB)
[K     |████████████████████████████████| 2.8 MB 11.4 MB/s 
[?25hCollecting huggingface-hub>=0.0.12
  Downloading huggingface_hub-0.0.16-py3-none-any.whl (50 kB)
[K     |████████████████████████████████| 50 kB 5.9 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 42.7 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 44.0 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 43.2 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: P

In [None]:
import torch
print(torch.cuda.is_available())

True


In [None]:
import os
class Config:
    def __init__(self):
        self.train_data_path=r"/content/char_ner_train.csv"
        self.saved_ids_data_path=r"/content/drive/MyDrive/saved_ids.pkl"
        self.model_save_path='/content/drive/MyDrive/bert-lstm-crf.bin'

        self.label2id={
            'M-GPE':0,
            'M-ORG':1,
            'M-PER':2,
            'M-LOC':3,
            'S-ORG':4,
            'E-ORG':5,
            'S-PER':6,
            'E-LOC':7,
            'B-LOC':8,
            'B-ORG':9,
            'S-GPE':10,
            'S-LOC':11,
            'B-PER':12,
            'E-GPE':13,
            'O':14,
            'B-GPE':15,
            'E-PER':16,
        }
        self.num_labels=17
        self.hidden_dropout_prob= 0.1
        self.bert_embedding_size=768
        self.hidden_size=768
        self.lstm_dropout_prob=0.5
        self.epoch_num=20
        self.batch_size=25
        self.learning_rate=3e-5
        self.weight_decay = 0.01
        # 训练集、验证集划分比例
        self.train_test_split_size = 0.005
        self.device = torch.device("cuda")
config=Config()
print(os.getcwd())

/content


### data_set,数据集构建

In [None]:
import pickle
from transformers import BertModel, BertTokenizer
from torch.utils.data import Dataset
import torch
from torch.nn.utils.rnn import pad_sequence

class MyDataSet(Dataset):

    def __init__(self,wordlists,taglists):
        self.wordlists = wordlists
        self.taglists =taglists

    def __getitem__(self, item):
        return torch.Tensor(self.wordlists[item]), torch.Tensor(self.taglists[item])

    def __len__(self):
        return len(self.wordlists)


def collate_fn(batch):
    """
    :param batch: (batch_num, ([sentence_len, word_embedding], [sentence_len]))
    :return:
    """
    a,b=len(batch),len(batch[0])
    x_list = [x[0] for x in batch]
    y_list = [x[1] for x in batch]
    lengths = [len(item[0]) for item in batch]
    x_list = pad_sequence(x_list, padding_value=0)
    y_list = pad_sequence(y_list, padding_value=-1)

    tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

    #return x_list.transpose(0, 1).to(config.device), y_list.transpose(0, 1).to(config.device), lengths.to(config.device)
    return x_list.transpose(0, 1), y_list.transpose(0, 1), lengths

# Model,最基本的模型文件

In [None]:
from transformers import BertModel, BertTokenizer
from torch.nn import Module
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torch
from torchcrf import CRF

class BertNER(Module):
    def __init__(self):
        super(BertNER, self).__init__()
        self.num_labels = config.num_labels

        # 基本的Bert模型
        self.bert = BertModel.from_pretrained('bert-base-chinese', output_hidden_states=True)
        # 用到的LSTM
        self.bilstm = nn.LSTM(
            input_size=config.bert_embedding_size, 
            hidden_size=config.hidden_size//2, #因为是双向的,所以要除2 
            batch_first=True,
            num_layers=2,
            dropout=config.lstm_dropout_prob,  # 0.5
            bidirectional=True
        )
        
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)

        #self.init_weights()

    def forward(self, input_data, lengths ):

        # 求出attention_mask
        max_len = int(max(lengths))
        attention_mask = torch.Tensor()
        # length = length.numpy()
        # 与每个序列等长度的全1向量连接长度为最大序列长度-当前序列长度的全0向量。
        for len_ in lengths:
            attention_mask = torch.cat((attention_mask, torch.Tensor([[1] * len_ + [0] * (max_len - len_)])), dim=0)


        input_data=input_data.long()
        attention_mask=attention_mask.long()

        outputs = self.bert(input_data.to(config.device),
                            attention_mask=attention_mask.to(config.device),)
        #print("\nlengths:",lengths[:5])
        sequence_output = outputs[0]
        #print("sequence_output,shape:",sequence_output.shape)

        origin_sequence_output=[]
        # 去掉special token, 待改进
        for idx, sentence in enumerate(sequence_output):
            sentence=torch.index_select(sentence,0,torch.arange(1,lengths[idx]-1).to(config.device))
            origin_sequence_output.append(sentence)

        # 将sequence_output的pred_label维度padding到最大长度
        padded_sequence_output = pad_sequence(origin_sequence_output, batch_first=True)
        # dropout pred_label的一部分feature
        #padded_sequence_output = self.dropout(padded_sequence_output)

        
        #print("padded_sequence_output:",padded_sequence_output.size())
        # 输入到LSTM
        lstm_output, _ = self.bilstm(padded_sequence_output)

        # 得到判别值
        # (batch_size,seq,label_num)
        logits = self.classifier(lstm_output)

        return logits

    def get_loss(self,output_logits,labels):

        # 数据类型和设备的转换
        labels=labels.long().to(config.device)
        # 得出计算损失的mask
        loss_mask = labels.gt(-1)
        # 利用CRF进行前向传播
        loss=self.crf(output_logits,labels,loss_mask)*(-1)
        
        return loss



# run_and_train

### evaluate,计算f1值的

In [None]:
import numpy as np
from sklearn.metrics import f1_score
import torch

def get_f1(model, test_dataLoader):
    with torch.no_grad():

        all_y_pred=[]
        all_y_true=[]
        for idx, batch_data in enumerate(test_dataLoader):
            x, labels, lengths = batch_data
            #labels=labels
            logits = model.forward(x,lengths)
            loss_mask = labels.gt(-1).to('cuda')

            # 用CRF的解码方法
            decoded_outputs = model.crf.decode(logits,mask=loss_mask)

            # print(len(decoded_outputs[1]))
            # print(len(decoded_outputs[2]))
            # print(len(decoded_outputs[3]))
            # 将输出和标签展开

            for sentence_score in decoded_outputs:
              #print(len(sentence_score))
              all_y_pred.append(sentence_score)
            
            labels=labels.to("cpu")
            loss_mask=loss_mask.reshape(-1).to("cpu")
            y_true = labels.numpy().reshape(-1)[loss_mask]
            all_y_true.append(y_true)

        all_y_true=np.concatenate(all_y_true,axis=0)
        all_y_pred = np.concatenate(all_y_pred, axis=0)

        f1_macro = f1_score(all_y_true, all_y_pred, average='macro')

        return f1_macro

### 读取,保存 simple_train,简单训练

In [None]:

from torch.optim import AdamW
from torch.utils.data import DataLoader
import logging
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch import nn
from torch.nn import init
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# 先加载分词好的数据
file = open(config.saved_ids_data_path, "rb")
word_lists,tag_lists=pickle.load(file)
file.close()
logging.info("--------已加载好训练数据!--------")

# 分割训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(word_lists[:7000], tag_lists[:7000], test_size=config.train_test_split_size, random_state=0)

# 构建dataloader
train_dataset=MyDataSet(x_train,y_train)
train_dataloader = DataLoader(dataset=train_dataset, batch_size=config.batch_size, collate_fn=collate_fn)
test_dataset=MyDataSet(x_test,y_test)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=config.batch_size, collate_fn=collate_fn)
logging.info("--------已构建好DataLoader!--------")

# 这里加载模型
model=BertNER()
try:
  model=torch.load(config.model_save_path)
except:
  print("加载模型失败")
model.to(config.device)
logging.info("--------已准备好模型!--------")

# 构建优化器
bert_optimizer = list(model.bert.named_parameters())
lstm_optimizer = list(model.bilstm.named_parameters())
classifier_optimizer = list(model.classifier.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in bert_optimizer if not any(nd in n for nd in no_decay)],
      'weight_decay': config.weight_decay},
    {'params': [p for n, p in bert_optimizer if any(nd in n for nd in no_decay)],
      'weight_decay': 0.0},
    {'params': [p for n, p in lstm_optimizer if not any(nd in n for nd in no_decay)],
      'lr': config.learning_rate * 5, 'weight_decay': config.weight_decay},
    {'params': [p for n, p in lstm_optimizer if any(nd in n for nd in no_decay)],
      'lr': config.learning_rate * 5, 'weight_decay': 0.0},
    {'params': [p for n, p in classifier_optimizer if not any(nd in n for nd in no_decay)],
      'lr': config.learning_rate * 5, 'weight_decay': config.weight_decay},
    {'params': [p for n, p in classifier_optimizer if any(nd in n for nd in no_decay)],
      'lr': config.learning_rate * 5, 'weight_decay': 0.0},
    {'params': model.crf.parameters(), 'lr': config.learning_rate * 5}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=config.learning_rate)#, correct_bias=False)
logging.info("--------已构建优化器!--------")

# 计算当前模型的F1值,用于等会做对比
best_f1=get_f1(model,test_dataloader)

# 开始训练
for epoch in range(1,config.epoch_num+1):
    for batch_num,batch_data in enumerate(tqdm(train_dataloader)):
        #print(batch_num,'begin')
        x,y,lengths =batch_data
        output_logits=model(x,lengths)
        loss=model.get_loss(output_logits ,y)
        model.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_num%50==0:
          current_f1=get_f1(model,test_dataloader)
          print("loss:", loss.item(),"  f1_score:",current_f1)
          if current_f1-best_f1>0.01:
            torch.save(model,config.model_save_path)
            best_f1=current_f1
            print("已经保存最好模型")
           # print("f1_score:",get_f1(model,test_dataloader))


        #print(batch_num,loss)



Downloading:   0%|          | 0.00/624 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/412M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/110k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/269k [00:00<?, ?B/s]

  0%|          | 1/279 [00:08<40:13,  8.68s/it]

loss: 47.66716384887695   f1_score: 0.8207484161142796


 18%|█▊        | 51/279 [03:10<20:07,  5.30s/it]

loss: 3.167123794555664   f1_score: 0.8179869907440148


 36%|███▌      | 101/279 [06:13<15:42,  5.29s/it]

loss: 1.1402969360351562   f1_score: 0.7786242353847403


 54%|█████▍    | 151/279 [09:13<12:01,  5.63s/it]

loss: 32.616424560546875   f1_score: 0.7882031678174386


 72%|███████▏  | 201/279 [12:19<06:53,  5.31s/it]

loss: 18.284854888916016   f1_score: 0.8045598025191897


 90%|████████▉ | 251/279 [15:19<02:25,  5.18s/it]

loss: 0.9881744384765625   f1_score: 0.7394383433265949


100%|██████████| 279/279 [16:56<00:00,  3.64s/it]
  0%|          | 1/279 [00:08<39:50,  8.60s/it]

loss: 9.156547546386719   f1_score: 0.6514629246158008


 18%|█▊        | 51/279 [03:12<19:52,  5.23s/it]

loss: 0.761505126953125   f1_score: 0.7477896246414765


 36%|███▌      | 101/279 [06:18<15:42,  5.29s/it]

loss: 7.200550079345703   f1_score: 0.7373447213044612


 54%|█████▍    | 151/279 [09:16<11:58,  5.61s/it]

loss: 39.769920349121094   f1_score: 0.7827100031846257


 72%|███████▏  | 201/279 [12:19<06:32,  5.04s/it]

loss: 18.356578826904297   f1_score: 0.7981090390593369


 90%|████████▉ | 251/279 [15:20<02:25,  5.19s/it]

loss: 16.771018981933594   f1_score: 0.7649678617144999


100%|██████████| 279/279 [16:57<00:00,  3.65s/it]
  0%|          | 1/279 [00:09<42:45,  9.23s/it]

loss: 3.013408660888672   f1_score: 0.7670568729897916


 18%|█▊        | 51/279 [03:10<19:41,  5.18s/it]

loss: 0.935302734375   f1_score: 0.7666644154844643


 36%|███▌      | 101/279 [06:13<15:40,  5.29s/it]

loss: 0.9424629211425781   f1_score: 0.7658814096325085


 54%|█████▍    | 151/279 [09:13<12:02,  5.65s/it]

loss: 26.501487731933594   f1_score: 0.7691049182323134


 72%|███████▏  | 201/279 [12:15<06:33,  5.04s/it]

loss: 10.299741744995117   f1_score: 0.7843390902366588


 90%|████████▉ | 251/279 [15:15<02:25,  5.20s/it]

loss: 0.7551116943359375   f1_score: 0.7618571894178812


100%|██████████| 279/279 [16:52<00:00,  3.63s/it]
  0%|          | 1/279 [00:08<40:01,  8.64s/it]

loss: 8.673528671264648   f1_score: 0.7675164499037045


 18%|█▊        | 51/279 [03:10<19:52,  5.23s/it]

loss: 0.45090675354003906   f1_score: 0.7770590602497169


 36%|███▌      | 101/279 [06:19<15:40,  5.28s/it]

loss: 0.7654762268066406   f1_score: 0.7872382375711223


 54%|█████▍    | 151/279 [09:19<12:02,  5.64s/it]

loss: 1.9141921997070312   f1_score: 0.8231186975992265


 72%|███████▏  | 201/279 [12:22<06:37,  5.10s/it]

loss: 5.442707061767578   f1_score: 0.7777492940426712


 90%|████████▉ | 251/279 [15:21<02:25,  5.20s/it]

loss: 0.47759246826171875   f1_score: 0.7757904333646207


100%|██████████| 279/279 [16:58<00:00,  3.65s/it]
  0%|          | 1/279 [00:08<40:07,  8.66s/it]

loss: 0.5153312683105469   f1_score: 0.8079793058695832


 18%|█▊        | 51/279 [03:09<19:42,  5.19s/it]

loss: 0.42707061767578125   f1_score: 0.8233589681843246


 36%|███▌      | 101/279 [06:12<15:35,  5.25s/it]

loss: 0.5201377868652344   f1_score: 0.8077706020816705


 54%|█████▍    | 151/279 [09:12<12:00,  5.63s/it]

loss: 27.848838806152344   f1_score: 0.7834149498877496


 72%|███████▏  | 201/279 [12:15<06:34,  5.06s/it]

loss: 3.4817333221435547   f1_score: 0.8128855019597905


 90%|████████▉ | 251/279 [15:15<02:25,  5.20s/it]

loss: 14.23570442199707   f1_score: 0.8128855019597905


100%|██████████| 279/279 [16:52<00:00,  3.63s/it]
  0%|          | 1/279 [00:08<39:53,  8.61s/it]

loss: 0.5703029632568359   f1_score: 0.801359808115539


 18%|█▊        | 51/279 [03:10<19:45,  5.20s/it]

loss: 0.1899242401123047   f1_score: 0.805327041309399


 36%|███▌      | 101/279 [06:13<15:39,  5.28s/it]

loss: 7.136119842529297   f1_score: 0.7664350595812414


 54%|█████▍    | 151/279 [09:13<11:59,  5.62s/it]

loss: 1.5432929992675781   f1_score: 0.77617911351101


 72%|███████▏  | 201/279 [12:15<06:34,  5.05s/it]

loss: 23.20339012145996   f1_score: 0.7947042001314865


 90%|████████▉ | 251/279 [15:14<02:25,  5.21s/it]

loss: 1.9976310729980469   f1_score: 0.8098233921517792


100%|██████████| 279/279 [16:51<00:00,  3.63s/it]
  0%|          | 1/279 [00:08<39:57,  8.62s/it]

loss: 0.28069114685058594   f1_score: 0.8140069282097105


 18%|█▊        | 51/279 [03:09<19:47,  5.21s/it]

loss: 3.169607162475586   f1_score: 0.8128337482090966


 36%|███▌      | 101/279 [06:11<15:34,  5.25s/it]

loss: 0.5148506164550781   f1_score: 0.7979229464183027


 54%|█████▍    | 151/279 [09:11<11:59,  5.62s/it]

loss: 1.2427711486816406   f1_score: 0.7980600265242683


 72%|███████▏  | 201/279 [12:15<06:39,  5.12s/it]

loss: 4.164628982543945   f1_score: 0.7827801848587935


 90%|████████▉ | 251/279 [15:15<02:25,  5.19s/it]

loss: 0.2531757354736328   f1_score: 0.8107375778266911


100%|██████████| 279/279 [16:51<00:00,  3.63s/it]
  0%|          | 1/279 [00:09<43:15,  9.34s/it]

loss: 0.24936676025390625   f1_score: 0.8181709154191169


 18%|█▊        | 51/279 [03:11<19:43,  5.19s/it]

loss: 0.127288818359375   f1_score: 0.8161212414408544


 36%|███▌      | 101/279 [06:13<15:36,  5.26s/it]

loss: 0.2599201202392578   f1_score: 0.8138761362993434


 54%|█████▍    | 151/279 [09:12<11:59,  5.62s/it]

loss: 0.410186767578125   f1_score: 0.807841812160257


 72%|███████▏  | 201/279 [12:15<06:39,  5.12s/it]

loss: 4.838579177856445   f1_score: 0.807841812160257


 90%|████████▉ | 251/279 [15:14<02:25,  5.19s/it]

loss: 0.1734142303466797   f1_score: 0.8032827317029402


100%|██████████| 279/279 [16:51<00:00,  3.62s/it]
  0%|          | 1/279 [00:08<39:43,  8.57s/it]

loss: 0.12755203247070312   f1_score: 0.8138761362993434


 18%|█▊        | 51/279 [03:11<19:41,  5.18s/it]

loss: 0.1145172119140625   f1_score: 0.8138761362993434


 36%|███▌      | 101/279 [06:16<15:36,  5.26s/it]

loss: 0.23633384704589844   f1_score: 0.8084924719300446


 54%|█████▍    | 151/279 [09:16<12:00,  5.63s/it]

loss: 0.289886474609375   f1_score: 0.8127655635879215


 72%|███████▏  | 201/279 [12:18<06:32,  5.03s/it]

loss: 2.2004852294921875   f1_score: 0.8002871784791291


 90%|████████▉ | 251/279 [15:18<02:24,  5.18s/it]

loss: 1.6336517333984375   f1_score: 0.7717133214200661


100%|██████████| 279/279 [16:55<00:00,  3.64s/it]
  0%|          | 1/279 [00:08<39:52,  8.61s/it]

loss: 0.14798927307128906   f1_score: 0.7841259612174155


 18%|█▊        | 51/279 [03:08<19:38,  5.17s/it]

loss: 0.10191917419433594   f1_score: 0.7843180762403688


 36%|███▌      | 101/279 [06:11<15:30,  5.23s/it]

loss: 0.21211624145507812   f1_score: 0.7804745639121365


 54%|█████▍    | 151/279 [09:12<12:46,  5.99s/it]

loss: 2.3967628479003906   f1_score: 0.7751421176811396


 72%|███████▏  | 201/279 [12:14<06:32,  5.03s/it]

loss: 3.0957260131835938   f1_score: 0.7803537722203389


 90%|████████▉ | 251/279 [15:13<02:24,  5.15s/it]

loss: 0.18291282653808594   f1_score: 0.7759855703738762


100%|██████████| 279/279 [16:49<00:00,  3.62s/it]
  0%|          | 1/279 [00:08<39:45,  8.58s/it]

loss: 0.10910224914550781   f1_score: 0.7696203680918179


 18%|█▊        | 51/279 [03:08<19:39,  5.17s/it]

loss: 0.09691619873046875   f1_score: 0.7369183995195989


 27%|██▋       | 74/279 [04:30<11:33,  3.38s/it]