# Chinese Named Entity Recognition (NER)

In [1]:
!git clone https://github.com/GitYCC/bert-minimal-tutorial.git

Cloning into 'bert-minimal-tutorial'...
remote: Enumerating objects: 76, done.[K
remote: Counting objects: 100% (76/76), done.[K
remote: Compressing objects: 100% (58/58), done.[K
remote: Total 76 (delta 36), reused 50 (delta 16), pack-reused 0[K
Unpacking objects: 100% (76/76), done.


In [2]:
%cd bert-minimal-tutorial

/content/bert-minimal-tutorial


In [3]:
!pip install -q -r requirements.txt

[?25l[K     |█▍                              | 10kB 29.6MB/s eta 0:00:01[K     |██▉                             | 20kB 16.4MB/s eta 0:00:01[K     |████▎                           | 30kB 12.4MB/s eta 0:00:01[K     |█████▊                          | 40kB 12.3MB/s eta 0:00:01[K     |███████▏                        | 51kB 10.6MB/s eta 0:00:01[K     |████████▋                       | 61kB 10.8MB/s eta 0:00:01[K     |██████████                      | 71kB 11.6MB/s eta 0:00:01[K     |███████████▍                    | 81kB 11.6MB/s eta 0:00:01[K     |████████████▉                   | 92kB 12.4MB/s eta 0:00:01[K     |██████████████▎                 | 102kB 12.5MB/s eta 0:00:01[K     |███████████████▊                | 112kB 12.5MB/s eta 0:00:01[K     |█████████████████▏              | 122kB 12.5MB/s eta 0:00:01[K     |██████████████████▋             | 133kB 12.5MB/s eta 0:00:01[K     |████████████████████            | 143kB 12.5MB/s eta 0:00:01[K     |█████████████

In [4]:
import os

import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertForTokenClassification
from tqdm.notebook import tqdm
from seqeval.metrics import f1_score

from utils import RunningAverage, tokenize_and_map

MODEL_NAME = 'bert-base-chinese'

## Dataloader

In [5]:
texts = []
tag_lists = []
with open('data/msra_train_bio.txt') as fr:
    text = ''
    tag_list = []
    for line in fr.readlines():
        line = line.strip()
        if line == '':
            assert len(text) == len(tag_list)
            texts.append(text)
            tag_lists.append(tag_list)
            text = ''
            tag_list = []
        elif line == '0':
            text += ' '
            tag_list.append('O')
        else:
            char, tag = line.split('\t')
            assert len(char) == 1
            text += char
            tag_list.append(tag)

In [6]:
idx = 26
print('text:', texts[idx])
print('tag list:', tag_lists[idx])

text: 我們的藏品中有幾十冊為北京圖書館等國家級藏館所未藏。
tag list: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [14]:
LABELS = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']


class NerDataset(Dataset):
    def __init__(self, tokenizer, texts, tag_lists=None, max_len=512, for_train=True):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.for_train = for_train

        self.texts = texts
        self.tag_lists = tag_lists

    def __getitem__(self, idx):
        text = self.texts[idx].lower()

        tokens, index_map = tokenize_and_map(self.tokenizer, text)

        cut_index = self.max_len - 2
        if cut_index < len(tokens):
            cut_text_index = index_map.index(cut_index)
            tokens = tokens[:cut_index]
            text = text[:cut_text_index]
            index_map = index_map[:cut_text_index]

        processed_tokens = ['[CLS]'] + tokens + ['[SEP]']

        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(processed_tokens))
        token_type_ids = torch.tensor([0] * len(processed_tokens))
        attention_mask = torch.tensor([1] * len(processed_tokens))

        outputs = (input_ids, token_type_ids, attention_mask)

        if self.for_train:
            labels = []

            tag_list = self.tag_lists[idx]
            for tag, token_index in zip(tag_list, index_map):
                if token_index is None:
                    continue
                if token_index >= len(labels):
                    labels.append(LABELS.index(tag))

            labels = [0] + labels + [0]  # for [CLS] and [SEP]
            labels = torch.tensor(labels)

            assert labels.size(0) == input_ids.size(0), f'{text}, {tokens}, {index_map}, {labels}, {input_ids}, {tag_list}'
            outputs += (labels, )

        info = {
            'text': text,
            'tokens': tokens,
            'index_map': index_map
        }
        outputs += (info, )
        return outputs

    def __len__(self):
        return len(self.texts)

    def create_mini_batch(self, samples):
        outputs = list(zip(*samples))

        # zero pad 到同一序列長度
        input_ids = pad_sequence(outputs[0], batch_first=True)
        token_type_ids = pad_sequence(outputs[1], batch_first=True)
        attention_mask = pad_sequence(outputs[2], batch_first=True)

        batch_output = (input_ids, token_type_ids, attention_mask)
    
        if self.for_train:
            labels = pad_sequence(outputs[3], batch_first=True)
            batch_output += (labels, )
        else:
            infos = outputs[3]
            batch_output += (infos, )

        return batch_output

In [8]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

SKIP_TOKEN_IDS = [tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.pad_token_id]

dataset = NerDataset(tokenizer, texts, tag_lists)

CUT_RATIO = 0.9
train_size = int(CUT_RATIO * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=109540.0, style=ProgressStyle(descripti…




In [9]:
batch_size = 8

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
    shuffle=True
)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    collate_fn=dataset.create_mini_batch,
)

## Model

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model = BertForTokenClassification.from_pretrained(
    MODEL_NAME, 
    num_labels = len(LABELS),
    return_dict=True
)
model.to(device)

device: cuda


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=624.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411577189.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

## Train

In [11]:
def train_batch(model, data, optimizer, device):
    model.train()
    input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

    outputs = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        labels=labels
    )
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


def evaluate(model, valid_loader, device):
    model.eval()

    loss_averager = RunningAverage()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for data in tqdm(valid_loader, desc='evaluate'):
            input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss_averager.add(outputs.loss.item())

            preds = outputs.logits.argmax(dim=-1).cpu().tolist()
            for token_id_list, label_list, pred_list in zip(input_ids, labels, preds):
                label_list = [LABELS[i] for i, token_id in zip(label_list, token_id_list)
                              if token_id not in SKIP_TOKEN_IDS]
                pred_list = [LABELS[i] for i, token_id in zip(pred_list, token_id_list)
                             if token_id not in SKIP_TOKEN_IDS]
                all_labels.append(label_list)
                all_preds.append(pred_list)

    f1 = f1_score(all_labels, all_preds)
    return loss_averager.get(), f1

In [12]:
lr = 0.00001
max_iter = 200
show_per_iter = 10
valid_per_iter = 50
save_per_iter = 100
save_checkpoint_dir = 'models/'
model_prefix = 'cn_ner_'

assert save_per_iter % valid_per_iter == 0

optimizer = optim.Adam(model.parameters(), lr=lr)

i = 1
is_running = True
train_loss = RunningAverage()
model_paths = []
while is_running:
    for train_data in train_loader:
        loss = train_batch(model, train_data, optimizer, device)
        train_loss.add(loss)

        if i % show_per_iter == 0:
            print('train [{}]: loss={}'.format(i, train_loss.get()))
            train_loss.flush()

        if i % valid_per_iter == 0:
            loss, f1 = evaluate(model, valid_loader, device)
            print(f'valid: loss={loss} f1={f1}')

        if i % save_per_iter == 0:
            path = os.path.join(save_checkpoint_dir, model_prefix + f'loss{loss:.5}/')
            print(f'save model at {path}')
            model.save_pretrained(path)
            model_paths.append(path)
        
        if i == max_iter:
            is_running = False
            break

        i += 1

train [10]: loss=1.1412302672863006
train [20]: loss=0.6182441264390945
train [30]: loss=0.4170108214020729
train [40]: loss=0.3752801865339279
train [50]: loss=0.3239841714501381


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=563.0, style=ProgressStyle(description_wid…


valid: loss=0.3162050543892235 f1=0.011971423054643755
train [60]: loss=0.2736549161374569
train [70]: loss=0.29652080237865447
train [80]: loss=0.19650333374738693
train [90]: loss=0.1721366174519062
train [100]: loss=0.1303988266736269


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=563.0, style=ProgressStyle(description_wid…


valid: loss=0.13463739691993643 f1=0.6018557318168213
save model at models/cn_ner_loss0.13464/
train [110]: loss=0.15852563753724097
train [120]: loss=0.14877173453569412
train [130]: loss=0.10444948710501194
train [140]: loss=0.10955897644162178
train [150]: loss=0.09543922133743762


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=563.0, style=ProgressStyle(description_wid…


valid: loss=0.08398433841622664 f1=0.7100931404135667
train [160]: loss=0.07400914430618286
train [170]: loss=0.07591968551278114
train [180]: loss=0.08649633470922709
train [190]: loss=0.08112326189875603
train [200]: loss=0.05587856452912092


HBox(children=(FloatProgress(value=0.0, description='evaluate', max=563.0, style=ProgressStyle(description_wid…


valid: loss=0.07005838210715996 f1=0.7712316451902906
save model at models/cn_ner_loss0.070058/


## Predict

In [16]:
reload_checkpoint = model_paths[-1]

texts = [
    '王小明去台北市立動物園玩',
    '高雄的西子灣是一個散心絕佳的好去處'
]

pred_dataset = NerDataset(tokenizer, texts, for_train=False)

pred_loader = DataLoader(
    dataset=pred_dataset,
    batch_size=batch_size,
    collate_fn=pred_dataset.create_mini_batch,
)

model = BertForTokenClassification.from_pretrained(reload_checkpoint)
model.to(device)

results = []
with torch.no_grad():
    for data in tqdm(pred_loader, desc='predict'):
        input_ids, token_type_ids, attention_mask = [d.to(device) for d in data[:3]]
        infos = data[3]

        outputs = model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask
        )

        preds = outputs.logits.argmax(dim=-1).cpu().tolist()
        for token_id_list, pred_list, info in zip(input_ids, preds, infos):
            pred_list = [LABELS[i] for i, token_id in zip(pred_list, token_id_list)
                         if token_id not in SKIP_TOKEN_IDS]
            tokens = info['tokens']
            result = list(zip(tokens, pred_list))
            results.append(result)

print('predict result: ')
for result in results:
    print(result)

HBox(children=(FloatProgress(value=0.0, description='predict', max=1.0, style=ProgressStyle(description_width=…


predict result: 
[('王', 'B-PER'), ('小', 'I-PER'), ('明', 'I-PER'), ('去', 'O'), ('台', 'B-LOC'), ('北', 'I-ORG'), ('市', 'I-ORG'), ('立', 'I-ORG'), ('動', 'I-ORG'), ('物', 'I-ORG'), ('園', 'I-ORG'), ('玩', 'O')]
[('高', 'B-LOC'), ('雄', 'I-LOC'), ('的', 'O'), ('西', 'B-LOC'), ('子', 'I-LOC'), ('灣', 'I-LOC'), ('是', 'O'), ('一', 'O'), ('個', 'O'), ('散', 'O'), ('心', 'O'), ('絕', 'O'), ('佳', 'O'), ('的', 'O'), ('好', 'O'), ('去', 'O'), ('處', 'O')]
