In [10]:
from collections import defaultdict

import numpy as np
import pandas as pd
import jieba

from scipy.sparse import csr_matrix, csr_array

In [11]:
import numpy as np

def viterbi(y, A, B, Pi=None):
    """
    Return the MAP estimate of state trajectory of Hidden Markov Model.

    Parameters
    ----------
    y : array (T,)
        Observation state sequence. int dtype.
    A : array (K, K)
        State transition matrix. See HiddenMarkovModel.state_transition  for
        details.
    B : array (K, M)
        Emission matrix. See HiddenMarkovModel.emission for details.
    Pi: optional, (K,)
        Initial state probabilities: Pi[i] is the probability x[0] == i. If
        None, uniform initial distribution is assumed (Pi[:] == 1/K).

    Returns
    -------
    x : array (T,)
        Maximum a posteriori probability estimate of hidden state trajectory,
        conditioned on observation sequence y under the model parameters A, B,
        Pi.
    T1: array (K, T)
        the probability of the most likely path so far
    T2: array (K, T)
        the x_j-1 of the most likely path so far
    """
    # Cardinality of the state space
    K = A.shape[0]
    # Initialize the priors with default (uniform dist) if not given by caller
    Pi = Pi if Pi is not None else np.full(K, 1 / K)
    T = len(y)
    T1 = np.empty((K, T), 'd')
    T2 = np.empty((K, T), 'B')

    # Initilaize the tracking tables from first observation
    T1[:, 0] = Pi * B[:, y[0]]
    T2[:, 0] = 0

    # Iterate throught the observations updating the tracking tables
    for i in range(1, T):
        T1[:, i] = np.max(T1[:, i - 1] * A.T * B[np.newaxis, :, y[i]].T, 1)
        T2[:, i] = np.argmax(T1[:, i - 1] * A.T, 1)

    # Build the output, optimal model trajectory
    x = np.empty(T, 'B')
    x[-1] = np.argmax(T1[:, T - 1])
    for i in reversed(range(1, T)):
        x[i - 1] = T2[x[i], i]

    return x, T1, T2

In [12]:
df = pd.read_csv("./ok_data_level3-4/ok_data_level3.csv").iloc[:-3]

In [13]:
states = ["s", "m", "e", "suf"]

In [18]:
df.head(20)

Unnamed: 0,id,pid,deep,name,pinyin_prefix,pinyin,ext_id,ext_name
0,11,0,0,北京,b,bei jing,110000000000,北京市
1,1101,11,1,北京,b,bei jing,110100000000,北京市
2,110101,1101,2,东城,d,dong cheng,110101000000,东城区
3,110102,1101,2,西城,x,xi cheng,110102000000,西城区
4,110105,1101,2,朝阳,c,chao yang,110105000000,朝阳区
5,110106,1101,2,丰台,f,feng tai,110106000000,丰台区
6,110107,1101,2,石景山,s,shi jing shan,110107000000,石景山区
7,110108,1101,2,海淀,h,hai dian,110108000000,海淀区
8,110109,1101,2,门头沟,m,men tou gou,110109000000,门头沟区
9,110111,1101,2,房山,f,fang shan,110111000000,房山区


In [20]:
df["suffix"] = df.apply(lambda row: row["ext_name"].split(row["name"])[1], axis=1)

# HMM

In [87]:
emission_freq = defaultdict(int)
transition_matrix = np.zeros((len(states), len(states) + 1))

token_dict = dict()
token_increment = 0

for name in df["name"]:
    for i, token in enumerate(name):
        if i == 0:
            prev_state = None
            state = "s"
        elif i == len(name) - 1:
            prev_state = "m" if len(name) > 2 else "s"
            state = "e"
        else:
            prev_state = "s" if i == 1 else "m"
            state = "m"

        state_index = states.index(state)

        if prev_state:
            prev_state_index = states.index(prev_state)
            transition_matrix[prev_state_index, state_index] += 1

        if token in token_dict:
            token_index = token_dict[token]
        else:
            token_increment += 1
            token_dict[token] = token_increment
            token_index = token_increment

        emission_freq[(state_index, token_index)] += 1

for suffix in df["suffix"]:
    for i, token in enumerate(suffix):
        state = "suf"
        state_index = states.index(state)
        
        if token in token_dict:
            token_index = token_dict[token]
        else:
            token_increment += 1
            token_dict[token] = token_increment
            token_index = token_increment

        emission_freq[(state_index, token_index)] += 1

transition_matrix[2, 3] = df.shape[0]
transition_matrix[2, 0] = 0.5 * df.shape[0]
transition_matrix[2, 4] = 0.5 * df.shape[0]
transition_matrix[3, 0] = df.shape[0]
transition_matrix[3, 3] = df.shape[0]
transition_matrix[3, 4] = df.shape[0]

transition_matrix = transition_matrix / transition_matrix.sum(axis=1, keepdims=True)
transition_matrix = transition_matrix[:, :-1]

initial_probs = np.array([1, 0, 0, 0])

In [88]:
state_indices = []
token_indices = []
frequencies = []

for (si, ti), freq in emission_freq.items():
    state_indices.append(si)
    token_indices.append(ti)
    frequencies.append(freq)

emission_matrix = csr_matrix((frequencies, (state_indices, token_indices))).toarray()

