In [5]:
import gc
import logging
import warnings
import multiprocessing
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, hamming_loss
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch_optimizer import Lamb
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from transformers import BertModel, BertJapaneseTokenizer

warnings.filterwarnings('ignore')
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
logging.getLogger('transformers').setLevel(logging.ERROR)

In [6]:
class Cfg:
    debug = False
    seed = 42
    lr = 1e-5
    max_len = 32
    epochs = 20
    n_folds = 5
    train_batch_size = 512
    val_batch_size = 8192
    num_classes = 5
    model_name = 'cl-tohoku/bert-base-japanese' 
    data_name = 'text'
    multilabel_columns = ['pos&neg', 'pos', 'neg', 'neu', 'non']
    n_gpus = torch.cuda.device_count()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
seed_everything(Cfg.seed, workers=True)

42

In [8]:
data = pd.read_csv(
    '../input/preprocessed_data.csv', 
    dtype={
        'topic': object, 
        'text': object, 
        'pos&neg': np.float32, 
        'pos': np.float32, 
        'neg': np.float32, 
        'neu': np.float32, 
        'non': np.float32
    }
)

In [9]:
if Cfg.debug == True:
    data = data.sample(10000, random_state=Cfg.seed)

print(f'Data: {len(data)}')

Data: 299145


In [10]:
data.head(1)

Unnamed: 0,text,pos&neg,pos,neg,neu,non
0,エクスペリアのGPS南北が逆になるのはデフォだったのか。,0.0,0.0,1.0,1.0,0.0


In [11]:
class TweetsDataset(Dataset):
    def __init__(self, data, tokenizer, config):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data_row = self.data.iloc[index]
        text = data_row[self.config.data_name]
        labels = data_row[self.config.multilabel_columns]

        encoded = self.tokenizer.encode_plus(
            text, 
            add_special_tokens=True,
            max_length=self.config.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoded['input_ids'].flatten(),
            'token_type_ids': encoded['token_type_ids'].flatten(),
            'attention_mask': encoded['attention_mask'].flatten(),
            'labels': torch.tensor(labels)
        }

Optimizer: `LAMB`

https://arxiv.org/abs/1904.00962

LAMBは層ごとの適応的な学習率の正当化が働きます。

<br>

ラージバッチトレーニングの効率化と精度の安定が期待出来るため、採用します。

<br>

Loss Function: `Focal Loss`

https://arxiv.org/abs/1708.02002

Focal Lossは、予測が上手く出来ているサンプルデータのlossを小さくし、

予測が上手く出来ていないサンプルデータの学習を促進させます。

それにより、学習データセットのクラス間が不均衡なことが要因で起きる問題に作用します。

<br>

不均衡な学習データセットであるため、採用します。


In [12]:
# reference: 
# https://www.kaggle.com/c/tgs-salt-identification-challenge/discussion/65938
class BinaryFocalLossWithLogits(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduce=False):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduce = reduce

    def forward(self, preds, labels):
        bce_criterion = nn.BCEWithLogitsLoss()
        bce_loss = bce_criterion(preds, labels)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        if self.reduce == True:
            focal_loss = torch.mean(focal_loss)
        return focal_loss

In [13]:
class CustomModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.bert = BertModel.from_pretrained(self.config.model_name)
        self.linear = nn.Linear(
            self.bert.config.hidden_size, 
            self.config.num_classes
            )
        self.criterion = BinaryFocalLossWithLogits()
    
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(
            input_ids, 
            attention_mask=attention_mask
            )
        preds = self.linear(output.pooler_output)

        if labels is not None:
            loss = self.criterion(preds, labels)
            return loss
        else:
            return preds

    def training_step(self, batch, batch_idx):
        loss = self.forward(input_ids=batch['input_ids'],
                            token_type_ids=batch['token_type_ids'],
                            attention_mask=batch['attention_mask'],
                            labels=batch['labels'])
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        val_loss = self.forward(input_ids=batch['input_ids'],
                                token_type_ids=batch['token_type_ids'],
                                attention_mask=batch['attention_mask'],
                                labels=batch['labels'])
        self.log('val_loss', val_loss)
        return val_loss
    
    def test_step(self, batch, batch_idx):
        return self.valadation_step(batch, batch_idx)

    def configure_optimizers(self):
        return Lamb(self.parameters(), lr=self.config.lr)

In [14]:
def inference(model, data_loader, config):
    all_outputs = []
    all_labels = []
    
    for batch_idx, batch in enumerate(data_loader):
        ids = batch['input_ids']
        token_type_ids = batch['token_type_ids']
        mask = batch['attention_mask']
        labels = batch['labels']
        
        ids = ids.to(config.device)
        token_type_ids = token_type_ids.to(config.device)
        mask = mask.to(config.device)
        labels = labels.to(config.device)
        
        model.to(config.device)
        model.eval()
        with torch.no_grad():
            outputs = model(
                input_ids=ids,
                token_type_ids=token_type_ids, 
                attention_mask=mask
                )
        outputs = outputs.cpu().detach().numpy().tolist()
        labels = labels.cpu().detach().numpy().tolist() 
        all_outputs.extend(outputs)
        all_labels.extend(labels)
    
    return all_outputs, all_labels

