## Module
- nlp : CNN/DailyMailを取得するときに使うよ
- logging : 全体の処理の流れを掴みたい時のデバグしやすいよ
- transformer : bert系のこと全般やります

In [1]:
import nlp
import logging
from transformers import BertTokenizer, EncoderDecoderModel, Trainer, TrainingArguments
import torch

## Model,Tokenizer

logging.basicConfig(level=logging.INFO)でログレベルを設定

In [3]:
logging.basicConfig(level=logging.INFO)

- encoder, decoder共に"bert-base-uncased"で事前学習
- tokenizerも"bert-base-uncased"で事前学習

In [2]:
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer

- clsトークンをbosトークンとして動作させる。なぜ？
- sepトークンをeosトークンとして動作させる。なぜ？

In [3]:
# CLS token will work as BOS token
tokenizer.bos_token = tokenizer.cls_token

# SEP token will work as EOS token
tokenizer.eos_token = tokenizer.sep_token

## Dataset

- cnn/dailymailを分割しtrainとvalidationに分ける

In [4]:
# load train and validation data
train_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="train")
val_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")

- train, valのデータ構造を見てみる

In [6]:
train_dataset

Dataset(features: {'article': Value(dtype='string', id=None), 'highlights': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None)}, num_rows: 287113)

In [3]:
type(train_dataset)

nlp.arrow_dataset.Dataset

In [7]:
val_dataset

Dataset(features: {'article': Value(dtype='string', id=None), 'highlights': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None)}, num_rows: 1337)

- 学習時間短縮のためtrainを縮小

In [4]:
# load train and validation data
train_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="train[:1%]")
val_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="validation[:5%]")

In [5]:
train_dataset.features

{'article': Value(dtype='string', id=None),
 'highlights': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None)}

In [3]:
type(train_dataset)

nlp.arrow_dataset.Dataset

- train, valの構造

In [5]:
train_dataset


Dataset(features: {'article': Value(dtype='string', id=None), 'highlights': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None)}, num_rows: 2871)

In [6]:
val_dataset

Dataset(features: {'article': Value(dtype='string', id=None), 'highlights': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None)}, num_rows: 668)

## 評価基準

- rougeを取得

In [7]:
# load rouge for validation
rouge = nlp.load_metric("rouge")

INFO:nlp.load:Checking /home/ats432/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py for additional imports.
INFO:filelock:Lock 47772153456208 acquired on /home/ats432/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py.lock
INFO:nlp.load:Found main folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /home/ats432/anaconda3/envs/myenv_torch/lib/python3.7/site-packages/nlp/metrics/rouge
INFO:nlp.load:Found specific version folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /home/ats432/anaconda3/envs/myenv_torch/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1
INFO:nlp.load:Found script file from https://s3.amazonaws.com

## Config

- modelのパラメータを設定

In [5]:
# set decoding params
model.config.decoder_start_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.max_length = 142
model.config.min_length = 56
model.config.no_repeat_ngram_size = 3
model.early_stopping = True
model.length_penalty = 2.0
model.num_beams = 4

## Data Mapping module

- inputs, outputsを作成
- inputsは記事本文の入力で、最大５１２tokenまでになるようにカットする
- outputsは要約文の出力で、最大１２８tokenになるようにする


In [6]:
# map data correctly
def map_to_encoder_decoder_inputs(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512)
    # force summarization <= 128
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=128)
    
    
    batch["input_ids"] = inputs.input_ids # inputsのID
    batch["attention_mask"] = inputs.attention_mask # 　encoderの重要部分を測る

    batch["decoder_input_ids"] = outputs.input_ids # outputsのID
    batch["labels"] = outputs.input_ids.copy() # outputsのIDをコピーしラベルとして使用
    # mask loss for padding
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
    ] # attentionの値を際立たせてる？　＜＝　聞こう
    
    batch["decoder_attention_mask"] = outputs.attention_mask # decoderの重要部分を測る

    assert all([len(x) == 512 for x in inputs.input_ids])  # "assert 条件式, 条件式がFalseの場合に出力するメッセージ"
    assert all([len(x) == 128 for x in outputs.input_ids])  # "assert 条件式, 条件式がFalseの場合に出力するメッセージ"

    return batch


## 損失関数

- rougeに突っ込んで計算

In [7]:
def compute_metrics(pred):
    
    labels_ids = pred.label_ids # 参照データのID
    pred_ids = pred.predictions #  予測結果のID

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) #予測結果の不要トークンの削除
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) #参照データの不要トークンの削除

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid # ←具体的には何やってるかわからん

    return {
        "rouge2_precision": round(rouge_output.precision, 4), #精度
        "rouge2_recall": round(rouge_output.recall, 4),              #再現性
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),#F値
    }

## Main

In [8]:
# set batch size here
batch_size = 16

# make train dataset ready
train_dataset = train_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

# same for validation dataset
val_dataset = val_dataset.map(
    map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=["article", "highlights"],
)
val_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)


In [9]:
train_dataset

Dataset(features: {'id': Value(dtype='string', id=None), 'input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_input_ids': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), 'decoder_attention_mask': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)}, num_rows: 2871)

In [17]:
# set training arguments - these params are not really tuned, feel free to change
training_args = TrainingArguments(
    output_dir="./",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # predict_from_generate=True,
    evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,
    save_steps=1000,
    # eval_steps=1000,
    overwrite_output_dir=True,
    warmup_steps=2000,
    save_total_limit=10,
)

In [18]:
# instantiate trainer
trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


INFO:transformers.training_args:PyTorch: setting up devices
INFO:transformers.trainer:You are instantiating a Trainer but W&B is not installed. To use wandb logging, run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface.


In [19]:
# start training
trainer.train()

INFO:transformers.trainer:***** Running training *****
INFO:transformers.trainer:  Num examples = 2871
INFO:transformers.trainer:  Num Epochs = 3
INFO:transformers.trainer:  Instantaneous batch size per device = 16
INFO:transformers.trainer:  Total train batch size (w. parallel, distributed & accumulation) = 16
INFO:transformers.trainer:  Gradient Accumulation steps = 1
INFO:transformers.trainer:  Total optimization steps = 540


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=180.0, style=ProgressStyle(description_wi…






HBox(children=(FloatProgress(value=0.0, description='Iteration', max=180.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=180.0, style=ProgressStyle(description_wi…

INFO:transformers.trainer:

Training completed. Do not forget to share your model on huggingface.co/models =)








TrainOutput(global_step=540, training_loss=6.535290581208688)

In [21]:
model.save_pretrained("bert2bert")

INFO:transformers.configuration_utils:Configuration saved in bert2bert/config.json
INFO:transformers.modeling_utils:Model weights saved in bert2bert/pytorch_model.bin


In [3]:
model.from_pretrained("bert2bert")

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [12]:
input_ids = torch.tensor(tokenizer.encode("Harry Potter is a series of seven fantasy novels written by British author J. K. Rowling. The novels chronicle the lives of a young wizard, Harry Potter, and his friends Hermione Granger and Ron Weasley, all of whom are students at Hogwarts School of Witchcraft and Wizardry. The main story arc concerns Harry's struggle against Lord Voldemort, a dark wizard who intends to become immortal, overthrow the wizard governing body known as the Ministry of Magic and subjugate all wizards and Muggles (non-magical people).", add_special_tokens=True)).unsqueeze(0)  # Batch size 1


In [13]:
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)

In [14]:
generated

tensor([[   0, 1012, 1012, 1010, 1010, 1010, 1010, 1010, 1010, 1010, 1010, 1010,
         1010, 1010, 1010, 1010, 1010, 1010, 1010, 1010]])