In [1]:
import os
import nltk
import torch
import random
import argparse
import warnings
import numpy as np
import utils as utils
warnings.filterwarnings('ignore')


# Hyper Params

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='msra')
parser.add_argument('--seed', default=1234)
parser.add_argument('--store_dir', default=None)

parser.add_argument('--batch_size', default=32)
parser.add_argument('--max_len', default=128)
parser.add_argument('--patience', default=0.02)
parser.add_argument('--patience_num', default=5)

parser.add_argument('--device', default=None)

args = parser.parse_args([])
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_params_dir = 'experiments/' + args.dataset
json_path = os.path.join(model_params_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
params = utils.Params(json_path)
params.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
params.seed = args.seed

data_dir = 'data/' + args.dataset
if args.dataset == 'msra':
    bert_class = 'bert-base-chinese'
else:
    bert_class = 'bert-base-cased'

In [3]:
# set seed for code
random.seed(args.seed)
torch.manual_seed(args.seed)
# params.seed = args.seed


<torch._C.Generator at 0x7f825c82cf50>

In [4]:
import numpy as np
from transformers import BertTokenizer

class DataLoader(object):
    def __init__(self, data_dir, bert_class, args, token_pad_idx=0, tag_pad_idx=1):
        self.data_dir = data_dir
        self.batch_size = args.batch_size
        self.max_len = args.max_len
        self.device = args.device
        self.seed = args.seed
        self.token_pad_idx = token_pad_idx
        self.tag_pad_idx = tag_pad_idx
        
        tags = self.load_tags()
        self.tag2idx = {tag: idx for idx, tag in enumerate(tags)}
        self.idx2tag = {idx: tag for idx, tag in enumerate(tags)}
        
        args.tag2idx = self.tag2idx
        args.idx2tag = self.idx2tag
        
        self.tokenizer = BertTokenizer.from_pretrained(bert_class, do_lower_case=False)
        
    def load_tags(self):
        tags  = []
        tags_path = os.path.join(self.data_dir, 'tags.txt')
        
        with open(tags_path, 'r') as file:
            for tag in file:
                tags.append(tag.strip())
        return tags
    
    def load_sentence_tags(self, sentence_path, tags_path, data={}):
        sentences = []
        tags = []
        
        with open(sentence_path, 'r') as file:
            for line in file:
                tokens = line.strip().split(' ')
                subwords = list(map(self.tokenizer.tokenize, tokens))
                subword_lengths = list(map(len, subwords))
                subwords = ['[CLS]'] + [item for indices in subwords for item in indices]
                # indice words except [CLS]
                token_start_idxs = list(range(1,len(subwords)))
                
                bert_tokens = self.tokenizer.convert_tokens_to_ids(subwords)
                sentences.append((bert_tokens, token_start_idxs))
                # len(bert_tokens) - len(token_start_idxs) = 1
  
                
        if tags_path != None:
            with open(tags_path, 'r') as file:
                for line in file:
                    tag_seq = [self.tag2idx.get(tag) for tag in line.strip().split(' ')]
                    tags.append(tag_seq)
            
            # Check the corresponding between sentences and tags
            assert len(sentences) == len(tags)
            for i in range(len(tags)):
                assert len(tags[i]) == len(sentences[i][0])-1
            data['tags'] = tags
            
        data['sentences'] = sentences
        data['size'] = len(sentences)
        
    def load_data(self, data_class):
        data = {}
        
        if data_class in ['train', 'val', 'test']:
            sentence_path = os.path.join(data_dir, data_class, 'sentences.txt')
            tags_path = os.path.join(data_dir, data_class, 'tags.txt')
            
            self.load_sentence_tags(sentence_path, tags_path, data)
        
        elif data_class == 'interactive':
            sentence_path = os.path.join(data_dir, data_class, 'sentences.txt')
            tags_path=None
            self.load_sentence_tags(sentence_path, tags_path, data)
            
        else:
            raise ValueError("No data in train/val/test or interactve!")
        
        return data
    
    def data_iterator(self, data, shuffle=False):
        order = list(range(data['size']))
        if shuffle:
            random.seed(self.seed)
            random.shuffle(order)
        InterModel = False if 'tags' in data else True
        
        if data['size'] % self.batch_size == 0:
            BATCH_SIZE = data['size'] // self.batch_size
        else:
            BATCH_SIZE = data['size'] // self.batch_size + 1
        
        for i in range(BATCH_SIZE):
            # fetch sentences and tags
            if i * self.batch_size < data['size'] < (i+1) * self.batch_size:
                sentences = [data['sentences'][idx] for idx in order[i*self.batch_size:]]
                if not InterModel:
                    tags = [data['tags'][idx] for idx in order[i*self.batch_size:]]
            else:
                sentences = [data['sentences'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
                if not InterModel:
                    tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]

            # batch length
            batch_len = len(sentences)

            # compute length of longest sentence in batch
            batch_max_subwords_len = max([len(s[0]) for s in sentences])
            max_subwords_len = min(batch_max_subwords_len, self.max_len)
            max_token_len = 0


            # prepare a numpy array with the data, initialising the data with pad_idx
            batch_data = self.token_pad_idx * np.ones((batch_len, max_subwords_len))
            batch_token_starts = []
            
            # copy the data to the numpy array
            for j in range(batch_len):
                cur_subwords_len = len(sentences[j][0])
                if cur_subwords_len <= max_subwords_len:
                    batch_data[j][:cur_subwords_len] = sentences[j][0]
                else:
                    batch_data[j] = sentences[j][0][:max_subwords_len]
                token_start_idx = sentences[j][-1]
                token_starts = np.zeros(max_subwords_len)
                token_starts[[idx for idx in token_start_idx if idx < max_subwords_len]] = 1
                batch_token_starts.append(token_starts)
                max_token_len = max(int(sum(token_starts)), max_token_len)
            
            if not InterModel:
                batch_tags = self.tag_pad_idx * np.ones((batch_len, max_token_len))
                for j in range(batch_len):
                    cur_tags_len = len(tags[j])  
                    if cur_tags_len <= max_token_len:
                        batch_tags[j][:cur_tags_len] = tags[j]
                    else:
                        batch_tags[j] = tags[j][:max_token_len]
            
            # since all data are indices, we convert them to torch LongTensors
            batch_data = torch.tensor(batch_data, dtype=torch.long)
            batch_token_starts = torch.tensor(batch_token_starts, dtype=torch.long)
            if not InterModel:
                batch_tags = torch.tensor(batch_tags, dtype=torch.long)

            # shift tensors to GPU if available
            batch_data, batch_token_starts = batch_data.to(self.device), batch_token_starts.to(self.device)
            if not InterModel:
                batch_tags = batch_tags.to(self.device)
                yield batch_data, batch_token_starts, batch_tags
            else:
                yield batch_data, batch_token_starts

data_loader = DataLoader(data_dir, bert_class, args, token_pad_idx=0, tag_pad_idx=-1)
idx2tag = data_loader.idx2tag

# Load Pre-Trained Model

In [5]:
from SequenceTagger import BertForSequenceTagging

model = BertForSequenceTagging.from_pretrained(model_params_dir)
model.to(args.device)

BertForSequenceTagging(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [6]:
# Interactive Function
from metrics import get_entities

def BertNerResponse(model, querystring):
    # gain word-level sentence
    querystring = [i for i in querystring]
    
    with open('data/' + args.dataset + '/interactive/sentences.txt', 'w') as f:
        f.write(' '.join(querystring))
    
    inter_data = data_loader.load_data('interactive')
    inter_data_iterator = data_loader.data_iterator(inter_data, shuffle=False)
    
    model.eval()
    
    batch_data, batch_token_starts = next(inter_data_iterator)
    batch_masks = batch_data.gt(0)
    
    batch_outs = model((batch_data, batch_token_starts), token_type_ids=None, attention_mask=batch_masks)[0]
    batch_outs = batch_outs.detach().cpu().numpy()
    
    pred_tags = []
    pred_tags.extend([[idx2tag[idx] for idx in indices] for indices in np.argmax(batch_outs, axis=2)]) 
    result = get_entities(pred_tags)
    
    res = []
    for item in result:
        res.append((''.join(querystring[item[1]:item[2]+1]), item[0]))
    return res


# MAIN

In [7]:
while True:
    query = input('Input:')
    if query == 'exit':
        break
    print(BertNerResponse(model, query))

Input:郑雨柔是个小傻瓜
[('郑雨柔', 'PER')]
Input:无时无刻，天安门广场飘扬的五星红旗都是我们的希望
[('天安门广场', 'LOC')]
Input:exit
