In [1]:
import warnings
warnings.filterwarnings('ignore')

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, BertTokenizer

import tqdm

from datetime import datetime
from pathlib import Path
import time

import numpy as np


In [210]:
tokenizer_a = BertTokenizer.from_pretrained("voidful/albert_chinese_large")
tokenizer_b = BertTokenizer.from_pretrained("bert-base-chinese")

In [217]:
assert len(tokenizer_a.vocab) == len(tokenizer_b.vocab)

for a, b in zip(tokenizer_a.vocab, tokenizer_b.vocab):
    
    if a != b:
        print(a, b)
#     print(a, b)
    
    break

In [218]:
tokenizer_a.save_pretrained('./config')

('./config/vocab.txt',
 './config/special_tokens_map.json',
 './config/added_tokens.json')

In [None]:
tokenizer_a.

In [2]:
# albert_tokenizer = AutoTokenizer.from_pretrained('voidful/albert_chinese_large')
tokenizer = BertTokenizer.from_pretrained("voidful/albert_chinese_large")


In [3]:
ner_tag_file = "pos.tgt.dict"
ner_tag2id = {}

count = 0
with open(ner_tag_file) as fp:
    for line in fp:
        line = line.strip().split()
        ner_tag2id[line[0]] = count
        count += 1
        
ner_id2tag = {v:k for k, v in ner_tag2id.items()}
        
print(len(ner_tag2id))

222


In [4]:
from typing import List, Optional, Mapping

