In [4]:
!pip install -q -r requirements.txt

## Import Libraries

In [5]:
import gc
import logging
import warnings
import itertools
import multiprocessing
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from transformers import (
    BertForTokenClassification, 
    BertJapaneseTokenizer, 
    get_linear_schedule_with_warmup
    )

warnings.filterwarnings('ignore')
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
logging.getLogger('transformers').setLevel(logging.ERROR)

In [6]:
class Cfg:
    debug = False
    seed = 42
    epochs = 5
    lr = 1e-5
    weight_decay = 1e-2
    max_len = 193
    n_folds = 5
    num_entities = 8
    train_batch_size = 32
    val_batch_size = 256
    group_col = 'curid'
    label_col = 'label'
    model_name = 'cl-tohoku/bert-base-japanese-whole-word-masking' 
    n_gpus = torch.cuda.device_count()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
seed_everything(Cfg.seed, workers=True)

42

In [8]:
data = pd.read_csv(
    'data/preprocessed_data.csv',
    dtype={
        'curid': object,
        'text_body': object,
        'text': object,
        'label': np.int32
    })

In [9]:
if Cfg.debug:
    data = data.sample(1000, random_state=Cfg.seed)

In [10]:
def init_logger(file_path):
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(logging.Formatter('%(message)s'))
    file_handler = logging.FileHandler(filename=file_path)
    file_handler.setFormatter(logging.Formatter('%(message)s'))
    logger.addHandler(stream_handler)
    logger.addHandler(file_handler)
    return logger

## Create Dataset

In [11]:
class NERTokenizer(BertJapaneseTokenizer):
    def bio_tagger(self, text, label, num_entities):
        '''
        IO法でラベリングされているテキストに、
        トークナイズと合わせて、BIO法を適用する
        '''
        tokens = self.tokenize(text)
        if label > 0:
            labels = [label + num_entities] * len(tokens)
            labels[0] = label
        else:
            labels = [0] * len(tokens)
        return tokens, labels

    def encode_plus_tagged(self, text, label, max_length, num_entities):
        '''
        トークナイズ結果に合わせてラベル付けをし、
        エンコーディング
        '''
        token_arr, label_arr = [], []
        tokens, labels = self.bio_tagger(text, label, num_entities)
        token_arr.extend(tokens)
        label_arr.extend(labels)

        input_ids = self.convert_tokens_to_ids(token_arr)
        encoded = self.prepare_for_model(input_ids,
                                         max_length=max_length,
                                         padding='max_length',
                                         truncation=True)
        # [CLS], [SEP], [PAD]のラベルを0として追加
        label_arr = [0] + label_arr[:max_length-2] + [0]
        encoded['labels'] = label_arr + [0] * (max_length - len(label_arr))
        return encoded

    def encode_plus_untagged(self, text_body, text, max_length):
        '''
        トークナイズとスパン取得を行い、
        エンコーディング
        '''
        tokens, tokens_for_spans = [], []
        words = self.word_tokenizer.tokenize(text)
        for word in words:
            subwords = self.subword_tokenizer.tokenize(word)
            tokens.extend(subwords)
            if subwords[0] == '[UNK]':
                tokens_for_spans.append(word)
            else:
                tokens_for_spans.extend([subword.replace('##','') for subword in subwords])

        pos = 0
        spans = []
        for token in tokens_for_spans:
            token_len = len(token)
            while True:
                if token != text_body[pos:pos+token_len]:
                    pos += 1
                else:
                    spans.append([pos, pos+token_len])
                    pos += token_len
                    break

        input_ids = self.convert_tokens_to_ids(tokens)
        encoded = self.prepare_for_model(input_ids,
                                         max_length=max_length,
                                         padding='max_length',
                                         truncation=True)
        # [CLS], [SEP], [PAD]に対応するスパン追加
        n_seq = len(encoded['input_ids'])
        spans = [[-1, -1]] + spans[:n_seq-2]
        spans = spans + [[-1, -1]] * (n_seq - len(spans))
        
        encoded = {k: torch.tensor([v]) for k, v in encoded.items()}
        return encoded, spans

    @staticmethod
    def viterbi_optimizer(preds, num_entities, penalty=10000):
        '''
        BIO法のルールに従わない予測ラベル列に、
        ペナルティを与えて、予測値を最適化する
        '''
        m = 2 * num_entities + 1
        penalty_matrix = np.zeros([m,m])
        for i in range(m):
            for j in range(num_entities+1, m):
                if not ((i == j) or (num_entities+i == j)):
                    penalty_matrix[i,j] = penalty
        
        path = [[i] for i in range(m)]
        preds_path = preds[0] - penalty_matrix[0,:]
        preds = preds[1:]

        for pred in preds:
            assert len(pred) == 2 * num_entities + 1
            pred_matrix = np.array(preds_path).reshape(-1,1) + np.array(pred).reshape(1,-1)
            pred_matrix -= penalty_matrix
            preds_path = pred_matrix.max(axis=0)
            pred_argmax = pred_matrix.argmax(axis=0)
            path = [path[idx]+[i] for i, idx in enumerate(pred_argmax)]

        optimized_preds = path[np.argmax(preds_path)]
        return optimized_preds

    def convert_bert_output_to_entities(self, text_body, preds, spans, num_entities):
        '''
        同じラベルが連続するトークンをまとめて、
        固有表現として抽出する
        '''
        assert len(spans) == len(preds)
        # [CLS], [SEP], [PAD]に対応する箇所を削除
        preds = [pred for pred, span in zip(preds, spans) if span[0] != -1]
        spans = [span for span in spans if span[0] != -1]
        preds = self.viterbi_optimizer(preds, num_entities)
        
        entities = []
        for pred, group in itertools.groupby(enumerate(preds), key=lambda x: x[1]):
            group = list(group)
            start = spans[group[0][0]][0]
            end = spans[group[-1][0]][1]

            if pred != 0:
                # Bならば
                if 1 <= pred <= num_entities:
                    entity = {
                        'name': text_body[start:end],
                        'span': [start, end],
                        'type_id': pred
                    }
                    entities.append(entity)
                # Iならば
                else:
                    entity['span'][1] = end
                    entity['name'] = text_body[entity['span'][0]:entity['span'][1]]
        return entities