In [89]:
def infer(text, token_dict, transition_matrix, emission_matrix, initial_probs=None):
    indices = []
    for c in text:
        indices.append(token_dict[c])
    
    return viterbi(indices, A=transition_matrix, B=emission_matrix, Pi=initial_probs)
        

In [99]:
infer("广西市海淀土家族苗族自治县", token_dict, transition_matrix, emission_matrix, initial_probs)

(array([0, 2, 3, 0, 2, 3, 3, 3, 3, 3, 3, 3, 3], dtype=uint8),
 array([[2.10000000e+01, 0.00000000e+00, 1.10984191e+03, 2.83724918e+06,
         0.00000000e+00, 2.56320558e+06, 0.00000000e+00, 0.00000000e+00,
         2.48061340e+09, 0.00000000e+00, 1.55121025e+12, 4.03314664e+13,
         0.00000000e+00],
        [0.00000000e+00, 4.05668685e+00, 0.00000000e+00, 2.14394337e+02,
         0.00000000e+00, 0.00000000e+00, 2.72331776e+06, 3.78632283e+06,
         0.00000000e+00, 9.58387785e+08, 1.66559959e+09, 7.49139786e+11,
         0.00000000e+00],
        [0.00000000e+00, 7.39894606e+02, 7.67528167e+01, 2.90766975e+04,
         2.56320558e+06, 0.00000000e+00, 4.63126247e+06, 0.00000000e+00,
         0.00000000e+00, 0.00000000e+00, 6.25267867e+08, 4.20414686e+12,
         4.62736432e+15],
        [0.00000000e+00, 0.00000000e+00, 2.65992111e+05, 0.00000000e+00,
         0.00000000e+00, 1.66608363e+07, 5.55361209e+07, 3.72092010e+09,
         3.47285876e+10, 2.32681537e+12, 1.20994399e+14, 

# CRF

In [6]:
from itertools import chain

# import nltk
import sklearn
import scipy.stats

import sklearn_crfsuite
from sklearn_crfsuite import scorers,CRF
from sklearn_crfsuite.metrics import flat_classification_report
from sklearn_crfsuite import metrics

In [24]:
sentences = []

current_area_1 = None
current_area_2 = None
for _, row in df.iterrows():
    level = row["deep"]
    location_start = row["name"][0]
    location_middle = row["name"][1:]
    location_suffix = row["suffix"]
    location_key = (location_start, location_middle, location_suffix)

    if level == 0:
        current_area_2 = None
        current_area_1 = location_key
        sentences.append([location_key])
        continue

    if level == 1:
        current_area_2 = location_key
        sentences.append([current_area_1, location_key])
        sentences.append([location_key])
        continue

    sentences.append([current_area_1, current_area_2, location_key])


        
    

In [None]:
# Feature set
def word2features(sent, i):
    word = sent[i][0]
    postag = sent[i][1]

    features = {
        'bias': 1.0,
        'word.lower()': word.lower(),
        'word[-3:]': word[-3:],
        'word[-2:]': word[-2:],
        'word.isupper()': word.isupper(),
        'word.istitle()': word.istitle(),
        'word.isdigit()': word.isdigit(),
        'postag': postag,
        'postag[:2]': postag[:2],
    }
    if i > 0:
        word1 = sent[i-1][0]
        postag1 = sent[i-1][1]
        features.update({
            '-1:word.lower()': word1.lower(),
            '-1:word.istitle()': word1.istitle(),
            '-1:word.isupper()': word1.isupper(),
            '-1:postag': postag1,
            '-1:postag[:2]': postag1[:2],
        })
    else:
        features['BOS'] = True

    if i < len(sent)-1:
        word1 = sent[i+1][0]
        postag1 = sent[i+1][1]
        features.update({
            '+1:word.lower()': word1.lower(),
            '+1:word.istitle()': word1.istitle(),
            '+1:word.isupper()': word1.isupper(),
            '+1:postag': postag1,
            '+1:postag[:2]': postag1[:2],
        })
    else:
        features['EOS'] = True

    return features

# CRF torch

In [267]:
# Author: Robert Guthrie
from itertools import product
from tqdm import tqdm

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import numpy as np
from datasets import load_metric

from bi_lstm_crf.model import BiRnnCrf

torch.manual_seed(1234)

<torch._C.Generator at 0x29bc41d1450>

In [129]:
torch.randn(2, 1, 4).shape

torch.Size([2, 1, 4])

In [185]:
test_feats = torch.randn(32, 12)
test_index_1 = torch.arange(32)
test_index_2 = torch.ones(32, dtype=torch.long)

In [187]:
test_feats[test_index_1, test_index_2].shape

torch.Size([32])

In [215]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, batch_size, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.batch_size = batch_size
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, self.batch_size, self.hidden_dim // 2),
                torch.randn(2, self.batch_size, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        # START_TAG has all of the score.
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas

        # Iterate through the sentence
        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence)#.view(len(sentence), 1, -1)
        embeds = torch.transpose(embeds, 0, 1)
        # print(torch.transpose(self.word_embeds(sentence), 0, 1).shape)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)  
        lstm_out = lstm_out.view(-1, self.batch_size, self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        sequence_length = tags.shape[0]
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat(
            [
                torch.full((1, self.batch_size), self.tag_to_ix[START_TAG], dtype=torch.long), 
                tags
            ], 
            dim=0
        )
        
        # tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        # for i, feat in enumerate(feats):
        for i in range(sequence_length):
            print("score shape", score.shape)
            print("transition shape", self.transitions[tags[i + 1], tags[i]].shape)
            print("likelihood shape", feats[i, torch.arange(self.batch_size), tags[i + 1]].shape)
            score = score + self.transitions[tags[i + 1], tags[i]] + feats[i, torch.arange(self.batch_size), tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        # print("F", feats.shape)
        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for feat in feats:
            # print("feat", feat.shape)
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                # print(forward_var.shape)
                # print(self.transitions[next_tag].shape)
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        tags_batch_second = torch.transpose(tags, 0, 1)
        gold_score = self._score_sentence(feats, tags_batch_second)
        return forward_score - gold_score

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

In [406]:

current_area_1, current_area_2, current_area_3 = None, None, None
combinations = []
for _, row in df.iterrows():
    level = row["deep"]
    lid = row["id"]

    if level == 0:
        current_area_1 = lid
        current_area_2 = np.nan
        current_area_3 = np.nan

    elif level == 1:
        current_area_2 = lid
        current_area_3 = np.nan
    else:
        current_area_3 = lid

    combinations.append([current_area_1, current_area_2, current_area_3])

    if not pd.isna(current_area_2):
        combinations.append([np.nan, current_area_2, current_area_3])

    if not pd.isna(current_area_3):
        combinations.append([np.nan, np.nan, current_area_3])


In [407]:
df_combinations = pd.DataFrame(combinations, columns=["id1", "id2", "id3"])

df_valid_1 = df_combinations.merge(df[["id", "suffix"]], left_on="id1", right_on="id", how="left")
df_valid_2 = df_combinations.merge(df[["id", "suffix"]], left_on="id2", right_on="id", how="left")
df_valid_3 = df_combinations.merge(df[["id", "suffix"]], left_on="id3", right_on="id", how="left")

df_valid_1["is_valid"] = df_valid_1["suffix"].str.len() != 0
df_valid_2["is_valid"] = df_valid_2["suffix"].str.len() != 0
df_valid_3["is_valid"] = df_valid_3["suffix"].str.len() != 0

df_combinations["is_valid"] = df_valid_1["is_valid"] & df_valid_2["is_valid"] & df_valid_3["is_valid"]

In [408]:

def extract_tags_from_df(dataframe, lid):
    if pd.isna(lid):
        return []

    row = dataframe[dataframe["id"] == lid].iloc[0]

    level = row["deep"]
    location_start = row["name"][0]
    location_middle = row["name"][1:]
    location_suffix_list = list(row["suffix"]) if row["suffix"] else []
    location_suffix_tags = [f"E-{level}"] * len(row["suffix"]) if row["suffix"] else []

    location_short_strs = list(location_start) + list(location_middle)
    location_strs = list(location_start) + list(location_middle) + location_suffix_list

    location_short_tags = [f"B-{level}"] * len(location_start) + [f"I-{level}"] * len(location_middle)
    location_tags = [f"B-{level}"] * len(location_start) + [f"I-{level}"] * len(location_middle) + location_suffix_tags

    location_key = [(location_strs, location_tags), (location_short_strs, location_short_tags)]

    return [location_key]

def extract_sentence_from_row(dataframe, id_1, id_2, id_3):
    location_key_1 = extract_tags_from_df(dataframe, id_1)
    location_key_2 = extract_tags_from_df(dataframe, id_2)
    location_key_3 = extract_tags_from_df(dataframe, id_3)

    stack = location_key_3

    diff_name_23_condition = not (location_key_2 and location_key_3 and location_key_2[0][-1][0] == location_key_3[0][-1][0])
    diff_name_12_condition = not (location_key_1 and location_key_2 and location_key_1[0][-1][0] == location_key_2[0][-1][0])
    
    if diff_name_23_condition:
        stack = location_key_2 + stack

    if diff_name_12_condition and diff_name_23_condition:
        stack = location_key_1 + stack

    combination = product(*stack)

    sentences = []
    for comb in combination:
        strs = []
        tags = []
        for ele in comb:
            strs.extend(ele[0])
            tags.extend(ele[1])

        sentences.append((strs, tags))

    return sentences

def extract_sentences(df_ids, df_ref):
    sentences = []
    for _, row in df_ids.iterrows():
        sents = extract_sentence_from_row(df_ref, row["id1"], row["id2"], row["id3"])
        sentences.extend(sents)
    return sentences

In [409]:

df_train = df_combinations[df_combinations["is_valid"]].sample(frac=0.8)
df_test = df_combinations[df_combinations["is_valid"]].drop(df_train.index)
# df_test = pd.concat([df_test, df_combinations[~df_combinations["is_valid"]]], axis=0)

train_sentences = extract_sentences(df_train, df)
test_sentences = extract_sentences(df_test, df)

START_TAG = "<START>"
STOP_TAG = "<STOP>"
UNKNOWN_WORD = "<UNK>"

WORD_PAD_IDX = 0
EMBEDDING_DIM = 12
HIDDEN_DIM = 8

word_to_ix = {UNKNOWN_WORD: 0}

for sentence, tags in train_sentences:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

for sentence, tags in test_sentences:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
ix_to_word = {v: k for k, v in word_to_ix.items()} 

tag_to_ix = {"B-0": 0, "I-0": 1, "E-0": 2, "B-1": 3, "I-1": 4, "E-1": 5, "B-2": 6, "I-2": 7, "E-2": 8}
ix_to_tag = {v: k for k, v in tag_to_ix.items()}

# for future work involving whole sentence only, otherwise the UNKNOWN not likely useful.
UNKNOWN_TAG = "O"
UNKNOWN_TAG_IDX = len(tag_to_ix)
tag_to_ix[UNKNOWN_TAG] = UNKNOWN_IDX

# padding tag. Exact value doesn't matter. 
TAG_PAD_IDX = UNKNOWN_TAG_IDX


In [410]:
# model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, 32, EMBEDDING_DIM, HIDDEN_DIM)
model = BiRnnCrf(
    len(word_to_ix), len(tag_to_ix),
    embedding_dim=EMBEDDING_DIM, hidden_dim=HIDDEN_DIM, num_rnn_layers=1)

In [411]:
class NERDataSet(Dataset):
    def __init__(self, sentences, word_to_ix, tag_to_ix):
        self.strings, self.tags = zip(*sentences)
        self.word_to_ix = word_to_ix
        self.tag_to_tx = tag_to_ix

    def __len__(self):
        return len(self.strings)

    def __getitem__(self, item):
        string_lst = self.strings[item]
        tag_lst = self.tags[item]

        word_ids = [self.word_to_ix.get(w, WORD_PAD_IDX) for w in string_lst]
        tag_ids = [self.tag_to_tx.get(t, TAG_PAD_IDX) for t in tag_lst]
        return word_ids, tag_ids

In [412]:
def custom_collate(indices):
    word_ids, tag_ids = zip(*indices)
    sequence_length = len(max(word_ids, key=len))

    padded_word_ids = [ids + [WORD_PAD_IDX] * (sequence_length - len(ids)) for ids in word_ids]
    padded_tag_ids = [ids + [TAG_PAD_IDX] * (sequence_length - len(ids)) for ids in tag_ids]

    batch = {"input_ids": torch.tensor(padded_word_ids, dtype=torch.int64), "labels": torch.tensor(padded_tag_ids, dtype=torch.int64)}
    return batch

In [413]:
train_dataset = NERDataSet(train_sentences, word_to_ix, tag_to_ix)
test_dataset = NERDataSet(test_sentences, word_to_ix, tag_to_ix)

In [414]:
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=32, collate_fn=custom_collate)

In [415]:
metric = load_metric("seqeval")

In [416]:
def __eval_model(model, device, dataloader, desc):
    model.eval()
    with torch.no_grad():
        # eval
        losses = []
        nums = []

        all_predictions = []
        all_labels = []
        for batch in tqdm(dataloader, desc=desc):
            xb = batch["input_ids"].to(device)
            yb = batch["labels"].to(device)

            predictions = model(xb)[1]
            
            current_loss = model.loss(xb, yb)
            current_num = len(xb)

            preds_list = [[ix_to_tag[p] for p in pred if p in ix_to_tag] for pred in predictions]
            labels_list = [[ix_to_tag[l] for l in label if l in ix_to_tag] for label in yb.tolist()]
            
            all_predictions.extend(preds_list)
            all_labels.extend(labels_list)

        metric_result = metric.compute(predictions=all_predictions, references=all_labels)
        return np.sum(np.multiply(losses, nums)) / np.sum(nums), metric_result

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0)

device = "cpu"
model.to(device)

val_loss = 0
best_val_loss = 1e4

losses = []
for epoch in range(20):
    # train
    model.train()
    bar = tqdm(train_loader)
    for bi, batch in enumerate(bar):
        xb = batch["input_ids"]
        yb = batch["labels"]

        model.zero_grad()

        loss = model.loss(xb, yb)
        loss.backward()
        optimizer.step()
        
        bar.set_description("{:2d}/{} loss: {:5.2f}, val_loss: {:5.2f}".format(
            epoch+1, 20, loss, val_loss))
        losses.append([epoch, bi, loss.item(), np.nan])

    # evaluation
    val_loss, metric_result = __eval_model(model, device, dataloader=test_loader, desc="eval")
    print("precision", metric_result["overall_precision"])
    print("recall", metric_result["overall_recall"])
    print("f1", metric_result["overall_f1"])
    print("accuracy", metric_result["overall_accuracy"])
    # save losses
    losses[-1][-1] = val_loss.item()
    # __save_loss(losses, loss_path)

    # # save model
    # if not args.save_best_val_model or val_loss < best_val_loss:
    #     best_val_loss = val_loss
    #     __save_model(args.model_dir, model)
    #     print("save model(epoch: {}) => {}".format(epoch, loss_path))

 1/20 loss:  2.14, val_loss:  0.00: 100%|██████████████████████████████████████████| 1062/1062 [00:40<00:00, 25.91it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:05<00:00, 50.83it/s]
  return np.sum(np.multiply(losses, nums)) / np.sum(nums), metric_result


precision 0.791603401153963
recall 0.8654907657190288
f1 0.8268998195840521
accuracy 0.9077858023604016


 2/20 loss:  0.79, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:39<00:00, 26.57it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:05<00:00, 53.08it/s]


precision 0.9069890288500609
recall 0.9263747665490766
f1 0.9165794066317626
accuracy 0.9565263343315131


 3/20 loss:  0.14, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:40<00:00, 26.17it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:08<00:00, 31.50it/s]


precision 0.9348610314011642
recall 0.9464619215604897
f1 0.9406257089236733
accuracy 0.9687334859961247


 4/20 loss:  0.07, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:46<00:00, 22.75it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:04<00:00, 58.89it/s]


precision 0.956555756876956
recall 0.9640589333886699
f1 0.9602926889766221
accuracy 0.9768715871058657


 5/20 loss:  0.09, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:55<00:00, 19.24it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:03<00:00, 70.21it/s]


