In [None]:
# BERTリポジトリのダウンロード
!test -d bert_repo || git clone https://github.com/google-research/bert bert_repo

In [None]:
# BERT学習済みチェックポイントのダウンロード
!wget https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip -P bert_checkpoint
!unzip ./bert_checkpoint/multi_cased_L-12_H-768_A-12.zip -d bert_checkpoint

In [None]:
# ライブドアニュースコーパスのダウンロード
!wget http://www.rondhuit.com/download/ldcc-20140209.tar.gz && tar xvfz ldcc-20140209.tar.gz
!mv text news_data

In [None]:
# ニューステキストからデータセットを作成
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split

news_dir   = 'news_data/'
train_file = news_dir + 'train.tsv'
dev_file   = news_dir + 'dev.tsv'
test_file  = news_dir + 'test.tsv'
df = pd.DataFrame([], columns=['text', 'label'])
dirs = ['dokujo-tsushin', 'it-life-hack', 'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax', 'sports-watch', 'topic-news']

for index, dir in enumerate(dirs):
    files = Path(news_dir+dir).glob('*.txt')
    for file in files:
        text = open(str(file), "r")
        lines = text.readlines()
        df = df.append(pd.Series([str(lines[2]), index], index=df.columns), ignore_index=True)
        text.close()
 
train, test = train_test_split(df, test_size=0.2, shuffle=True, stratify=df.label)
df = pd.DataFrame(train)
train, dev = train_test_split(df, test_size=0.2, shuffle=True, stratify=df.label)
 
train.to_csv(train_file, sep='\t', index=False, header=True)
dev.to_csv(dev_file, sep='\t', index=False, header=True)
test.to_csv(test_file, sep='\t', index=False, header=True)

In [None]:
# 再学習（ファインチューニング）
TASK = 'Jp'
BERT_BASE_DIR = 'bert_checkpoint/multi_cased_L-12_H-768_A-12'
DATASET_DIR = 'news_data'
OUTPUT_DIR = 'output_news'

!python3 ./bert_repo/run_classifier.py \
  --task_name=$TASK \
  --do_train=true \
  --do_eval=true \
  --do_lower_case=False \
  --data_dir=$DATASET_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=$OUTPUT_DIR

In [None]:
# 予測（テスト）
TASK = 'Jp'
BERT_BASE_DIR = 'bert_checkpoint/multi_cased_L-12_H-768_A-12'
GLUE_DIR = 'news_data'
OUTPUT_DIR = 'output_news'
TRAINED_CLASSIFIER = 'output_news'

!python3 ./bert_repo/run_classifier.py \
  --task_name=$TASK \
  --do_predict=true \
  --do_lower_case=False \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=$OUTPUT_DIR

In [None]:
# SavedModel形式でモデルをエクスポート
TASK = 'Jp'
BERT_BASE_DIR = 'bert_checkpoint/multi_cased_L-12_H-768_A-12'
GLUE_DIR = 'news_data'
OUTPUT_DIR = 'output_news'
TRAINED_CLASSIFIER = 'output_news'

!python3 ./bert_repo/run_classifier.py \
  --task_name=$TASK \
  --do_export=true \
  --do_lower_case=False \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=$OUTPUT_DIR

In [None]:
# エクスポートしたモデルを確認
LATEST_MODEL_DIR=!(ls ./saved_model | tail -1)
LATEST_MODEL_DIR=LATEST_MODEL_DIR[0]
!echo $LATEST_MODEL_DIR
!saved_model_cli show --all --dir saved_model/$LATEST_MODEL_DIR