In [12]:
class NERDataset(Dataset):
    def __init__(self, data, tokenizer, config):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.config = config

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row['text']
        label = data_row['label']
        
        encoded = self.tokenizer.encode_plus_tagged(
            text, label, self.config.max_len, self.config.num_entities
            )
        encoded = {k: torch.tensor(v) for k, v in encoded.items()}
        return {
            'input_ids': encoded['input_ids'].flatten(),
            'token_type_ids': encoded['token_type_ids'].flatten(),
            'attention_mask': encoded['attention_mask'].flatten(),
            'labels': encoded['labels'].flatten()
        }


class NERDataModule(pl.LightningDataModule):
    def __init__(self, train_data, val_data, tokenizer, config):
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.tokenizer = tokenizer
        self.config = config

    def create_dataset(self, mode):
        return (
            NERDataset(self.train_data, self.tokenizer, self.config)
            if mode == 'train'
            else NERDataset(self.val_data, self.tokenizer, self.config)
        )

    def train_dataloader(self):
        train_ds = self.create_dataset(mode='train')
        train_loader = DataLoader(train_ds,
                                  batch_size=self.config.train_batch_size,
                                  num_workers=multiprocessing.cpu_count(),
                                  pin_memory=True,
                                  drop_last=True,
                                  shuffle=True)
        return train_loader

    def val_dataloader(self):
        val_ds = self.create_dataset(mode='val')
        val_loader = DataLoader(val_ds,
                                batch_size=self.config.val_batch_size,
                                num_workers=multiprocessing.cpu_count(),
                                pin_memory=True,
                                drop_last=False,
                                shuffle=False)
        return val_loader

## Create Model

#### BERT Tagger

In [13]:
class NERModel(pl.LightningModule):
    def __init__(self, config, num_training_steps):
        super().__init__()
        self.config = config
        self.num_training_steps = num_training_steps
        self.bert = BertForTokenClassification.from_pretrained(
            self.config.model_name,
            num_labels=2 * self.config.num_entities + 1
            )
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
        preds = self.bert(
            input_ids, 
            token_type_ids=token_type_ids, 
            attention_mask=attention_mask
            )
        if labels is not None:
            loss = self.criterion(preds, labels)
            return loss
        return  preds
    
    def training_step(self, batch, batch_idx):
        output = self.bert(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        output = self.bert(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)
        
    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.bert.named_parameters()
                            if not any(nd in n for nd in no_decay)],
                'weight_decay': self.config.weight_decay
            },
            {
                'params': [p for n, p in self.bert.named_parameters()
                            if any(nd in n for nd in no_decay)],
                'weight_decay': 0.0
            }
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.config.lr)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=self.num_training_steps
        )
        return [optimizer], [scheduler]

## FineTuning -> Inference -> Evaluation

In [14]:
def ner_inference(model, val_data, tokenizer, config):
    all_entities, all_targets = [], []
    n_val = len(val_data)
    for i in range(n_val):
        encoded, spans = tokenizer.encode_plus_untagged(
            val_data.iloc[i]['text_body'],
            val_data.iloc[i]['text'], 
            config.max_len
            )
        encoded = {k: v.to(config.device) for k, v in encoded.items()}
        model.to(config.device)
        model.eval()
        with torch.no_grad():
            output = model(**encoded)
        preds = output.logits[0].cpu().detach().numpy().tolist()
        entities = tokenizer.convert_bert_output_to_entities(
            val_data.iloc[i]['text_body'], 
            preds, 
            spans,
            config.num_entities
            )
        all_entities.append(entities)
        # モデル評価の為に、ターゲットのエンティティを作成
        target_name = val_data.iloc[i]['text']
        target_span = [val_data.iloc[i]['start'], val_data.iloc[i]['end']]
        target_typeId = val_data.iloc[i]['label']
        if target_typeId == 0:
            targets = []
        else:
            targets = [{'name': target_name, 'span': target_span, 'type_id': target_typeId}]
        all_targets.append(targets)

    return all_entities, all_targets