precision 0.967871319880913
recall 0.9714463581655945
f1 0.9696555438182233
accuracy 0.9814162409723446


 6/20 loss:  0.04, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:40<00:00, 26.30it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:09<00:00, 27.57it/s]


precision 0.9739415030242771
recall 0.9756796015770907
f1 0.9748097775382001
accuracy 0.9838647172802537


 7/20 loss:  0.09, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:45<00:00, 23.29it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:04<00:00, 60.81it/s]


precision 0.9786546193061715
recall 0.9799543473749741
f1 0.9793040520924059
accuracy 0.98615465915096


 8/20 loss:  0.28, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [01:02<00:00, 16.88it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:06<00:00, 42.23it/s]


precision 0.9798763536782706
recall 0.9801203569205229
f1 0.9799983401112127
accuracy 0.9866831072749692


 9/20 loss:  0.03, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [01:21<00:00, 13.03it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:05<00:00, 52.09it/s]


precision 0.9813708405941416
recall 0.98165594521685
f1 0.9815133722015893
accuracy 0.9870706358992426


10/20 loss:  0.03, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [01:08<00:00, 15.49it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:11<00:00, 24.03it/s]


precision 0.9841506928885569
recall 0.9844366051047935
f1 0.9842936282341225
accuracy 0.98865598027127


