In [1]:
import os
import torch
import random
import argparse
import warnings
warnings.filterwarnings("ignore")


# Hyper Params

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='msra')
parser.add_argument('--seed', default=1234)
parser.add_argument('--store_dir', default=None)

parser.add_argument('--max_epoch_num', default=20)
parser.add_argument('--min_epoch_num', default=5)
parser.add_argument('--batch_size', default=32)
parser.add_argument('--max_len', default=128)
parser.add_argument('--patience', default=0.02)
parser.add_argument('--patience_num', default=5)

parser.add_argument('--full_finetuning', default=True)
parser.add_argument('--learning_ratio', default=3e-5)
parser.add_argument('--weight_decay', default=0.01)
parser.add_argument('--clip_grad', default=5)

parser.add_argument('--device', default=None)

args = parser.parse_args([])
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_params_dir = 'experiments/' + args.dataset
json_path = os.path.join(model_params_dir, 'params.json')
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
# params = utils.Params(json_path)

data_dir = 'data/' + args.dataset
if args.dataset == 'msra':
    bert_class = 'bert-base-chinese'
else:
    bert_class = 'bert-base-cased'

In [3]:
# set seed for random parts
random.seed(args.seed)
torch.manual_seed(args.seed)
# params.seed = args.seed

<torch._C.Generator at 0x7f95c7455770>

# Define dataloader for following step

In [4]:
import numpy as np
# ! pip install transformers
from transformers import BertTokenizer

