# rinna GPT-2 モデルの Fine Tuning
HuggingFace の transformers ライブラリを用いて [rinna gpt-2](https://huggingface.co/rinna/japanese-gpt2-medium) モデルの Fine Tuning を行います。

## 事前準備
必要なライブラリをインポートします。

In [None]:
from azureml.core import Experiment, Workspace, Environment
from azureml.core.compute import ComputeTarget
from azureml.core import ScriptRunConfig
from azureml.core.runconfig import PyTorchConfiguration

import os
os.makedirs('src', exist_ok=True)

Azure ML Workspace へ接続します。

In [None]:
ws = Workspace.from_config()

実験 Experiment の名称

In [None]:
model_experiment = Experiment(ws, name="rinna-gpt2-exp")

分散学習の設定

In [None]:
distr_config = PyTorchConfiguration(process_count=1, node_count=1)

環境 Environment の設定

In [None]:
hf_ort_env = Environment.from_dockerfile(name='rinna-docker-env', dockerfile='Dockerfile')

学習コードの準備

In [None]:
%%writefile src/train.py
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import sys
from azureml.core import Run
import argparse
import mlflow
from datasets import load_dataset
from transformers import (AutoModelForCausalLM,
                          DataCollatorForLanguageModeling, T5Tokenizer,
                          TextDataset, Trainer, TrainerCallback,
                          TrainingArguments, default_data_collator)

# 日本語対応
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')

# 引数
parser = argparse.ArgumentParser()

parser.add_argument('--max_steps', type=int, default=100)
parser.add_argument('--output_dir', type=str)
parser.add_argument('--model_name_or_path', default='rinna/japanese-gpt2-medium')

args = parser.parse_args()

# Azure ML 事前準備
run = Run.get_context()
ws = run.experiment.workspace

# mlflow trackinr uri の設定
mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())

# tokenizer, model オブジェクトのロード
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
model.resize_token_embeddings(len(tokenizer))

# データセット
train_path = 'train.txt'
test_path = 'test.txt'

train_dataset = TextDataset(tokenizer=tokenizer, file_path=train_path, block_size=512)
eval_dataset = TextDataset(tokenizer=tokenizer, file_path=test_path, block_size=512)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# mlflow でログを取るための callback クラス
class MyCallback(TrainerCallback):
    def __init__(self, azureml_run=None):
        self.mlflow = mlflow

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.mlflow.log_metric(k, v, step=state.global_step)

# Trainer 引数
training_args = TrainingArguments(
    output_dir="./outputs", 
    overwrite_output_dir=True, 
    max_steps=args.max_steps,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    eval_steps=50,
    fp16=True,
    report_to=["none"],
    ort=True,
    )


# モデル学習の設定
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback]
)

# モデル学習開始
trainer.train()

# モデルの保存
trainer.save_model()

# モデルの検証
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium", do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained("./outputs")

input = tokenizer.encode("仕事", return_tensors="pt")
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
print(tokenizer.batch_decode(output))

input = tokenizer.encode("料理", return_tensors="pt")
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
print(tokenizer.batch_decode(output))


input = tokenizer.encode("握手をしたら、", return_tensors="pt")
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=100)
print(tokenizer.batch_decode(output))

スクリプトの引数の定義

In [None]:
script_params = ['--max_steps', 100, '--output_dir', './outputs', '--model_name_or_path', 'rinna/japanese-gpt2-medium']

## モデル学習
`ScriptRunConfig` を用いて Azure Machine Learning Compute Cluster 上で学習ができるように設定します。

In [None]:
model_run_config = ScriptRunConfig(
    source_directory='./src',
    script='./train.py',
    arguments=script_params,
    compute_target=ComputeTarget(workspace=ws, name="gpuinstance"),
    environment=hf_ort_env,
    distributed_job_config=distr_config)

モデル学習の開始

In [None]:
run = model_experiment.submit(model_run_config)
run

In [None]:
run.wait_for_completion(show_output=True)

## モデルテスト
ローカル環境でモデルの推論を行います。Run の outputs フォルダのモデルファイルをダウンロード & ロードして利用します。

In [None]:
run_test = ws.get_run(run.id)
run_test.run.download_files(prefix='outputs/models/', output_directory='./')

In [None]:
from transformers import T5Tokenizer, AutoModelForCausalLM

tokenizer = T5Tokenizer.from_pretrained("outputs/models/", do_lower_case=True)
model = AutoModelForCausalLM.from_pretrained("outputs/models/")

In [None]:
input = tokenizer.encode("こんにちは、", return_tensors="pt")
output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=10)
print(tokenizer.batch_decode(output))