11/20 loss:  0.01, val_loss:   nan: 100%|██████████████████████████████████████████| 1062/1062 [00:59<00:00, 17.98it/s]
eval: 100%|██████████████████████████████████████████████████████████████████████████| 270/270 [00:08<00:00, 33.15it/s]


precision 0.9842369436263326
recall 0.9847271218095041
f1 0.984481971702419
accuracy 0.9886207503963361


12/20 loss:  0.03, val_loss:   nan:  67%|████████████████████████████▋              | 709/1062 [00:29<00:12, 28.85it/s]

In [402]:
with torch.no_grad():
    # eval
    for batch in tqdm(test_loader):
        xb = batch["input_ids"].to(device)
        yb = batch["labels"].to(device)

        inputs_list = [[ix_to_word[i] for i in input if i != WORD_PAD_IDX] for input in xb.tolist()]

        predictions = model(xb)[1]

        preds_list = [[ix_to_tag[p] for p in pred if p in ix_to_tag] for pred in predictions]
        labels_list = [[ix_to_tag[l] for l in label if l in ix_to_tag] for label in yb.tolist()]

        for j in range(len(inputs_list)):
            if preds_list[j] != labels_list[j]:
                print(inputs_list[j])
                print(preds_list[j])
                print(labels_list[j])

  2%|█▌                                                                                | 5/268 [00:00<00:05, 44.43it/s]