class CRF(nn.Module):
    """Conditional random field.
    This module implements a conditional random field [LMP]. The forward computation
    of this class computes the log likelihood of the given sequence of tags and
    emission score tensor. This class also has ``decode`` method which finds the
    best tag sequence given an emission score tensor using `Viterbi algorithm`_.
    Arguments
    ---------
    num_tags : int
        Number of tags.
    batch_first : bool, optional
        Whether the first dimension corresponds to the size of a minibatch.
    Attributes
    ----------
    start_transitions : :class:`~torch.nn.Parameter`
        Start transition score tensor of size ``(num_tags,)``.
    end_transitions : :class:`~torch.nn.Parameter`
        End transition score tensor of size ``(num_tags,)``.
    transitions : :class:`~torch.nn.Parameter`
        Transition score tensor of size ``(num_tags, num_tags)``.
    References
    ----------
    .. [LMP] Lafferty, J., McCallum, A., Pereira, F. (2001).
             "Conditional random fields: Probabilistic models for segmenting and
             labeling sequence data". *Proc. 18th International Conf. on Machine
             Learning*. Morgan Kaufmann. pp. 282–289.
    .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
    """

    def __init__(self, num_tags: int, batch_first: bool = False, tag_id_to_name: Mapping[int, str] = None) -> None:
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()
        if not tag_id_to_name: return

        self.tag_id_to_name = dict()
        for t_i_n in tag_id_to_name.items():
            t_i, t_name = t_i_n
            t_name_part = t_name.split('-')
            assert len(t_name_part) <= 3, "tag name {} error".format(t_name)
            t_pref = t_name_part[0]
            t_suff = t_name_part[1] if len(t_name_part) > 1 else t_name_part[0]
            if len(t_name_part) == 3:
                t_suff += '-'
                t_suff += t_name_part[2]
            self.tag_id_to_name[t_i] = (t_pref, t_suff)

        self.set_transition_constraint()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution
        between -0.01 and 0.01.
        """
        nn.init.uniform_(self.start_transitions, -0.01, 0.01)
        nn.init.uniform_(self.end_transitions, -0.01, 0.01)
        nn.init.uniform_(self.transitions, -0.01, 0.01)

    def set_transition_constraint(self) -> None:
        """Set impossible transitions
        Set transitions between impossible labels to a very low score, effectively disabling them
        """
        if len(self.tag_id_to_name) < 1:
            print("no tag id to name dict")
            return
        for source_i, (source_pref, source_suff) in self.tag_id_to_name.items():
            if source_pref in ['[PAD]', '[SEP]']:
                self.start_transitions.data[source_i] = -10000.
                # can only go to PAD
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if (targ_pref in ['[PAD]']):
                        self.transitions.data[source_i, targ_i] = 0.001
            if source_pref in ['[CLS]']:
                # possible ends are S- and B-
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if targ_pref in ['S', 'B']:
                        self.transitions.data[source_i, targ_i] = 0.001
                    if targ_pref in ['I', 'E']:
                        self.transitions.data[source_i, targ_i] = -10000.
            if source_pref in ['B', 'I']:
                # possible ends are I- and E-
                self.transitions.data[source_i].fill_(-10000.)
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if (targ_suff == source_suff) and (targ_pref in ['I', 'E']):
                        self.transitions.data[source_i, targ_i] = 0.001
            if source_pref in ['S', 'E']:
                # cannot go to I or E
                for targ_i in range(len(self.tag_id_to_name)):
                    targ_pref, targ_suff = self.tag_id_to_name[targ_i]
                    if targ_pref in ['I', 'E']:
                        self.transitions.data[source_i, targ_i] = -10000.
            if source_pref in ['I', 'E']:
                # cannot be start transitions
                self.start_transitions.data[source_i] = -10000.

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(
            self,
            emissions: torch.Tensor,
            tags: torch.LongTensor,
            mask: Optional[torch.ByteTensor] = None,
            reduction: str = 'sum',
    ) -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores.
        Arguments
        ---------
        emissions : :class:`~torch.Tensor`
            Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise.
        tags : :class:`~torch.LongTensor`
            Sequence of tags tensor of size ``(seq_length, batch_size)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
        mask : :class:`~torch.ByteTensor`, optional
            Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
            ``(batch_size, seq_length)`` otherwise.
        reduction : str, optional
            Specifies  the reduction to apply to the output: 'none'|'sum'|'mean'|'token_mean'.
            'none': no reduction will be applied. 'sum': the output will be summed over batches.
            'mean': the output will be averaged over batches. 'token_mean': the output will be
            averaged over tokens.
        Returns
        -------
        :class:`~torch.Tensor`
            The log likelihood. This will have size ``(batch_size,)`` if reduction is 'none',
            ``()`` otherwise.
        """
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.float().sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm.
        Arguments
        ---------
        emissions : :class:`~torch.Tensor`
            Emission score tensor of size ``(seq_length, batch_size, num_tags)`` if
            ``batch_first`` is ``False``, ``(batch_size, seq_length, num_tags)`` otherwise.
        mask : :class:`~torch.ByteTensor`, optional
            Mask tensor of size ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
            ``(batch_size, seq_length)`` otherwise.
        Returns
        -------
        List[List[int]]
            List of list containing the best tag sequence for each batch.
        """
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(
            self,
            emissions: torch.Tensor,
            tags: Optional[torch.LongTensor] = None,
            mask: Optional[torch.ByteTensor] = None) -> None:
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(
            self, emissions: torch.Tensor, tags: torch.LongTensor,
            mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.float()

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(
            self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self, emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

In [56]:
class Albert_CRF(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.albert = AutoModel.from_pretrained("voidful/albert_chinese_large", return_dict=True)
        
        self.classifier = nn.Linear(self.albert.config.hidden_size, len(ner_tag2id))
        
        self.crf = CRF(len(ner_tag2id), batch_first=True, tag_id_to_name=ner_id2tag)
        
    def forward(self, input_seqs, input_mask):
        
        output = self.albert(input_ids=input_seqs, 
                             attention_mask=input_mask,)
        
        output = output['last_hidden_state']
        
        output = self.classifier(output)
        
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        
        output = output.transpose(0, 1)
        input_mask = input_mask.transpose(0, 1)
        
        best_tags_list = self.crf._viterbi_decode(emissions=output, mask=input_mask.byte())
        
        return best_tags_list
        
        

In [198]:
exp_dir = Path("albert_ner/exp/")

model_path = exp_dir / 'albert_ner_len256_batch_20_2020-10-05 10:41:11/epoch_2_1.ckpt'

model_path = "albert_large.ckpt"

from collections import OrderedDict
ckpt = torch.load(model_path, map_location='cpu')
state_dict = ckpt['net']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    

model = Albert_CRF()

model.load_state_dict(new_state_dict)

model = model.eval()

for parms in model.parameters():
    parms.requires_grad = False

print('total parms : ', sum(p.numel() for p in model.parameters()))
print('trainable parms : ', sum(p.numel() for p in model.parameters() if p.requires_grad))


total parms :  16825630
trainable parms :  0


In [199]:
model

Albert_CRF(
  (albert): AlbertModel(
    (embeddings): AlbertEmbeddings(
      (word_embeddings): Embedding(21128, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0, inplace=False)
    )
    (encoder): AlbertTransformer(
      (embedding_hidden_mapping_in): Linear(in_features=128, out_features=1024, bias=True)
      (albert_layer_groups): ModuleList(
        (0): AlbertLayerGroup(
          (albert_layers): ModuleList(
            (0): AlbertLayer(
              (full_layer_layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
              (attention): AlbertAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
 

In [200]:
sent = \
'高加索地區的亞塞拜然與亞美尼亞兩國，9月27日為了主權爭議的納哥諾卡拉巴克地區（簡稱納卡）爆發衝突，雙方坦克、大砲與無人機齊發，連番交火，已知有250人喪命。10月5日再度互控對方攻擊平民區，亞塞拜然境內第二大城甘賈（Ganja）、亞美尼亞控制的納卡地區首府史提帕納科特（Stepanakert）分別遭砲火攻擊。'

sent = '[CLS] ' + sent + ' [SEP]'

In [201]:
token = tokenizer.tokenize(sent)

ids = tokenizer.convert_tokens_to_ids(token)

print(len(token))

input_seq = torch.LongTensor(ids)
input_mask = torch.ones(input_seq.shape)

145


In [202]:
best_tags_list = model(input_seq.unsqueeze(0), input_mask.unsqueeze(0))
best_tags_list = best_tags_list[0]
best_tags_list = [ner_id2tag[tag_id] for tag_id in best_tags_list]

In [203]:
if len(best_tags_list) != len(token):
    print('something error')

toks = []
tags = []

temp_tok = ''
  
for tok, tag in zip(token[1:-1], best_tags_list[1:-1]):
    
    tag = tag.split('-')
    
    if len(tag) == 2:
        bound, pos = tag
    elif len(tag) == 3:
        bound, _, pos = tag
        
    if bound == 'S':
        toks.append(tok)
        tags.append(pos)
        
        continue
        
    temp_tok += tok+' '
    
    if bound == 'B':
        tags.append(pos)
        
    if bound == 'E':
        temp_tok = temp_tok.replace(' ##', '')
        if temp_tok.replace(' ', '').encode().isalpha() == False:
            temp_tok = temp_tok.replace(' ', '')
        toks.append(temp_tok)
        temp_tok = ''

        
if len(temp_tok) > 0: # special case: no 'E-' in predicted_pos
    toks.append(temp_tok)
    temp_tok = ''
    
if len(toks) != len(tags):
    print('something error')

In [204]:
output = ''

for tok, tag in zip(toks, tags):
    print(tok, tag)
    output += tok + ' {' + tag + '}' + '   '

高加索 LOC
地區 Nc
的 DE
亞塞拜然 LOC
與 Caa
亞美尼亞 LOC
兩 Neu
國 Nc
， COMMACATEGORY
9月 Nd
27日 Nd
為了 P
主權 Na
爭議 Na
的 DE
納哥諾卡拉巴克 LOC
地區 Nc
（ PARENTHESISCATEGORY
簡稱 VG
納卡 Nb
） PARENTHESISCATEGORY
爆發 VJ
衝突 Na
， COMMACATEGORY
雙方 Nh
坦克 Na
、 PAUSECATEGORY
大砲 Na
與 Caa
無人機 Na
齊發 VH
， COMMACATEGORY
連番 D
交火 VA
， COMMACATEGORY
已 D
知 VK
有 V_2
250 Neu
人 Na
喪命 VH
。 PERIODCATEGORY
10月 Nd
5日 Nd
再度 D
互控 VC
對方 Nh
攻擊 VC
平民區 Nc
， COMMACATEGORY
亞塞拜然 LOC
境 Na
內 Ncd
第二 Neu
大城 Na
甘賈 PER
（ PARENTHESISCATEGORY
ganja  PER
） PARENTHESISCATEGORY
、 PAUSECATEGORY
亞美尼亞 LOC
控制 VC
的 DE
納卡 LOC
地區 Nc
首府 Nc
史提帕納科特 PER
（ PARENTHESISCATEGORY
stepanakert  FW
） PARENTHESISCATEGORY
分別 D
遭 P
砲火 Na
攻擊 Nv
。 PERIODCATEGORY
