# BERTで含意関係認識（日本語）
文１と文２の関係が含意、中立、矛盾のうちどれかに分類を行う  

# ライブラリの準備

In [None]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
!pip3 install sentence-transformers
!pip3 install fugashi ipadic
!pip3 install unidic-lite

In [None]:
from sentence_transformers.readers import InputExample
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import (
    CEBinaryClassificationEvaluator,
    CESoftmaxAccuracyEvaluator
)
from torch.utils.data import DataLoader
import pathlib

# 含意関係認識の日本語データセットのダウンロード
https://nlp.ist.i.kyoto-u.ac.jp/?%E6%97%A5%E6%9C%AC%E8%AA%9ESNLI%28JSNLI%29%E3%83%87%E3%83%BC%E3%82%BF%E3%82%BB%E3%83%83%E3%83%88

In [None]:
!wget -O jsnli.zip "https://nlp.ist.i.kyoto-u.ac.jp/DLcounter/lime.cgi?down=https://nlp.ist.i.kyoto-u.ac.jp/nl-resource/JSNLI/jsnli_1.1.zip&name=JSNLI.zip"
!unzip jsnli.zip

# データの読み込み

In [None]:
data_train = pd.read_table('./jsnli_1.1/train_w_filtering.tsv',names=('label', 'text', 'hypothesis'))
display(data_train)
data_test = pd.read_table('./jsnli_1.1/dev.tsv',names=('label', 'text', 'hypothesis'))
display(data_test)

# データの可視化

In [None]:
# 主要な統計手法をまとめて確認
data_train.describe()

In [None]:
data_test.describe()

In [None]:
# 欠損値有り(True)、欠損値無し(False)
display(data_train.isnull().any())
display(data_test.isnull().any())

In [None]:
# 型の確認
data_train.dtypes

# 特徴量エンジニアリング

In [None]:
# labelを数値に変換
# entailment:0 neutral:1 contradiction:2
label2int = {"contradiction": 2, "entailment": 0, "neutral": 1}
data_train['label'].replace(['entailment', 'neutral', 'contradiction'], [0,1,2], inplace=True) #  inplaceはもとのデータに結果を反映させる
data_test['label'].replace(['entailment', 'neutral', 'contradiction'], [0,1,2], inplace=True) #  inplaceはもとのデータに結果を反映させる

In [None]:
# 文字列の空白を削除
data_train['text'] = data_train['text'].str.replace(' ', '') #  inplaceはもとのデータに結果を反映させる
data_train['hypothesis'] = data_train['hypothesis'].str.replace(' ', '') #  inplaceはもとのデータに結果を反映させる
data_test['text'] = data_test['text'].str.replace(' ', '') #  inplaceはもとのデータに結果を反映させる
data_test['hypothesis'] = data_test['hypothesis'].str.replace(' ', '') #  inplaceはもとのデータに結果を反映させる

In [None]:
# 文字列の句点を削除（対話システム文に句点がつかない場合）
data_train['text'] = data_train['text'].str.replace('。', '') #  inplaceはもとのデータに結果を反映させる
data_train['hypothesis'] = data_train['hypothesis'].str.replace('。', '') #  inplaceはもとのデータに結果を反映させる
data_test['text'] = data_test['text'].str.replace('。', '') #  inplaceはもとのデータに結果を反映させる
data_test['hypothesis'] = data_test['hypothesis'].str.replace('。', '') #  inplaceはもとのデータに結果を反映させる

In [None]:
display(data_train)
display(data_test)

# sentence-transformer用にデータを加工

In [None]:
# def line2inp(label, text, hyp):
def line2inp(row):
    return InputExample(texts=[row['text'], row['hypothesis']], label=row['label'])

In [None]:
data_train_st = data_train.apply(line2inp, axis=1)
train_dataloader = DataLoader(data_train_st, shuffle=True, batch_size=16)

data_test_st = data_test.apply(line2inp, axis=1)
# test_dataloader = DataLoader(data_test_st, shuffle=True, batch_size=16)

In [None]:
print(data_train_st)
print(len(data_train_st))
display(data_test_st)
print(len(data_test_st))

# モデルの作成

In [None]:
# モデルの定義
model_name = "cl-tohoku/bert-base-japanese-v2"
model = CrossEncoder(model_name, num_labels=len(label2int))

# 評価用のオブジェクトを生成
evaluator = CESoftmaxAccuracyEvaluator.from_input_examples(  # 分類
    # test_dataloader, name=f"jsnli-dev"
    data_test_st, name=f"jsnli-dev"
)

# モデルの学習

In [None]:
num_epochs = 4
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) #10% of train data for warm-up
model_save_path = './jsnli_model'

In [None]:
model.fit(train_dataloader=train_dataloader,
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=10000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)