In [15]:
def run_train(fold, tokenizer, data, tr_idx, val_idx, config):
    print(f'     ----- Fold: {fold} -----')
    train, val = data.iloc[tr_idx], data.iloc[val_idx]

    train_ds = TweetsDataset(train, tokenizer, config)
    val_ds = TweetsDataset(val, tokenizer, config)

    train_loader = DataLoader(train_ds, 
                              batch_size=config.train_batch_size, 
                              num_workers=multiprocessing.cpu_count(), 
                              pin_memory=True, 
                              drop_last=True,
                              shuffle=True)
    
    val_loader = DataLoader(val_ds, 
                            batch_size=config.val_batch_size, 
                            num_workers=multiprocessing.cpu_count(), 
                            pin_memory=True, 
                            drop_last=False,
                            shuffle=False)

    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_loss', 
                                              mode='min', 
                                              save_top_k=1,  
                                              save_weights_only=True, 
                                              dirpath=f'model_fold{fold}/')
    
    es_callback = pl.callbacks.EarlyStopping(monitor='val_loss', 
                                             patience=1)
    
    tb_logger = pl.loggers.TensorBoardLogger(f'model_fold{fold}_logs/')

    trainer = pl.Trainer(max_epochs=config.epochs,
                         gpus=config.n_gpus,
                         logger=tb_logger,
                         callbacks=[checkpoint,es_callback],
                         progress_bar_refresh_rate=0)
    
    model = CustomModel(config)
    model.to(config.device)
    trainer.fit(model, train_loader, val_loader)
    
    model.load_state_dict(torch.load(checkpoint.best_model_path)['state_dict'])
    outputs, labels = inference(model, val_loader, config)

    del train_ds, val_ds, train_loader, val_loader, model
    gc.collect()
    torch.cuda.empty_cache()

    return outputs, labels

Metrics: `Hamming Loss`

参考: https://buildersbox.corp-sansan.com/entry/2019/12/11/110000

<br>

$$
\text{Hamming Loss} = \frac{1}{n}\sum_{i \in \mathcal{I}} \frac{|\boldsymbol
{y}_i \bigtriangleup \boldsymbol {z}_i|}{m}
$$

<br>

$yi$: 真のラベルベクトルの真偽値

$zi$: 予測ラベルベクトルの真偽値

$\bigtriangleup$: XOR

$m$: 全事例

<br>

ハミング距離で損失を測る評価指標。

各事例集合（事例: マルチラベルの組み合わせ）ごとに、

真のラベルベクトルの真偽値と、予測ラベルベクトルの真偽値との間でXOR演算を取り、 

真偽値が異なる部分が全事例のラベル中何個あるかを割合で出して、

全事例に対して平均を取ります。

各事例ごとの評価値は候補ラベル数に依存します。

<br>

マルチラベルクラス分類に向いていると思われる評価指標のため、採用します。

<br>

CV: `MultilabelStratifiedKFold`

マルチラベルクラス分類のため、採用します。

In [16]:
mskf = MultilabelStratifiedKFold(n_splits=Cfg.n_folds, shuffle=True, random_state=Cfg.seed)
tokenizer = BertJapaneseTokenizer.from_pretrained(Cfg.model_name)
acc_scores = 0.0
hl_scores = 0.0

for i, (tr_idx, val_idx) in enumerate(mskf.split(data['text'], data[Cfg.multilabel_columns])):
    outputs, labels = run_train(i, tokenizer, data, tr_idx, val_idx, Cfg)
    outputs = np.array(outputs)
    labels = np.array(labels)
    preds = outputs > 0

    acc_score = accuracy_score(labels, preds)
    hl_score = hamming_loss(labels, preds)
    acc_scores += acc_score
    hl_scores += hl_score

    print(f'FOLD{i} ACCURACY SCORE: {acc_score:.5f}')
    print(f'FOLD{i} HAMMING LOSS SCORE: {hl_score:.5f}')
print(f'{Cfg.n_folds}FOLDS ACCURACY CV SCORE: {acc_scores/Cfg.n_folds:.5f}')
print(f'{Cfg.n_folds}FOLDS HAMMING LOSS CV SCORE: {hl_scores/Cfg.n_folds:.5f}')

Downloading:   0%|          | 0.00/258k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/104 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/479 [00:00<?, ?B/s]

     ----- Fold: 0 -----


Downloading:   0%|          | 0.00/445M [00:00<?, ?B/s]

FOLD0 ACCURACY SCORE: 0.68282
FOLD0 HAMMING LOSS SCORE: 0.10581
     ----- Fold: 1 -----
FOLD1 ACCURACY SCORE: 0.68897
FOLD1 HAMMING LOSS SCORE: 0.10683
     ----- Fold: 2 -----
FOLD2 ACCURACY SCORE: 0.68884
FOLD2 HAMMING LOSS SCORE: 0.10493
     ----- Fold: 3 -----
FOLD3 ACCURACY SCORE: 0.67873
FOLD3 HAMMING LOSS SCORE: 0.10709
     ----- Fold: 4 -----
FOLD4 ACCURACY SCORE: 0.68295
FOLD4 HAMMING LOSS SCORE: 0.10637
5FOLDS ACCURACY CV SCORE: 0.68446
5FOLDS HAMMING LOSS CV SCORE: 0.10621