class DataLoader(object):
    def __init__(self, data_dir, bert_class, args, token_pad_idx=0, tag_pad_idx=1):
        self.data_dir = data_dir
        self.batch_size = args.batch_size
        self.max_len = args.max_len
        self.device = args.device
        self.seed = args.seed
        self.token_pad_idx = token_pad_idx
        self.tag_pad_idx = tag_pad_idx
        
        tags = self.load_tags()
        self.tag2idx = {tag: idx for idx, tag in enumerate(tags)}
        self.idx2tag = {idx: tag for idx, tag in enumerate(tags)}
        
        args.tag2idx = self.tag2idx
        args.idx2tag = self.idx2tag
        
        self.tokenizer = BertTokenizer.from_pretrained(bert_class, do_lower_case=False)
        
    def load_tags(self):
        tags  = []
        tags_path = os.path.join(self.data_dir, 'tags.txt')
        
        with open(tags_path, 'r') as file:
            for tag in file:
                tags.append(tag.strip())
        return tags
    
    def load_sentence_tags(self, sentence_path, tags_path, data={}):
        sentences = []
        tags = []
        
        with open(sentence_path, 'r') as file:
            for line in file:
                tokens = line.strip().split(' ')
                subwords = list(map(self.tokenizer.tokenize, tokens))
                subword_lengths = list(map(len, subwords))
                subwords = ['[CLS]'] + [item for indices in subwords for item in indices]
                # indice words except [CLS]
                token_start_idxs = list(range(1,len(subwords)))
                
                bert_tokens = self.tokenizer.convert_tokens_to_ids(subwords)
                sentences.append((bert_tokens, token_start_idxs))
                # len(bert_tokens) - len(token_start_idxs) = 1
  
                
        if tags_path != None:
            with open(tags_path, 'r') as file:
                for line in file:
                    tag_seq = [self.tag2idx.get(tag) for tag in line.strip().split(' ')]
                    tags.append(tag_seq)
            
            # Check the corresponding between sentences and tags
            assert len(sentences) == len(tags)
            for i in range(len(tags)):
                assert len(tags[i]) == len(sentences[i][0])-1
        
        data['sentences'] = sentences
        data['tags'] = tags
        data['size'] = len(sentences)
        
    def load_data(self, data_class):
        data = {}
        
        if data_class in ['train', 'val', 'test']:
            sentence_path = os.path.join(data_dir, data_class, 'sentences.txt')
            tags_path = os.path.join(data_dir, data_class, 'tags.txt')
            
            self.load_sentence_tags(sentence_path, tags_path, data)
        
        elif data_class == 'interactive':
            sentence_path = os.path.join(sentence_path, data_class, 'sentences.txt')
            tags_path=None
            self.load_sentence_tags(sentence_path, tags_path, data)
            
        else:
            raise ValueError("No data in train/val/test or interactve!")
        
        return data
    
    def data_iterator(self, data, shuffle=False):
        order = list(range(data['size']))
        if shuffle:
            random.seed(self.seed)
            random.shuffle(order)
        InterModel = False if 'tags' in data else True
        
        if data['size'] % self.batch_size == 0:
            BATCH_SIZE = data['size'] // self.batch_size
        else:
            BATCH_SIZE = data['size'] // self.batch_size + 1
        
        for i in range(BATCH_SIZE):
            # fetch sentences and tags
            if i * self.batch_size < data['size'] < (i+1) * self.batch_size:
                sentences = [data['sentences'][idx] for idx in order[i*self.batch_size:]]
                if not InterModel:
                    tags = [data['tags'][idx] for idx in order[i*self.batch_size:]]
            else:
                sentences = [data['sentences'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]
                if not InterModel:
                    tags = [data['tags'][idx] for idx in order[i*self.batch_size:(i+1)*self.batch_size]]

            # batch length
            batch_len = len(sentences)

            # compute length of longest sentence in batch
            batch_max_subwords_len = max([len(s[0]) for s in sentences])
            max_subwords_len = min(batch_max_subwords_len, self.max_len)
            max_token_len = 0


            # prepare a numpy array with the data, initialising the data with pad_idx
            batch_data = self.token_pad_idx * np.ones((batch_len, max_subwords_len))
            batch_token_starts = []
            
            # copy the data to the numpy array
            for j in range(batch_len):
                cur_subwords_len = len(sentences[j][0])
                if cur_subwords_len <= max_subwords_len:
                    batch_data[j][:cur_subwords_len] = sentences[j][0]
                else:
                    batch_data[j] = sentences[j][0][:max_subwords_len]
                token_start_idx = sentences[j][-1]
                token_starts = np.zeros(max_subwords_len)
                token_starts[[idx for idx in token_start_idx if idx < max_subwords_len]] = 1
                batch_token_starts.append(token_starts)
                max_token_len = max(int(sum(token_starts)), max_token_len)
            
            if not InterModel:
                batch_tags = self.tag_pad_idx * np.ones((batch_len, max_token_len))
                for j in range(batch_len):
                    cur_tags_len = len(tags[j])  
                    if cur_tags_len <= max_token_len:
                        batch_tags[j][:cur_tags_len] = tags[j]
                    else:
                        batch_tags[j] = tags[j][:max_token_len]
            
            # since all data are indices, we convert them to torch LongTensors
            batch_data = torch.tensor(batch_data, dtype=torch.long)
            batch_token_starts = torch.tensor(batch_token_starts, dtype=torch.long)
            if not InterModel:
                batch_tags = torch.tensor(batch_tags, dtype=torch.long)

            # shift tensors to GPU if available
            batch_data, batch_token_starts = batch_data.to(self.device), batch_token_starts.to(self.device)
            if not InterModel:
                batch_tags = batch_tags.to(self.device)
                yield batch_data, batch_token_starts, batch_tags
            else:
                yield batch_data, batch_token_starts

                
data_loader = DataLoader(data_dir, bert_class, args, token_pad_idx=0, tag_pad_idx=-1)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=109540.0), HTML(value='')))




In [5]:
import time
load_proc_start = time.time()

train = data_loader.load_data('train')
val = data_loader.load_data('val')
test = data_loader.load_data('test')
load_proc_time = time.time() - load_proc_start
lp_mins, lp_secs = load_proc_time/60, load_proc_time%60
print('Load and Processing data cost: {0}m {1:2f}s.'.format(int(lp_mins), lp_secs))

