In [None]:
# livedoorニュースコーパスのダウンロード
! wget http://www.rondhuit.com/download/ldcc-20140209.tar.gz && tar xvfz ldcc-20140209.tar.gz
! mv text news_data

In [None]:
# ニューステキストからデータセットを作成（train.tsv, dev.tsv, test.tsv）
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split

news_dir   = 'news_data/'
train_file = news_dir + 'train.tsv'
dev_file   = news_dir + 'dev.tsv'
test_file  = news_dir + 'test.tsv'
df = pd.DataFrame([], columns=['text', 'label'])
dirs = ['dokujo-tsushin', 'it-life-hack', 'kaden-channel', 'livedoor-homme', 'movie-enter', 'peachy', 'smax', 'sports-watch', 'topic-news']

for index, dir in enumerate(dirs):
    files = Path(news_dir+dir).glob('*.txt')
    for file in files:
        text = open(str(file), "r")
        lines = text.readlines()
        df = df.append(pd.Series([str(lines[2]), index], index=df.columns), ignore_index=True)
        text.close()
 
train, test = train_test_split(df, test_size=0.2, shuffle=True, stratify=df.label)
df = pd.DataFrame(train)
train, dev = train_test_split(df, test_size=0.2, shuffle=True, stratify=df.label)
 
train.to_csv(train_file, sep='\t', index=False, header=True)
dev.to_csv(dev_file, sep='\t', index=False, header=True)
test.to_csv(test_file, sep='\t', index=False, header=True)

In [None]:
# 事前学習済みモデルのダウンロード (Chinese Base)
! wget https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz && tar xvfz albert_base_zh.tar.gz
! mv albert_base albert_checkpoint/

In [None]:
# 再学習（ファインチューニング）
TASK = 'JpNews'
BERT_BASE_DIR = 'albert_checkpoint/albert_base'
DATASET_DIR = 'news_data'
OUTPUT_DIR = 'output_news'

! python3 ./albert_repo/run_classifier.py \
  --data_dir=$DATASET_DIR \
  --output_dir=$OUTPUT_DIR \
  --init_checkpoint=$BERT_BASE_DIR/model.ckpt-best \
  --albert_config_file=$BERT_BASE_DIR/albert_config.json \
  --vocab_file=$BERT_BASE_DIR/vocab_chinese.txt \
  --do_train \
  --do_eval \
  --do_lower_case=False \
  --max_seq_length=128 \
  --optimizer=adamw \
  --task_name=$TASK \
  --warmup_step=100 \
  --learning_rate=3e-5 \
  --train_step=5000 \
  --train_batch_size=32

In [None]:
# 予測（テスト）
TASK = 'JpNews'
BERT_BASE_DIR = 'albert_checkpoint/albert_base'
GLUE_DIR = 'news_data'
OUTPUT_DIR = 'output_news'
TRAINED_CLASSIFIER = 'output_news'

! python3 ./albert_repo/run_classifier.py \
  --task_name=$TASK \
  --do_predict=true \
  --do_lower_case=False \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab_chinese.txt \
  --albert_config_file=$BERT_BASE_DIR/albert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=$OUTPUT_DIR

In [None]:
# SavedModel形式でモデルをエクスポート
TASK = 'JpNews'
BERT_BASE_DIR = 'albert_checkpoint/albert_base'
GLUE_DIR = 'news_data'
OUTPUT_DIR = 'output_news'
TRAINED_CLASSIFIER = 'output_news'

! python3 ./albert_repo/run_classifier.py \
  --task_name=$TASK \
  --do_export=true \
  --do_lower_case=False \
  --data_dir=$GLUE_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab_chinese.txt \
  --albert_config_file=$BERT_BASE_DIR/albert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=$OUTPUT_DIR