In [1]:
"""
BERTの勉強 note3
"""
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from transformers import AutoModel, AutoTokenizer
from transformers import TrainingArguments, Trainer

import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, precision_recall_fscore_support

from tqdm import tqdm
import glob, pickle

pretrained_model_name = "cl-tohoku/bert-base-japanese"

In [2]:
# タスク用Datasetクラスを定義
class LivedoorDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    
    def __getitem__(self, idx):
        item = { key: torch.tensor(val[idx]) for key, val in self.encodings.items() }
        item["labels"] = torch.tensor(self.labels[idx]) # item["label"]でなくitem["labels"]が正しい！
        return item
    
    def __len__(self):
        return len(self.labels)     

In [3]:
# 保存済みDatasetをpklからロード
with open("../../DataSet/ldcc/dataloader/ds_train.pkl", "rb") as f:
    ds_train = pickle.load(f)
with open("../../DataSet/ldcc/dataloader/ds_valid.pkl", "rb") as f:
    ds_valid = pickle.load(f)
with open("../../DataSet/ldcc/dataloader/ds_test.pkl", "rb") as f:
    ds_test = pickle.load(f)

#### データの準備ここまで
ここからはtransformersを活用\
今回はBertModel（AutoModel）から自前で構築してみる

ポイントは
1. forwardの出力の最初は必ずlossの値にする（Trainerの仕様）
2. loss計算のためにforwardの引数にはlabelが必要
3. その他のloss計算に必要な設定も引数として渡す。\
使わないとしても実行時エラーや学習に誤りが含まれる可能性があるので引数として受け口だけでも用意しておく

In [4]:
"""
ファインチューニング用モデル自作
"""
class BertClassifier(nn.Module):
    def __init__(self, pretrained_model):
        super(BertClassifier, self).__init__()
        
        self.bert = pretrained_model
        self.dropout = nn.Dropout(p=.1)
        self.classifier = nn.Linear(in_features=768, out_features=9) #9カテゴリのクラス分類
        
        # 重み初期化
        nn.init.normal_(self.classifier.weight, std=.02)
        nn.init.normal_(self.classifier.bias, 0)
        
    def forward(
        self, 
        input_ids, 
        labels=None, 
        attention_mask=None, token_type_ids=None, position_ids=None,
        head_mask=None, inputs_embeds=None, output_attentions=None,
        output_hidden_states=None, return_dict=None    
    ):
        output = self.bert(
            input_ids=input_ids, 
            attention_mask=attention_mask,
            token_type_ids=token_type_ids, 
            position_ids=position_ids,
            head_mask=head_mask, 
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        pooler_output = output.pooler_output
        pooler_output = self.dropout(pooler_output)
        output_classifier = self.classifier(pooler_output)
        
        # loss計算
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(output_classifier.view(-1, 9), labels.view(-1))
        
        # 出力はlossが先
        return loss, output_classifier   

In [5]:
model = AutoModel.from_pretrained(pretrained_model_name)
my_model = BertClassifier(model)

Some weights of the model checkpoint at cl-tohoku/bert-base-japanese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


#### 学習工程はTrainerを使って定義
* TrainingArgumentsでコンフィグ指定
* Trainerインスタンス作成
    - モデルやデータセットはここで渡す
    - 必要に応じて評価時のメトリクス計算関数をセット（accとかprとかf1とか）
* Trainer.train()で学習

In [6]:
# 評価関数の設定
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [7]:
# TrainingArguments, Trainerを定義
training_args = TrainingArguments(
    output_dir='./mymodel_outputs/',
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    no_cuda=False,
    evaluation_strategy='steps',
    eval_steps=50
)

if "trainer" in locals():
    del trainer

trainer = Trainer(
    model=my_model,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_valid,
    compute_metrics=compute_metrics
)

In [8]:
# ファインチューニング
%time trainer.train()

***** Running training *****
  Num examples = 5893
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 1474


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
50,1.9172,1.930099,0.298507,0.17156,0.275792,0.272559
100,1.495,1.135765,0.677069,0.588255,0.665413,0.63226
150,0.908,0.783815,0.770692,0.730839,0.798093,0.740459
200,0.7969,0.917494,0.693351,0.672843,0.801011,0.668779
250,0.7287,0.594911,0.795115,0.760493,0.81794,0.768404
300,0.5105,0.547249,0.824966,0.815931,0.838143,0.826081
350,0.5183,0.488564,0.83175,0.814761,0.865358,0.814704
400,0.7206,0.423532,0.886024,0.882898,0.894517,0.882693
450,0.4323,0.411885,0.869742,0.850046,0.880185,0.846341
500,0.6353,0.336761,0.902307,0.895508,0.900963,0.895338


***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
  _warn_prf(average, modifier, msg_start, len(result))
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
  _warn_prf(average, modifier, msg_start, len(result))
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
Saving model checkpoint to ./mymodel_outputs/checkpoint-500
Trainer.model is not a `PreTrainedModel`, only saving its state dict.
***** Running Evaluation *****
  Num examples = 737
  Batch size = 8
*

CPU times: total: 18min 47s
Wall time: 18min 58s


TrainOutput(global_step=1474, training_loss=0.515492757774758, metrics={'train_runtime': 1138.1806, 'train_samples_per_second': 5.178, 'train_steps_per_second': 1.295, 'total_flos': 0.0, 'train_loss': 0.515492757774758, 'epoch': 1.0})

In [23]:
# validationでの性能評価
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 737
  Batch size = 8


{'eval_loss': 0.2155521959066391,
 'eval_accuracy': 0.9402985074626866,
 'eval_f1': 0.9314369480683583,
 'eval_precision': 0.9318940677639961,
 'eval_recall': 0.9315479030773096,
 'eval_runtime': 19.7575,
 'eval_samples_per_second': 37.302,
 'eval_steps_per_second': 4.707,
 'epoch': 1.0}

In [9]:
# ファインチューニングしたモデルをテストデータで性能評価
trainer.evaluate(ds_test)

***** Running Evaluation *****
  Num examples = 737
  Batch size = 8


{'eval_loss': 0.3659582734107971,
 'eval_accuracy': 0.9199457259158752,
 'eval_f1': 0.9123277184731702,
 'eval_precision': 0.917057359689104,
 'eval_recall': 0.9117513363006746,
 'eval_runtime': 19.0669,
 'eval_samples_per_second': 38.653,
 'eval_steps_per_second': 4.878,
 'epoch': 1.0}