args.train_size = train['size']
args.val_size = val['size']
args.test_size = test['size']

Load and Processing data cost: 0m 59.781951s.


# Model Structure and optimizers

In [6]:
from SequenceTagger import BertForSequenceTagging
from transformers.optimization import get_linear_schedule_with_warmup, AdamW

model = BertForSequenceTagging.from_pretrained(bert_class, num_labels=len(args.tag2idx))
model.to(args.device)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=624.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=411577189.0), HTML(value='')))




Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceTagging: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceTagging from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceTagging from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceTagging were not initialized from the model checkpoint at bert-base-chinese and are n

BertForSequenceTagging(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [7]:
# finetuing whole model or only classifier
if args.full_finetuning:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 
         'weight_decay': args.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{'params': [p for n, p in param_optimizer]}]

optimizer = AdamW(optimizer_grouped_parameters,
                  lr=args.learning_ratio,
                  correct_bias=False)
train_steps = args.train_size // args.batch_size
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=train_steps, num_training_steps=args.max_epoch_num * train_steps)

# Train Step

In [8]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict

import numpy as np


def get_entities(seq, suffix=False):
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]
    prev_tag = 'O'
    prev_type = ''
    begin_offset = 0
    chunks = []
    for i, chunk in enumerate(seq + ['O']):
        if suffix:
            tag = chunk[-1]
            type_ = chunk.split('-')[0]
        else:
            tag = chunk[0]
            type_ = chunk.split('-')[-1]

        if end_of_chunk(prev_tag, tag, prev_type, type_):
            chunks.append((prev_type, begin_offset, i-1))
        if start_of_chunk(prev_tag, tag, prev_type, type_):
            begin_offset = i
        prev_tag = tag
        prev_type = type_

    return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
    chunk_end = False

    if prev_tag == 'E': chunk_end = True
    if prev_tag == 'S': chunk_end = True

    if prev_tag == 'B' and tag == 'B': chunk_end = True
    if prev_tag == 'B' and tag == 'S': chunk_end = True
    if prev_tag == 'B' and tag == 'O': chunk_end = True
    if prev_tag == 'I' and tag == 'B': chunk_end = True
    if prev_tag == 'I' and tag == 'S': chunk_end = True
    if prev_tag == 'I' and tag == 'O': chunk_end = True

    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
        chunk_end = True

    return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
    chunk_start = False

    if tag == 'B': chunk_start = True
    if tag == 'S': chunk_start = True

    if prev_tag == 'E' and tag == 'E': chunk_start = True
    if prev_tag == 'E' and tag == 'I': chunk_start = True
    if prev_tag == 'S' and tag == 'E': chunk_start = True
    if prev_tag == 'S' and tag == 'I': chunk_start = True
    if prev_tag == 'O' and tag == 'E': chunk_start = True
    if prev_tag == 'O' and tag == 'I': chunk_start = True

    if tag != 'O' and tag != '.' and prev_type != type_:
        chunk_start = True

    return chunk_start


def f1_score(y_true, y_pred, average='micro', digits=2, suffix=False):
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)

    p = 100 * nb_correct / nb_pred if nb_pred > 0 else 0
    r = 100 * nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0

    return score


def accuracy_score(y_true, y_pred):
    if any(isinstance(s, list) for s in y_true):
        y_true = [item for sublist in y_true for item in sublist]
        y_pred = [item for sublist in y_pred for item in sublist]

    nb_correct = sum(y_t==y_p for y_t, y_p in zip(y_true, y_pred))
    nb_true = len(y_true)

    score = nb_correct / nb_true

    return score



In [12]:
import torch.nn as nn

if args.store_dir is not None:
    model = BertForSequenceTagging.from_pretrained(model_params_dir)

best_val_f1 = 0.0
patience_counter = 0

