# 第2章 ファインチューニング: 言語モデルの追加学習

In [30]:
import os

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "cyberagent/open-calm-small"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

In [31]:
input = tokenizer.encode("私は犬が好き。", return_tensors="pt")
print(input)

tensor([[2727, 3807, 9439,  247]])


In [32]:
a = [tokenizer.decode(input[0][i]) for i in range(len(input[0]))]
print(a)

['私は', '犬', 'が好き', '。']


In [33]:
output = model(input)
print(type(output))
print(output.logits)
print(output.logits.shape)

<class 'transformers.modeling_outputs.CausalLMOutputWithPast'>
tensor([[[  7.2745, -17.2270,   8.0928,  ..., -16.3880, -17.2634, -17.3245],
         [  7.2613, -16.8069,  10.7188,  ..., -16.0717, -17.0571, -16.9093],
         [ 12.8202, -16.5409,  15.7865,  ..., -15.9594, -16.8466, -16.8622],
         [ 12.0014, -17.0628,   7.1439,  ..., -16.2650, -17.1278, -17.2398]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 4, 52096])


In [34]:
# 各tokenと入力文に対する損失値の出力
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
from_tensor = torch.cat([input[0][1:], torch.tensor([-100])])
print(from_tensor)
loss0 = loss_fn(output.logits[0], from_tensor)
print(loss0)
# tokenの損失値の平均
print(torch.sum(loss0) / 3)

tensor([3807, 9439,  247, -100])
tensor([8.0847, 5.4859, 3.0050, 0.0000], grad_fn=<NllLossBackward0>)
tensor(5.5252, grad_fn=<DivBackward0>)


In [35]:
# 入力分に対する損失値の出力
loss_fn = torch.nn.CrossEntropyLoss()
loss1 = loss_fn(output.logits[0], from_tensor)
print(loss1)

tensor(5.5252, grad_fn=<NllLossBackward0>)


In [36]:
# `model` を利用して自動で損失値を計算
output = model(input, labels=input)
loss = output.loss
print(loss)

tensor(5.5252, grad_fn=<NllLossBackward0>)


In [37]:
# パラメータの更新
optimizer.zero_grad()
loss.backward()
optimizer.step()

言語モデルの追加学習は、コーパスの各文に対して、パラメータの更新処理を繰り返して実行する。

# 2.2 Trainerの利用

In [38]:
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="./output",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=5,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
)

## データ利用元

https://sites.google.com/view/kvjcorpus/%E3%83%9B%E3%83%BC%E3%83%A0/%E6%97%A5%E6%9C%AC%E8%AA%9E/%E3%83%87%E3%83%BC%E3%82%BF%E3%83%95%E3%82%A1%E3%82%A4%E3%83%AB?authuser=0

ヘファナン・ケビン（2012）「関西弁コーパスの紹介」『総合政策研究』41号 157-164.

- ./chap2-kjs-corpus/test.txt
- ./chap2-kjs-corpus/train.txt
- ./chap2-kjs-corpus/val.txt

In [39]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, filename, tokenizer):
        self.tokenizer = tokenizer
        self.features = []
        with open(filename, "r") as f:
            lines = f.read().split("\n")
            for line in lines:
                input_ids = tokenizer.encode(line, return_tensors="pt", max_length=512, padding="longest")[0]
                self.features.append({"input_ids": input_ids})
    def __len__(self):
        return len(self.features)
    def __getitem__(self, idx):
        return self.features[idx]

train_dataset = MyDataset("./chap2-kjs-corpus/train.txt", tokenizer)

In [40]:
from transformers import DataCollatorForLanguageModeling
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [41]:
from torch.utils.data import DataLoader
dataloader = DataLoader(train_dataset, batch_size=10, shuffle=True, collate_fn=collator)

In [42]:
dl = iter(dataloader)
a = next(dl)
print(a)
# paddingが行われていることが確認できる

{'input_ids': tensor([[ 3479,  5955,  4180,   245,  5020,  1332,  2443, 21892,   258,  2076,
           338,   529,   256,   245,  8267,  2831,   245,   623,  5081,   267,
           256,  3467,   247],
        [  592, 21164,   270,  2868,  7253, 12127, 34609,   247,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [18092, 12172, 30105, 28820,    32,   279, 46776,   245,   676,   247,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [  308, 25821, 42586,   247,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [51901,    32, 51901,  3972, 23848,   463, 27916,   314,   259,   247,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1],
        [ 5020, 14190,   363,   245,  488

# 2.5 保存されたモデルからの文生成