In [1]:
from transformers.trainer_utils import set_seed

set_seed(42)

In [2]:
from datasets import load_dataset

train_dataset = load_dataset("llm-book/aio-retriever", split="train")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/3.58k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/2.56k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/637M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/28.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/22335 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [3]:
print(train_dataset)

Dataset({
    features: ['qid', 'competition', 'timestamp', 'section', 'number', 'original_question', 'original_answer', 'original_additional_info', 'question', 'answers', 'passages', 'positive_passage_indices', 'negative_passage_indices'],
    num_rows: 22335
})


In [4]:
from pprint import pprint

pprint(train_dataset[0])

{'answers': ['26文字'],
 'competition': 'abc ～the first～',
 'negative_passage_indices': [1,
                              2,
                              3,
                              4,
                              5,
                              6,
                              7,
                              8,
                              9,
                              10,
                              11,
                              12,
                              13,
                              14,
                              15,
                              16,
                              17,
                              18,
                              19,
                              20,
                              21,
                              22,
                              23,
                              24,
                              25,
                              26,
                              27,
                              28,


In [5]:
# BPRの訓練には、正例と負例がそれぞれ少なくとも1つ必要
# どちらかが0のものを除外
train_dataset = train_dataset.filter(
    lambda x: (
        len(x["positive_passage_indices"]) > 0 and len(x["negative_passage_indices"]) > 0
    )
)

Filter:   0%|          | 0/22335 [00:00<?, ? examples/s]

In [6]:
# BPRのオリジナルの実装では、質問おｔの関連度が最も高い正例パッセージのみを使用する. それに従う
def filter_passages(example: dict) -> dict:
    example["positive_passage_indices"] = example["positive_passage_indices"][0]
    return example


train_dataset = train_dataset.map(filter_passages)

Map:   0%|          | 0/19596 [00:00<?, ? examples/s]

In [7]:
print(train_dataset)

Dataset({
    features: ['qid', 'competition', 'timestamp', 'section', 'number', 'original_question', 'original_answer', 'original_additional_info', 'question', 'answers', 'passages', 'positive_passage_indices', 'negative_passage_indices'],
    num_rows: 19596
})


In [8]:
# 検証セットに対しても同様に前処理
valid_dataset = load_dataset(
    "llm-book/aio-retriever",
    split="validation",
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [9]:
print(valid_dataset)

Dataset({
    features: ['qid', 'competition', 'timestamp', 'section', 'number', 'original_question', 'original_answer', 'original_additional_info', 'question', 'answers', 'passages', 'positive_passage_indices', 'negative_passage_indices'],
    num_rows: 1000
})


In [10]:
valid_dataset = valid_dataset.filter(
    lambda x: (
        len(x["positive_passage_indices"]) > 0 and len(x["negative_passage_indices"]) > 0
    )
)

Filter:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [11]:
# 検証セットでは、常に同じ内容でモデルの評価を行えるようにするため、正例も負例も最初の1つだけ使う
def filter_passages(example: dict) -> dict:
    example["positive_passage_indices"] = example["positive_passage_indices"][0]
    example["negative_passage_indices"] = example["negative_passage_indices"][0]
    return example


valid_dataset = valid_dataset.map(filter_passages)

Map:   0%|          | 0/864 [00:00<?, ? examples/s]

In [12]:
print(valid_dataset)

Dataset({
    features: ['qid', 'competition', 'timestamp', 'section', 'number', 'original_question', 'original_answer', 'original_additional_info', 'question', 'answers', 'passages', 'positive_passage_indices', 'negative_passage_indices'],
    num_rows: 864
})


In [13]:
from transformers import AutoTokenizer


base_model_name = "tohoku-nlp/bert-base-japanese-v3"
tokenizer = AutoTokenizer.from_pretrained(base_model_name)

tokenizer_config.json:   0%|          | 0.00/251 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/231k [00:00<?, ?B/s]

loading file vocab.txt from cache at C:\Users\ykite\.cache\huggingface\hub\models--tohoku-nlp--bert-base-japanese-v3\snapshots\65243d6e5629b969c77309f217bd7b1a79d43c7e\vocab.txt
loading file spiece.model from cache at None
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at C:\Users\ykite\.cache\huggingface\hub\models--tohoku-nlp--bert-base-japanese-v3\snapshots\65243d6e5629b969c77309f217bd7b1a79d43c7e\tokenizer_config.json
loading file tokenizer.json from cache at None


In [14]:
import random
import torch
from torch import Tensor
from transformers import BatchEncoding

In [15]:
def collate_fn(
    examples: list[dict]
) -> dict[str, BatchEncoding | Tensor]:
    """BPRの訓練・検証データのミニバッチを作成"""
    questions: list[str] = []
    passage_titles: list[str] = []
    passage_texts: list[str] = []

    for example in examples:
        questions.append(example["question"])

        # 正例と負例のパッセージを一つずつ取り出す
        positive_passage_idx = random.choice(example["positive_passage_indices"])
        negative_passage_idx = random.choice(example["negative_passage_indices"])

        passage_titles.extend([
            example["passages"][positive_passage_idx]["title"],
            example["passages"][negative_passage_idx]["title"],
        ])
        passage_texts.extend([
            example["passages"][positive_passage_idx]["text"],
            example["passages"][negative_passage_idx]["text"],
        ])

    # 質問とパッセージにトークナイザ適用
    tokenized_questions = tokenizer(
        questions,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt",
    )
    tokenized_passages = tokenizer(
        passage_titles,
        passage_texts,
        padding=True,
        truncation="only_second",
        max_length=256,
        return_tensors="pt",
    )

    # 質問とパッセージのスコア行列（類似度行列）における正例の位置を示すテンソル
    # [0, 1, ..., len(questions)-1] 行目に対して、
    # [0, 2, ..., 2 * (len(questions)-1)] 列目の要素が正例
    labels = torch.arange(0, 2 * len(questions), 2)

    return {
        "tokenized_questions": tokenized_questions,
        "tokenized_passages": tokenized_passages,
        "labels": labels,
    }

In [16]:
import math
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
from transformers.utils import ModelOutput

In [17]:
class BPRModel(nn.Module):
    """Binary Passage Retrieverのモデル"""

    def __init__(self, base_model_name: str):
        super().__init__()
        # 質問エンコーダ
        self.question_encoder = AutoModel.from_pretrained(base_model_name)
        # パッセージエンコーダ
        self.passage_encoder = AutoModel.from_pretrained(base_model_name)
        # モデルの訓練ステップ数
        self.global_step = 0

    def binary_encode(self, x: Tensor) -> Tensor:
        """実数埋め込みをバイナリ埋め込みに変換"""
        if self.training:
            # 訓練時: sign関数を近似したtanh関数によりベクトルの変換を行う
            # global_step -> ∞の時 tanh -> sign に収束
            return torch.tanh(
                # pow(a, 0.5) は√
                x * math.pow((1.0 + self.global_step * 0.1), 0.5)
            )
        else:
            # 評価時: sign関数によりベクトルの2値化を行う
            return torch.where(x >= 0, 1.0, -1.0).to(x.device)

    def encode_questions(self, tokenized_questions: BatchEncoding) -> tuple[Tensor, Tensor]:
        """質問を実数埋め込みとバイナリ埋め込みに変換"""
        encoded = self.question_encoder(**tokenized_questions).last_hidden_state[:, 0]
        binary_encoded = self.binary_encode(encoded)
        return encoded, binary_encoded

    def encode_passages(self, tokenized_passages: BatchEncoding) -> Tensor:
        """パッセージをバイナリ埋め込みに変換"""
        encoded = self.passage_encoder(**tokenized_passages).last_hidden_state[:, 0]
        binary_encoded = self.binary_encode(encoded)
        return binary_encoded

    def compute_loss(
        self,
        encoded_questions: Tensor,
        binary_encoded_questions: Tensor,
        binary_encoded_passages: Tensor,
        labels: Tensor,
    ):
        """損失を計算

        :param labels: [0, 2, 4, ..., len(questions) - 1]
        """
        # num_questions = encoded_questions.size(0)
        num_passages = binary_encoded_passages.size(0)

        # 候補パッセージ生成の損失
        # 質問のバイナリ埋込とパッセージのバイナリ埋込の内積を用いる
        # 正例パッセージのスコアと負例パッセージのスコアのランキング損失を計算
        binary_scores = torch.matmul(
            binary_encoded_questions,
            binary_encoded_passages.transpose(0, 1)
        )
        # 正例のマスク
        # [
        #    [1, 0, 0, 0, 0, ...],
        #    [0, 0, 1, 0, 0, ...],
        #    [0, 0, 0, 0, 1, ...],
        #    :,
        # ]
        positive_mask = F.one_hot(
            labels,
            num_classes=num_passages,
        ).bool()

        positive_binary_scores = torch.masked_select(
            binary_scores, positive_mask
        ).repeat_interleave(num_passages - 1)
        negative_binary_scores = torch.masked_select(
            binary_scores, ~positive_mask
        )
        target = torch.ones_like(positive_binary_scores).long()
        loss_cand = F.margin_ranking_loss(
            positive_binary_scores,
            negative_binary_scores,
            target,
            margin=0.1,
        )

        # 候補パッセージのリランキングの損失を計算する
        # 質問の実数埋め込みとパッセージのバイナリ埋め込みの内積を
        # スコアに用いて、正例パッセージのスコアと負例パッセージのスコアの
        # 交差エントロピー損失を計算する
        dense_scores = torch.matmul(
            encoded_questions, binary_encoded_passages.transpose(0, 1)
        )
        loss_rerank = F.cross_entropy(dense_scores, labels)

        loss = loss_cand + loss_rerank
        return loss

    def forward(
        self,
        tokenized_questions: BatchEncoding,
        tokenized_passages: BatchEncoding,
        labels: Tensor,
    ) -> ModelOutput:
        """モデルの前向き計算を定義"""
        # 質問とパッセージを埋め込みに変換する
        encoded_questions, binary_encoded_questions = (
            self.encode_questions(tokenized_questions)
        )
        binary_encoded_passages = self.encode_passages(
            tokenized_passages
        )

        # BPRの損失を計算する
        loss = self.compute_loss(
            encoded_questions,
            binary_encoded_questions,
            binary_encoded_passages,
            labels,
        )

        # モデルの訓練ステップ数のカウンタを増やす
        if self.training:
            self.global_step += 1

        return ModelOutput(loss=loss)

- 以下、著者のモデルを Hugging Face Hub から読み込むため実際の訓練は行わない