for epoch in range(1, args.max_epoch_num +1):
    epoch_start = time.time()
    
    print('\n*****',' Train Epoch {0}/{1} '.format(epoch, args.max_epoch_num), '*****')
    
    train_steps = args.train_size // args.batch_size
    val_steps = args.val_size // args.batch_size
    
    train_data_iterator = data_loader.data_iterator(train, shuffle=True)
    val_data_iterator = data_loader.data_iterator(val, shuffle=True)
    
    # Train step
    model.train()
    epoch_loss, epoch_avg_loss = 0.0, 0.0

    for batch in range(1, train_steps+1):
        batch_data, batch_token_starts, batch_tags = next(train_data_iterator)
        batch_masks = batch_data.gt(0)
    
        loss = model((batch_data, batch_token_starts), token_type_ids=None, attention_mask=batch_masks, labels=batch_tags)[0]
        
        model.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=args.clip_grad)
        
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item()
        epoch_avg_loss = epoch_loss / (batch+1)
        
        if batch % 50 == 0:
            print('| Train_epoch: {0} | Train_batch: {1}/{2} | Train_batch_loss: {3:4f} |'.format(epoch, batch, train_steps, epoch_avg_loss))
    
    # Val step
    model.eval()
    val_loss = 0.0
    true_tags, pred_tags = [], []
    
    for _ in range(val_steps):
        batch_data, batch_token_starts, batch_tags = next(val_data_iterator)
        batch_masks = batch_data.gt(0)
        
        loss = model((batch_data, batch_token_starts), token_type_ids=None, attention_mask=batch_masks, labels=batch_tags)[0]
        val_loss += loss.item()
        
        batch_outs = model((batch_data, batch_token_starts), token_type_ids=None, attention_mask=batch_masks)[0]
        batch_outs = batch_outs.detach().cpu().numpy()
        batch_tags = batch_tags.to('cpu').numpy()
        
        pred_tags.extend([[args.idx2tag.get(idx) for idx in indices] for indices in np.argmax(batch_outs, axis=2)])
        true_tags.extend([[args.idx2tag.get(idx) if idx != -1 else 'O' for idx in indices] for indices in batch_tags])
    
    assert len(pred_tags) == len(true_tags)

    f1 = f1_score(true_tags, pred_tags)
    loss = val_loss / val_steps
    acc = accuracy_score(true_tags, pred_tags)
    epoch_time = time.time() - epoch_start
    epoch_mins, epoch_secs = int(epoch_time/60), int(epoch_time%60)
    
    print('| Train_epoch: {0} | Train_loss: {2:4f} | Train&Val_time: {3}m {4}s |'.format(epoch, args.max_epoch_num, epoch_avg_loss, epoch_mins, epoch_secs))
    print('                 | Valid_loss: {0:4f} | Valid_f1: {1:4f} | Valid_Acc: {2:2f}% |'.format(loss, f1, 100*acc))

    if f1 - best_val_f1 >= 0.1:
        best_val_f1 = f1
        
        model_to_save = model.module if hasattr(model, 'module') else model
        model_path = model_params_dir + '/pytorch_model.bin'
        config_path = model_params_dir + '/config.json'
        
        torch.save(model_to_save.state_dict(), model_path, _use_new_zipfile_serialization=False)
        model_to_save.config.to_json_file(config_path)


*****  Train Epoch 1/20  *****
| Train_epoch: 1 | Train_loss: 1.004717 | Train&Val_time: 0m 14s |
                 | Valid_loss: 0.029796 | Valid_f1: 0.845309 | Valid_Acc: 64.592890% |

*****  Train Epoch 2/20  *****
| Train_epoch: 2 | Train_loss: 1.001204 | Train&Val_time: 0m 15s |
                 | Valid_loss: 0.029559 | Valid_f1: 0.684346 | Valid_Acc: 64.879587% |

*****  Train Epoch 3/20  *****


KeyboardInterrupt: 