['鞍', '山']
['B-2', 'I-2']
['B-1', 'I-1']
['图', '木', '舒', '克']
['B-2', 'I-2', 'I-2', 'I-2']
['B-1', 'I-1', 'I-1', 'I-1']
['孝', '感', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['阳', '江']
['B-2', 'I-2']
['B-1', 'I-1']
['陇', '南']
['B-2', 'I-2']
['B-1', 'I-1']
['怀', '化']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '澄', '迈', '县']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['安', '康', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['琼', '海', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['鞍', '山', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['邳', '州', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['益', '阳']
['B-2', 'I-2']
['B-1', 'I-1']
['屯', '昌', '县']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['芜', '湖', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['广', '州', '南', '沙', '区']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['阿', '拉', '尔', '市']
['B-1', 'I-1', 'I-1', 'E-1']
['B-2', 'I-2', 'I-2', 'E-2']
['海', '南', '东', '方']


  6%|█████▏                                                                           | 17/268 [00:00<00:05, 49.92it/s]

['昆', '玉', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['乌', '兰', '察', '布', '市']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'I-1', 'I-1', 'E-1']
['阆', '中', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['温', '州', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['阿', '里']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '西']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '省', '屯', '昌']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['海', '南', '省', '昌', '江', '黎', '族', '自', '治', '县']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['昭', '通', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['茂', '名']
['B-2', 'I-2']
['B-1', 'I-1']
['盘', '锦', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['东', '方', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['乌', '鲁', '木', '齐', '水', '磨', '沟', '区']
['B-1', 'I-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'I-1', 'I-1', '

 14%|███████████▏                                                                     | 37/268 [00:00<00:04, 57.46it/s]

['鄂', '城']
['B-1', 'I-1']
['B-2', 'I-2']
['东', '营', '广', '饶']
['B-0', 'I-0', 'B-1', 'I-1']
['B-1', 'I-1', 'B-2', 'I-2']
['平', '凉']
['B-2', 'I-2']
['B-1', 'I-1']
['吴', '忠']
['B-2', 'I-2']
['B-1', 'I-1']
['赣', '州', '赣', '县']
['B-2', 'I-2', 'I-2', 'E-2']
['B-1', 'I-1', 'B-2', 'I-2']
['淄', '博', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['海', '南', '乐', '东']
['B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2']
['河', '北', '省', '秦', '皇', '岛', '山', '海', '关']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2']
['临', '海', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['宁', '德']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '陵', '水']
['B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2']


 20%|████████████████                                                                 | 53/268 [00:00<00:03, 65.60it/s]

['曲', '靖']
['B-2', 'I-2']
['B-1', 'I-1']
['宝', '鸡', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['莆', '田']
['B-2', 'I-2']
['B-1', 'I-1']
['前', '郭', '尔', '罗', '斯', '蒙', '古', '族', '自', '治', '县']
['B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-2', 'I-2', 'I-2', 'I-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['沧', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['彰', '化', '县']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['滨', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['上', '栗']
['B-1', 'I-1']
['B-2', 'I-2']
['河', '南', '新', '乡']
['B-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1']
['乌', '苏']
['B-1', 'I-1']
['B-2', 'I-2']
['广', '东', '茂', '名']
['B-0', 'I-0', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1']


 23%|██████████████████▋                                                              | 62/268 [00:01<00:02, 72.02it/s]

['崇', '左']
['B-2', 'I-2']
['B-1', 'I-1']
['运', '城', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['甘', '肃', '省', '嘉', '峪', '关']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'I-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1']
['汉', '川', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['防', '城', '港']
['B-2', 'I-2', 'I-2']
['B-1', 'I-1', 'I-1']
['张', '掖']
['B-2', 'I-2']
['B-1', 'I-1']
['神', '木', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['丹', '江', '口']
['B-1', 'I-1', 'I-1']
['B-2', 'I-2', 'I-2']
['张', '湾', '区']
['B-0', 'I-0', 'E-0']
['B-2', 'I-2', 'E-2']
['杜', '尔', '伯', '特', '蒙', '古', '族', '自', '治', '县']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-2', 'I-2', 'I-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '五', '家', '渠', '市']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'E-1']
['回', '民', '区']
[

 29%|███████████████████████▌                                                         | 78/268 [00:01<00:02, 73.40it/s]

['商', '州']
['B-1', 'I-1']
['B-2', 'I-2']
['宝', '鸡']
['B-2', 'I-2']
['B-1', 'I-1']
['济', '南', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['黑', '龙', '江', '哈', '尔', '滨']
['B-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'I-0', 'B-1', 'I-1', 'I-1']
['黑', '河', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['江', '西', '省', '景', '德', '镇', '昌', '江', '区']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['辽', '宁']
['B-1', 'I-1']
['B-0', 'I-0']
['广', '州', '南', '沙']
['B-0', 'I-0', 'B-1', 'I-1']
['B-1', 'I-1', 'B-2', 'I-2']
['宣', '威', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['海', '南', '省', '屯', '昌', '县']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'E-2']
['九', '江', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['兴', '安']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '省', '陵', '水', '黎', '族', '自', '治', '县']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1', 'E-1', 'E-1', '

 35%|████████████████████████████▍                                                    | 94/268 [00:01<00:02, 69.60it/s]

['安', '康']
['B-2', 'I-2']
['B-1', 'I-1']
['银', '州']
['B-1', 'I-1']
['B-2', 'I-2']
['新', '疆', '乌', '鲁', '木', '齐', '水', '磨', '沟']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2']
['邯', '山']
['B-1', 'I-1']
['B-2', 'I-2']
['广', '东', '省', '茂', '名']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '乌', '鲁', '木', '齐', '水', '磨', '沟', '区']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['河', '北', '秦', '皇', '岛', '山', '海', '关']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2']
['崇', '左', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '北', '屯']
['B-0', 'I-0', 'E-0', 'E-0', 'E-

 41%|█████████████████████████████████▏                                              | 111/268 [00:01<00:02, 70.68it/s]

['上', '海', '市']
['B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'E-0']
['新', '疆', '阿', '勒', '泰', '哈', '巴', '河', '县']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['青', '岛', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['新', '疆', '五', '家', '渠']
['B-0', 'I-0', 'B-2', 'I-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1']
['珠', '海']
['B-2', 'I-2']
['B-1', 'I-1']
['保', '定', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['海', '南', '省', '昌', '江']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['海', '南', '定', '安', '县']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['甘', '洛']
['B-1', 'I-1']
['B-2', 'I-2']
['玉', '树']
['B-2', 'I-2']
['B-1', 'I-1']
['新', '乡']
['B-2', 'I-2']
['B-1', 'I-1']
['张', '家', '口']
['B-2', 'I-2', 'I-2']
['B-1', 'I-1', 'I-1']
['四', '川', '省', '内', '江', '威', '远']
['B-0', 'I-0', 'E-0', 'B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'B-2',

 44%|███████████████████████████████████▌                                            | 119/268 [00:01<00:02, 73.04it/s]

['山', '西']
['B-2', 'I-2']
['B-0', 'I-0']
['乌', '鲁', '木', '齐', '水', '磨', '沟']
['B-1', 'I-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2']
['B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2']
['黄', '岩']
['B-1', 'I-1']
['B-2', 'I-2']
['池', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['林', '州', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']


 50%|████████████████████████████████████████                                        | 134/268 [00:02<00:02, 58.87it/s]

['上', '海']
['B-2', 'I-2']
['B-0', 'I-0']
['花', '莲']
['B-2', 'I-2']
['B-1', 'I-1']
['淄', '博']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '昌', '江', '黎', '族', '自', '治', '县']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['佛', '山', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['黄', '冈']
['B-2', 'I-2']
['B-1', 'I-1']
['芜', '湖']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '省', '定', '安', '县']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1']


 53%|██████████████████████████████████████████                                      | 141/268 [00:02<00:02, 58.39it/s]

['韶', '关', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['运', '城']
['B-2', 'I-2']
['B-1', 'I-1']
['图', '木', '舒', '克', '市']
['B-2', 'I-2', 'I-2', 'I-2', 'E-2']
['B-1', 'I-1', 'I-1', 'I-1', 'E-1']
['林', '内']
['B-1', 'I-1']
['B-2', 'I-2']
['新', '疆', '图', '木', '舒', '克']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2', 'I-2', 'I-2']
['海', '南', '省', '澄', '迈']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']


 57%|█████████████████████████████████████████████▉                                  | 154/268 [00:02<00:02, 48.76it/s]

['阳', '江', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['青', '岛']
['B-2', 'I-2']
['B-1', 'I-1']
['三', '亚']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '省', '陵', '水']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['甘', '肃', '省', '嘉', '峪', '关', '市']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'E-1']
['济', '源', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['海', '南', '省', '东', '方', '市']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'E-2']
['铜', '陵']
['B-2', 'I-2']
['B-1', 'I-1']
['新', '疆', '图', '木', '舒', '克', '市']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'E-1']
['B-0', 'I-0', 'B-2', 'I-2', 'I-2', 'I-2', 'E-2']
['宁', '夏']
['B-2', 'I-2']
['B-0', 'I-0']


 63%|██████████████████████████████████████████████████▋                             | 170/268 [00:02<00:02, 41.81it/s]

['四', '川', '内', '江', '市']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['河', '北', '秦', '皇', '岛', '山', '海', '关', '区']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['昭', '通']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '澄', '迈']
['B-0', 'I-0', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1']
['古', '塔']
['B-0', 'I-0']
['B-2', 'I-2']


 68%|██████████████████████████████████████████████████████                          | 181/268 [00:03<00:01, 46.41it/s]

['香', '港', '特', '别', '行', '政', '区']
['B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['蚌', '埠']
['B-2', 'I-2']
['B-1', 'I-1']
['丹', '江', '口', '市']
['B-1', 'I-1', 'I-1', 'E-1']
['B-2', 'I-2', 'I-2', 'E-2']
['六', '枝', '特']
['B-1', 'I-1', 'I-1']
['B-2', 'I-2', 'I-2']
['昆', '玉']
['B-2', 'I-2']
['B-1', 'I-1']
['延', '安']
['B-2', 'I-2']
['B-1', 'I-1']
['琼', '海']
['B-2', 'I-2']
['B-1', 'I-1']
['株', '洲']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '昌', '江']
['B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2']
['昌', '江']
['B-1', 'I-1']
['B-2', 'I-2']


 72%|█████████████████████████████████████████████████████████▎                      | 192/268 [00:03<00:01, 45.96it/s]

['大', '连']
['B-2', 'I-2']
['B-1', 'I-1']
['阿', '拉', '山', '口']
['B-1', 'I-1', 'B-2', 'I-2']
['B-2', 'I-2', 'I-2', 'I-2']
['绵', '竹', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['赞', '皇']
['B-1', 'I-1']
['B-2', 'I-2']
['西', '藏', '林', '芝', '工', '布', '江', '达', '县']
['B-0', 'I-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'I-2', 'E-2']
['海', '南', '省', '万', '宁']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']


 75%|████████████████████████████████████████████████████████████▎                   | 202/268 [00:03<00:01, 40.49it/s]

['赣', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['乌', '兰', '察', '布']
['B-1', 'I-1', 'B-2', 'I-2']
['B-1', 'I-1', 'I-1', 'I-1']
['丹', '阳', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['丹', '阳']
['B-1', 'I-1']
['B-2', 'I-2']
['海', '南']
['B-2', 'I-2']
['B-0', 'I-0']
['海', '南', '东', '方', '市']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2']
['工', '布', '江', '达', '县']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-2', 'I-2', 'I-2', 'I-2', 'E-2']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '北', '屯', '市']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-2', 'I-2', 'E-2']
['南', '宁', '西', '乡', '塘', '区']
['B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['新', '疆', '北', '屯']
['B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2']
['宜', '宾']
['B-2', 'I-2']
['B-1', 'I-1']


 77%|█████████████████████████████████████████████████████████████▊                  | 207/268 [00:03<00:01, 40.20it/s]

['乌', '苏', '市']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['曲', '靖', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['杜', '尔', '伯', '特']
['B-1', 'I-1', 'B-2', 'I-2']
['B-2', 'I-2', 'I-2', 'I-2']
['佛', '山']
['B-2', 'I-2']
['B-1', 'I-1']
['新', '疆', '五', '家', '渠', '市']
['B-0', 'I-0', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'E-1']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '五', '家', '渠']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-2', 'I-2', 'I-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1']
['花', '莲', '县']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['蚌', '埠', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']


 81%|████████████████████████████████████████████████████████████████▍               | 216/268 [00:04<00:01, 32.62it/s]

['晋', '中']
['B-2', 'I-2']
['B-1', 'I-1']
['广', '饶']
['B-1', 'I-1']
['B-2', 'I-2']
['宁', '德', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['锦', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['屯', '昌']
['B-2', 'I-2']
['B-1', 'I-1']
['陇', '南', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['泸', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['和', '林', '格', '尔', '县']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-2', 'I-2', 'I-2', 'I-2', 'E-2']


 84%|██████████████████████████████████████████████████████████████████▊             | 224/268 [00:04<00:01, 34.25it/s]

['黑', '河']
['B-2', 'I-2']
['B-1', 'I-1']
['三', '亚', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['白', '沙', '黎', '族', '自', '治', '县']
['B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['西', '藏', '林', '芝', '工', '布', '江', '达']
['B-0', 'I-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'I-2']
['陕', '西', '安', '康']
['B-0', 'I-0', 'B-2', 'I-2']
['B-0', 'I-0', 'B-1', 'I-1']
['河', '北', '省', '秦', '皇', '岛', '山', '海', '关', '区']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['温', '州']
['B-2', 'I-2']
['B-1', 'I-1']
['盘', '锦']
['B-2', 'I-2']
['B-1', 'I-1']


 89%|███████████████████████████████████████████████████████████████████████         | 238/268 [00:04<00:00, 49.53it/s]

['广', '东', '茂', '名', '市']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['广', '西', '壮', '族', '自', '治', '区', '崇', '左']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1']
['新', '疆', '北', '屯', '市']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '乌', '鲁', '木', '齐', '水', '磨', '沟']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2']
['宁', '夏', '回', '族', '自', '治', '区']
['B-2', 'I-2', 'E-2', 'E-2', 'E-0', 'E-0', 'E-0']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0']
['铁', '门', '关']
['B-2', 'I-2', 'I-2']
['B-1', 'I-1', 'I-1']
['延', '安', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['丁', '青', '县']
['B-1', 'I-1', 'E-1']
['B-2', 'I-2', 'E-2']
['香', '港']
['B-2', 'I-2']


 95%|███████████████████████████████████████████████████████████████████████████▊    | 254/268 [00:04<00:00, 59.65it/s]

['喀', '什']
['B-2', 'I-2']
['B-1', 'I-1']
['海', '南', '陵', '水', '黎', '族', '自', '治', '县']
['B-0', 'I-0', 'B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['B-0', 'I-0', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['海', '南', '省', '乐', '东', '黎', '族', '自', '治', '县']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1', 'E-1', 'E-1', 'E-1', 'E-1', 'E-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['怀', '化', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I-1', 'E-1']
['江', '西', '景', '德', '镇', '昌', '江', '区']
['B-0', 'I-0', 'B-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['B-0', 'I-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['南', '充']
['B-2', 'I-2']
['B-1', 'I-1']
['新', '疆', '维', '吾', '尔', '自', '治', '区', '阿', '勒', '泰', '哈', '巴', '河', '县']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0', 'B-1', 'I-1', 'I-1', 'B-2', 'I-2', 'I-2', 'E-2']
['新', '乡', '市']
['B-2', 'I-2', 'E-2']
['B-1', 'I

 97%|█████████████████████████████████████████████████████████████████████████████▉  | 261/268 [00:04<00:00, 56.15it/s]

['海', '南', '屯', '昌']
['B-0', 'I-0', 'B-1', 'I-1']
['B-0', 'I-0', 'B-2', 'I-2']
['防', '城', '港', '市']
['B-2', 'I-2', 'I-2', 'E-2']
['B-1', 'I-1', 'I-1', 'E-1']
['海', '南', '省', '东', '方']
['B-0', 'I-0', 'E-0', 'B-1', 'I-1']
['B-0', 'I-0', 'E-0', 'B-2', 'I-2']
['四', '方', '台']
['B-1', 'I-1', 'I-1']
['B-2', 'I-2', 'I-2']
['香', '港', '特', '别', '行', '政', '区']
['B-2', 'I-2', 'E-2', 'E-2', 'E-2', 'E-2', 'E-2']
['B-0', 'I-0', 'E-0', 'E-0', 'E-0', 'E-0', 'E-0']
['阿', '拉', '山', '口', '市']
['B-1', 'I-1', 'B-2', 'I-2', 'E-2']
['B-2', 'I-2', 'I-2', 'I-2', 'E-2']
['铁', '门', '关', '市']
['B-2', 'I-2', 'I-2', 'E-2']
['B-1', 'I-1', 'I-1', 'E-1']


100%|████████████████████████████████████████████████████████████████████████████████| 268/268 [00:05<00:00, 52.12it/s]

['白', '沙']
['B-2', 'I-2']
['B-1', 'I-1']





In [405]:
df[df["name"].str.contains("乐东")]

Unnamed: 0,id,pid,deep,name,pinyin_prefix,pinyin,ext_id,ext_name,suffix
2234,469027,46,1,乐东,l,le dong,469027000000,乐东黎族自治县,黎族自治县
2235,469027000,469027,2,乐东,l,le dong,469027000000,乐东黎族自治县,黎族自治县


In [255]:
def sent_to_vector(txt):
    vector = torch.tensor([word_to_ix[cha] for cha in txt])
    return vector.view(1, -1)

In [273]:
input_vector = sent_to_vector("枣庄薛城区")
type(model(batch["input_ids"])[1])

list

In [285]:
batch["labels"].tolist()

[[0, 1, 2, 2, 2, 2, 2, 2, 3, 4, 5, 5, 6, 7, 9],
 [0, 1, 2, 3, 4, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 7],
 [0, 1, 1, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 9, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9],
 [0, 1, 3, 4, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 6, 7, 8, 9, 9, 9, 9, 9, 9, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9, 9, 9],
 [0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9,

In [319]:
ord("a")

97