In [15]:
def ner_evaluation(entities_arr, targets_arr):
    n_entities, n_targets, n_correct = 0, 0, 0
    for entities, targets in zip(entities_arr, targets_arr):
        get_span_type = lambda x: (x['span'][0], x['span'][1], x['type_id'])
        set_entities = set(get_span_type(entity) for entity in entities)
        set_targets = set(get_span_type(target) for target in targets)
        n_targets += len(targets)
        n_entities += len(entities)
        n_correct += len(set_entities & set_targets)
    precision = n_correct / n_entities
    recall = n_correct / n_targets
    f1 = 2 * precision * recall / (precision + recall)
    return {'precision': precision, 'recall': recall, 'f1': f1}

In [16]:
def run_train(fold, tokenizer, data, tr_idx, val_idx, logger, config):
    logger.info(f'\t----- Fold: {fold} -----')
    train, val = data.iloc[tr_idx], data.iloc[val_idx]

    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_loss', 
                                              mode='min', 
                                              save_top_k=1,  
                                              save_weights_only=True, 
                                              dirpath=f'model_folds_seed_{config.seed}/model_fold{fold}/')
    
    es_callback = pl.callbacks.EarlyStopping(monitor='val_loss', 
                                             patience=3)
    
    tb_logger = pl.loggers.TensorBoardLogger(f'model_folds_seed_{config.seed}/model_fold{fold}_logs/')

    trainer = pl.Trainer(max_epochs=config.epochs,
                         gpus=config.n_gpus,
                         logger=tb_logger,
                         callbacks=[checkpoint,es_callback],
                         progress_bar_refresh_rate=0)
    
    num_training_steps = ((len(train)) // (config.train_batch_size)) * float(config.epochs)
    model = NERModel(config, num_training_steps)
    datamodule = NERDataModule(train, val, tokenizer, config)
    trainer.fit(model, datamodule=datamodule)

    model.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])
    entities_arr, targets_arr = ner_inference(model, val, tokenizer, config)

    del datamodule, model
    gc.collect()
    torch.cuda.empty_cache()

    return entities_arr, targets_arr

## Metrics
#### Precision, Recall, F1
## CV
#### GroupKFold
`curid`をGroup IDとして、 バリデーション分割をする

In [17]:
def run_all(config):
    gkf = GroupKFold(n_splits=config.n_folds)
    tokenizer = NERTokenizer.from_pretrained(config.model_name)
    logger = init_logger('cv_results/cv.log')
    precision_score = 0.0
    recall_score = 0.0
    f1_score = 0.0

    for i, (tr_idx, val_idx) in enumerate(gkf.split(data, data[config.label_col], data[config.group_col])):
        entities_arr, targets_arr = run_train(i, tokenizer, data, tr_idx, val_idx, logger, config)
        eval_result = ner_evaluation(entities_arr, targets_arr)
        precision_score += eval_result['precision']
        recall_score += eval_result['recall']
        f1_score += eval_result['f1']
        
        logger.info(f'FOLD{i} PRECISION SCORE: {eval_result["precision"]:.5f}')
        logger.info(f'FOLD{i} RECALL SCORE: {eval_result["recall"]:.5f}')
        logger.info(f'FOLD{i} F1 SCORE: {eval_result["f1"]:.5f}')
    logger.info(f'{config.n_folds}FOLDS PRECISION CV SCORE: {precision_score/config.n_folds:.5f}')
    logger.info(f'{config.n_folds}FOLDS RECALL CV SCORE: {recall_score/config.n_folds:.5f}')
    logger.info(f'{config.n_folds}FOLDS F1 CV SCORE: {f1_score/config.n_folds:.5f}')

In [18]:
run_all(Cfg)

	----- Fold: 0 -----
FOLD0 PRECISION SCORE: 0.76272
FOLD0 RECALL SCORE: 0.76243
FOLD0 F1 SCORE: 0.76257
	----- Fold: 1 -----
FOLD1 PRECISION SCORE: 0.75396
FOLD1 RECALL SCORE: 0.75681
FOLD1 F1 SCORE: 0.75538
	----- Fold: 2 -----
FOLD2 PRECISION SCORE: 0.75427
FOLD2 RECALL SCORE: 0.75656
FOLD2 F1 SCORE: 0.75541
	----- Fold: 3 -----
FOLD3 PRECISION SCORE: 0.77144
FOLD3 RECALL SCORE: 0.77407
FOLD3 F1 SCORE: 0.77275
	----- Fold: 4 -----
FOLD4 PRECISION SCORE: 0.77386
FOLD4 RECALL SCORE: 0.77416
FOLD4 F1 SCORE: 0.77401
5FOLDS PRECISION CV SCORE: 0.76325
5FOLDS RECALL CV SCORE: 0.76481
5FOLDS F1 CV SCORE: 0.76403
