In [1]:
#データ準備

from pprint import pprint
from datasets import load_dataset

train_dataset = load_dataset(
    "shunk031/JGLUE",
    name="JNLI",
    split="train"
)

valid_dataset = load_dataset(
    "shunk031/JGLUE",
    name="JNLI",
    split="validation"
)





In [2]:
train_dataset.column_names

['sentence_pair_id', 'yjcaptions_id', 'sentence1', 'sentence2', 'label']

In [3]:
#前提文と仮説文を結合、ラベルを加え出力

from transformers import BatchEncoding
from transformers import AutoTokenizer

def preprocess_text_pair_classfication(example):
    encoded_example = tokenizer(example["sentence1"], example["sentence2"], max_length=128)
    encoded_example["labels"] = example["label"]



    return encoded_example



In [4]:
class_label = train_dataset.features["label"]

label2id = {label:id for id, label in enumerate(class_label.names)}
id2label = {id:label for id, label in enumerate(class_label.names)}



In [5]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, BatchEncoding, DataCollatorWithPadding
import numpy as np





train_arg = TrainingArguments(
    output_dir="./output/",
    per_device_train_batch_size=128,
    per_device_eval_batch_size=64,
    learning_rate=2e-5,
    lr_scheduler_type="linear",
    num_train_epochs=5,
    save_strategy="epoch",
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    fp16=True
)

def calc_accuracy(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    return {"accuracy":(predictions == labels).mean()}

class_label = train_dataset.features["label"]

label2id = {label:id for id, label in enumerate(class_label.names)}
id2label = {id:label for id, label in enumerate(class_label.names)}


model_name = "Mizuiro-sakura/luke-japanese-base-finetuned-jnli"


tokenizer = AutoTokenizer.from_pretrained(model_name)
encoded_train_dataset = train_dataset.map(
    preprocess_text_pair_classfication
)



encoded_valid_dataset = train_dataset.map(
    preprocess_text_pair_classfication
)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=class_label.num_classes,
    label2id=label2id,
    id2label=id2label,
    )

trainer = Trainer(
    model=model,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_valid_dataset,
    data_collator=data_collator,
    args=train_arg,
    compute_metrics=calc_accuracy
)






In [6]:
trainer.train()

  0%|          | 0/785 [00:00<?, ?it/s]

{'loss': 0.3768, 'learning_rate': 1.605095541401274e-05, 'epoch': 1.0}


  0%|          | 0/314 [00:00<?, ?it/s]

{'eval_loss': 0.08473782241344452, 'eval_accuracy': 0.9695611019777811, 'eval_runtime': 11.3438, 'eval_samples_per_second': 1769.508, 'eval_steps_per_second': 27.68, 'epoch': 1.0}
{'loss': 0.1043, 'learning_rate': 1.2050955414012739e-05, 'epoch': 2.0}


  0%|          | 0/314 [00:00<?, ?it/s]

{'eval_loss': 0.046230170875787735, 'eval_accuracy': 0.9863498231455189, 'eval_runtime': 11.2132, 'eval_samples_per_second': 1790.125, 'eval_steps_per_second': 28.003, 'epoch': 2.0}
{'loss': 0.0766, 'learning_rate': 8.05095541401274e-06, 'epoch': 3.0}


  0%|          | 0/314 [00:00<?, ?it/s]

{'eval_loss': 0.027775323018431664, 'eval_accuracy': 0.9931749115727594, 'eval_runtime': 11.2561, 'eval_samples_per_second': 1783.306, 'eval_steps_per_second': 27.896, 'epoch': 3.0}
{'loss': 0.0526, 'learning_rate': 4.0509554140127395e-06, 'epoch': 4.0}


  0%|          | 0/314 [00:00<?, ?it/s]

{'eval_loss': 0.021645167842507362, 'eval_accuracy': 0.9943705475016191, 'eval_runtime': 11.2695, 'eval_samples_per_second': 1781.176, 'eval_steps_per_second': 27.863, 'epoch': 4.0}
{'loss': 0.0396, 'learning_rate': 5.0955414012738854e-08, 'epoch': 5.0}


  0%|          | 0/314 [00:00<?, ?it/s]

{'eval_loss': 0.017589282244443893, 'eval_accuracy': 0.9958650924126937, 'eval_runtime': 11.1759, 'eval_samples_per_second': 1796.1, 'eval_steps_per_second': 28.096, 'epoch': 5.0}
{'train_runtime': 770.5323, 'train_samples_per_second': 130.254, 'train_steps_per_second': 1.019, 'train_loss': 0.12996883999769854, 'epoch': 5.0}


TrainOutput(global_step=785, training_loss=0.12996883999769854, metrics={'train_runtime': 770.5323, 'train_samples_per_second': 130.254, 'train_steps_per_second': 1.019, 'train_loss': 0.12996883999769854, 'epoch': 5.0})