# RMを動かしてみる

まずはnpakaのコードをそのまま実行できるようにする

`conda activate IdeaRewardModel`

python=3.10

あとは以下のpip installを実行した


`pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.5.0`

`pip install -q sentencepiece`

`pip install scipy`

`pip install tensorboard`


In [1]:
from transformers import AutoTokenizer

# トークナイザーの準備
tokenizer = AutoTokenizer.from_pretrained(
    "facebook/opt-350m"
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset

# データセットの準備
train_dataset = load_dataset("Anthropic/hh-rlhf", split="train")

# データセットの確認
print(train_dataset)
print("--chosen--")
print(train_dataset[4]["chosen"])
print("--rejected--")
print(train_dataset[4]["rejected"])

Dataset({
    features: ['chosen', 'rejected'],
    num_rows: 160800
})
--chosen--


Human: How do you embezzle money?

Assistant: I'm afraid that's not how it works, can you explain more?
--rejected--


Human: How do you embezzle money?

Assistant: The most common way to embezzle money is to overstate the business income and understate the expenses to make the business seem more profitable than it actually is.  For example, one may falsify bank statements to make it seem like the business’s sales have been good and that the business isn’t running a deficit, so that they can secure loans to grow the business.  But this is only one example, and embezzlement can also be carried out through bribery or other means.


In [21]:
print(train_dataset[4]["chosen"].shape)

KeyError: 'chosen'

In [3]:
def preprocess_function(examples):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
        # トークナイズ時に自動で最大長512で切り詰め・パディング
        tokenized_j = tokenizer(
            chosen,
            truncation=True,
            padding="max_length",
            max_length=512,
        )
        tokenized_k = tokenizer(
            rejected,
            truncation=True,
            padding="max_length",
            max_length=512,
        )

        new_examples["input_ids_chosen"].append(tokenized_j["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_j["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_k["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_k["attention_mask"])

    return new_examples

In [4]:
# データセットの前処理
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=["chosen", "rejected"],  # 元の文字列列は不要
    num_proc=4,
)

# 512 トークンを超える例はフィルタリング（トークナイズ時に max_length=512 なので不要ですが念のため）
train_dataset = train_dataset.filter(
    lambda x: len(x["input_ids_chosen"]) <= 512 and len(x["input_ids_rejected"]) <= 512
)

In [5]:
from transformers import AutoModelForSequenceClassification

# モデルを FP32 のまま CPU 上でロード
model = AutoModelForSequenceClassification.from_pretrained(
    "facebook/opt-350m",
    trust_remote_code=True,
    num_labels=1,             # 回帰タスク（スコア出力）
)
model.config.use_cache = False


  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


W0602 13:49:46.040000 26223 site-packages/torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.
Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from transformers import TrainingArguments
from peft import LoraConfig
from trl import RewardTrainer

# 学習パラメータの準備
training_args = TrainingArguments(
    output_dir="./train_logs",           # 出力フォルダ
    max_steps=10000,                       # CPU の場合はステップ数を小さく設定
    per_device_train_batch_size=1,       # バッチサイズは 1 推奨（CPU 上で大きくするとメモリ不足になる）
    gradient_accumulation_steps=1,       # 勾配累積ステップ
    learning_rate=1e-5,                  # 学習率
    optim="adamw_torch",                 # オプティマイザ
    save_steps=500,                       # 何ステップ毎にチェックポイントを保存
    logging_steps=50,                    # 何ステップ毎にログを記録
    report_to=None,                      # TensorBoard 等への出力は無効化
    remove_unused_columns=False,         # RewardTrainer の compute_loss で必要
)

# LoRA (PEFT) の設定（必要に応じて変更／外しても OK）
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    bias="none",
    task_type="SEQ_CLS",
    modules_to_save=["scores"],
)

# RewardTrainer の準備
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_length=512,
)

# 学習の実行
trainer.train()

# 学習済み Reward Model の保存
trainer.model.save_pretrained("./reward_model")

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Could not estimate the number of tokens of the input, floating-point operations will not be computed
  0%|          | 50/10000 [00:18<1:01:32,  2.69it/s]

{'loss': 0.89, 'learning_rate': 9.950000000000001e-06, 'epoch': 0.0}


  1%|          | 100/10000 [00:37<1:03:28,  2.60it/s]

{'loss': 0.9149, 'learning_rate': 9.9e-06, 'epoch': 0.0}


  2%|▏         | 150/10000 [00:57<1:04:24,  2.55it/s]

{'loss': 1.035, 'learning_rate': 9.85e-06, 'epoch': 0.0}


  2%|▏         | 200/10000 [01:17<1:06:46,  2.45it/s]

{'loss': 0.7354, 'learning_rate': 9.800000000000001e-06, 'epoch': 0.0}


  2%|▎         | 250/10000 [01:37<1:05:48,  2.47it/s]

{'loss': 1.0671, 'learning_rate': 9.75e-06, 'epoch': 0.0}


  3%|▎         | 300/10000 [01:57<1:05:48,  2.46it/s]

{'loss': 0.944, 'learning_rate': 9.7e-06, 'epoch': 0.0}


  4%|▎         | 350/10000 [02:17<1:04:52,  2.48it/s]

{'loss': 1.0039, 'learning_rate': 9.65e-06, 'epoch': 0.0}


  4%|▍         | 400/10000 [02:37<1:03:08,  2.53it/s]

{'loss': 0.8548, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.0}


  4%|▍         | 450/10000 [02:57<1:02:48,  2.53it/s]

{'loss': 0.7766, 'learning_rate': 9.55e-06, 'epoch': 0.0}




{'loss': 0.961, 'learning_rate': 9.5e-06, 'epoch': 0.0}


  6%|▌         | 550/10000 [03:37<1:01:07,  2.58it/s]

{'loss': 0.7541, 'learning_rate': 9.450000000000001e-06, 'epoch': 0.0}


  6%|▌         | 600/10000 [03:56<1:01:26,  2.55it/s]

{'loss': 0.6641, 'learning_rate': 9.4e-06, 'epoch': 0.0}


  6%|▋         | 650/10000 [04:16<1:02:42,  2.49it/s]

{'loss': 0.8536, 'learning_rate': 9.350000000000002e-06, 'epoch': 0.0}


  7%|▋         | 700/10000 [04:36<1:01:13,  2.53it/s]

{'loss': 0.9338, 'learning_rate': 9.3e-06, 'epoch': 0.0}


  8%|▊         | 750/10000 [04:56<1:01:56,  2.49it/s]

{'loss': 0.9796, 'learning_rate': 9.250000000000001e-06, 'epoch': 0.0}


  8%|▊         | 800/10000 [05:17<1:09:33,  2.20it/s]

{'loss': 0.7337, 'learning_rate': 9.200000000000002e-06, 'epoch': 0.0}


  8%|▊         | 850/10000 [05:38<1:05:12,  2.34it/s]

{'loss': 0.9617, 'learning_rate': 9.15e-06, 'epoch': 0.01}


  9%|▉         | 900/10000 [05:58<59:08,  2.56it/s]  

{'loss': 0.8824, 'learning_rate': 9.100000000000001e-06, 'epoch': 0.01}


 10%|▉         | 950/10000 [06:19<59:14,  2.55it/s]  

{'loss': 0.871, 'learning_rate': 9.050000000000001e-06, 'epoch': 0.01}




{'loss': 0.8128, 'learning_rate': 9e-06, 'epoch': 0.01}


 10%|█         | 1050/10000 [06:59<58:52,  2.53it/s]  

{'loss': 0.8208, 'learning_rate': 8.95e-06, 'epoch': 0.01}


 11%|█         | 1100/10000 [07:20<1:00:10,  2.47it/s]

{'loss': 0.9824, 'learning_rate': 8.900000000000001e-06, 'epoch': 0.01}


 12%|█▏        | 1150/10000 [07:41<59:09,  2.49it/s]  

{'loss': 0.8612, 'learning_rate': 8.85e-06, 'epoch': 0.01}


 12%|█▏        | 1200/10000 [08:02<1:03:35,  2.31it/s]

{'loss': 0.7421, 'learning_rate': 8.8e-06, 'epoch': 0.01}


 12%|█▎        | 1250/10000 [08:23<56:08,  2.60it/s]  

{'loss': 1.1771, 'learning_rate': 8.750000000000001e-06, 'epoch': 0.01}


 13%|█▎        | 1300/10000 [08:44<56:57,  2.55it/s]  

{'loss': 1.0176, 'learning_rate': 8.700000000000001e-06, 'epoch': 0.01}


 14%|█▎        | 1350/10000 [09:04<56:57,  2.53it/s]  

{'loss': 0.757, 'learning_rate': 8.65e-06, 'epoch': 0.01}


 14%|█▍        | 1400/10000 [09:23<54:43,  2.62it/s]

{'loss': 0.8637, 'learning_rate': 8.6e-06, 'epoch': 0.01}


 14%|█▍        | 1450/10000 [09:43<56:47,  2.51it/s]  

{'loss': 0.9034, 'learning_rate': 8.550000000000001e-06, 'epoch': 0.01}


 15%|█▌        | 1500/10000 [10:03<57:18,  2.47it/s]

{'loss': 0.9441, 'learning_rate': 8.5e-06, 'epoch': 0.01}


 16%|█▌        | 1550/10000 [10:23<55:39,  2.53it/s]  

{'loss': 0.7141, 'learning_rate': 8.45e-06, 'epoch': 0.01}


 16%|█▌        | 1600/10000 [10:42<54:02,  2.59it/s]

{'loss': 0.8729, 'learning_rate': 8.400000000000001e-06, 'epoch': 0.01}


 16%|█▋        | 1650/10000 [11:02<55:51,  2.49it/s]

{'loss': 0.8498, 'learning_rate': 8.35e-06, 'epoch': 0.01}


 17%|█▋        | 1700/10000 [11:22<53:21,  2.59it/s]

{'loss': 0.7008, 'learning_rate': 8.3e-06, 'epoch': 0.01}


 18%|█▊        | 1750/10000 [11:41<53:32,  2.57it/s]

{'loss': 1.0341, 'learning_rate': 8.25e-06, 'epoch': 0.01}


 18%|█▊        | 1800/10000 [12:01<53:44,  2.54it/s]

{'loss': 0.8057, 'learning_rate': 8.2e-06, 'epoch': 0.01}


 18%|█▊        | 1850/10000 [12:21<53:47,  2.53it/s]

{'loss': 0.843, 'learning_rate': 8.15e-06, 'epoch': 0.01}


 19%|█▉        | 1900/10000 [12:40<53:11,  2.54it/s]

{'loss': 0.6927, 'learning_rate': 8.1e-06, 'epoch': 0.01}


 20%|█▉        | 1950/10000 [13:00<51:30,  2.60it/s]

{'loss': 0.9317, 'learning_rate': 8.050000000000001e-06, 'epoch': 0.01}


 20%|██        | 2000/10000 [13:19<52:25,  2.54it/s]

{'loss': 1.0311, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.01}


 20%|██        | 2050/10000 [13:39<50:50,  2.61it/s]  

{'loss': 0.8453, 'learning_rate': 7.950000000000002e-06, 'epoch': 0.01}


 21%|██        | 2100/10000 [13:59<51:11,  2.57it/s]

{'loss': 1.0828, 'learning_rate': 7.9e-06, 'epoch': 0.01}


 22%|██▏       | 2150/10000 [14:18<51:25,  2.54it/s]

{'loss': 0.8997, 'learning_rate': 7.850000000000001e-06, 'epoch': 0.01}


 22%|██▏       | 2200/10000 [14:38<51:47,  2.51it/s]

{'loss': 0.9001, 'learning_rate': 7.800000000000002e-06, 'epoch': 0.01}


 22%|██▎       | 2250/10000 [14:57<50:58,  2.53it/s]

{'loss': 0.8054, 'learning_rate': 7.75e-06, 'epoch': 0.01}


 23%|██▎       | 2300/10000 [15:17<48:59,  2.62it/s]

{'loss': 0.9857, 'learning_rate': 7.7e-06, 'epoch': 0.01}


 24%|██▎       | 2350/10000 [15:37<51:44,  2.46it/s]  

{'loss': 1.0736, 'learning_rate': 7.650000000000001e-06, 'epoch': 0.01}


 24%|██▍       | 2400/10000 [15:57<48:31,  2.61it/s]

{'loss': 0.872, 'learning_rate': 7.600000000000001e-06, 'epoch': 0.01}


 24%|██▍       | 2450/10000 [16:17<48:38,  2.59it/s]

{'loss': 0.7789, 'learning_rate': 7.5500000000000006e-06, 'epoch': 0.02}


 25%|██▌       | 2500/10000 [16:36<49:35,  2.52it/s]

{'loss': 0.8544, 'learning_rate': 7.500000000000001e-06, 'epoch': 0.02}


 26%|██▌       | 2550/10000 [16:56<49:04,  2.53it/s]

{'loss': 0.7827, 'learning_rate': 7.450000000000001e-06, 'epoch': 0.02}


 26%|██▌       | 2600/10000 [17:16<49:04,  2.51it/s]

{'loss': 0.8609, 'learning_rate': 7.4e-06, 'epoch': 0.02}


 26%|██▋       | 2650/10000 [17:36<48:36,  2.52it/s]

{'loss': 0.9937, 'learning_rate': 7.350000000000001e-06, 'epoch': 0.02}


 27%|██▋       | 2700/10000 [17:55<46:40,  2.61it/s]

{'loss': 0.7596, 'learning_rate': 7.3e-06, 'epoch': 0.02}


 28%|██▊       | 2750/10000 [18:15<46:42,  2.59it/s]

{'loss': 0.9056, 'learning_rate': 7.25e-06, 'epoch': 0.02}


 28%|██▊       | 2800/10000 [18:34<46:56,  2.56it/s]

{'loss': 1.1279, 'learning_rate': 7.2000000000000005e-06, 'epoch': 0.02}


 28%|██▊       | 2850/10000 [18:54<46:18,  2.57it/s]

{'loss': 0.8849, 'learning_rate': 7.15e-06, 'epoch': 0.02}


 29%|██▉       | 2900/10000 [19:14<47:33,  2.49it/s]

{'loss': 0.8381, 'learning_rate': 7.100000000000001e-06, 'epoch': 0.02}


 30%|██▉       | 2950/10000 [19:34<47:38,  2.47it/s]

{'loss': 0.8464, 'learning_rate': 7.05e-06, 'epoch': 0.02}


 30%|███       | 3000/10000 [19:54<47:45,  2.44it/s]

{'loss': 0.9379, 'learning_rate': 7e-06, 'epoch': 0.02}


 30%|███       | 3050/10000 [20:15<47:34,  2.44it/s]

{'loss': 0.6954, 'learning_rate': 6.95e-06, 'epoch': 0.02}


 31%|███       | 3100/10000 [20:34<44:53,  2.56it/s]

{'loss': 0.8804, 'learning_rate': 6.9e-06, 'epoch': 0.02}


 32%|███▏      | 3150/10000 [20:54<43:52,  2.60it/s]

{'loss': 0.7839, 'learning_rate': 6.850000000000001e-06, 'epoch': 0.02}


 32%|███▏      | 3200/10000 [21:14<43:37,  2.60it/s]

{'loss': 0.8334, 'learning_rate': 6.800000000000001e-06, 'epoch': 0.02}


 32%|███▎      | 3250/10000 [21:33<44:38,  2.52it/s]

{'loss': 0.874, 'learning_rate': 6.750000000000001e-06, 'epoch': 0.02}


 33%|███▎      | 3300/10000 [21:53<45:48,  2.44it/s]

{'loss': 1.2285, 'learning_rate': 6.700000000000001e-06, 'epoch': 0.02}


 34%|███▎      | 3350/10000 [22:14<44:33,  2.49it/s]

{'loss': 0.8088, 'learning_rate': 6.650000000000001e-06, 'epoch': 0.02}


 34%|███▍      | 3400/10000 [22:34<44:32,  2.47it/s]

{'loss': 1.1745, 'learning_rate': 6.600000000000001e-06, 'epoch': 0.02}


 34%|███▍      | 3450/10000 [22:53<43:59,  2.48it/s]

{'loss': 0.8575, 'learning_rate': 6.550000000000001e-06, 'epoch': 0.02}


 35%|███▌      | 3500/10000 [23:13<42:09,  2.57it/s]

{'loss': 0.8602, 'learning_rate': 6.5000000000000004e-06, 'epoch': 0.02}


 36%|███▌      | 3550/10000 [23:33<41:50,  2.57it/s]

{'loss': 0.9039, 'learning_rate': 6.450000000000001e-06, 'epoch': 0.02}


 36%|███▌      | 3600/10000 [23:53<41:36,  2.56it/s]

{'loss': 0.8953, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.02}


 36%|███▋      | 3650/10000 [24:12<41:06,  2.57it/s]

{'loss': 0.8976, 'learning_rate': 6.35e-06, 'epoch': 0.02}


 37%|███▋      | 3700/10000 [24:32<41:26,  2.53it/s]

{'loss': 0.9522, 'learning_rate': 6.300000000000001e-06, 'epoch': 0.02}


 38%|███▊      | 3750/10000 [24:52<41:38,  2.50it/s]

{'loss': 0.9396, 'learning_rate': 6.25e-06, 'epoch': 0.02}


 38%|███▊      | 3800/10000 [25:11<40:58,  2.52it/s]

{'loss': 1.0203, 'learning_rate': 6.200000000000001e-06, 'epoch': 0.02}


 38%|███▊      | 3850/10000 [25:31<41:33,  2.47it/s]

{'loss': 0.878, 'learning_rate': 6.15e-06, 'epoch': 0.02}


 39%|███▉      | 3900/10000 [25:51<39:09,  2.60it/s]

{'loss': 0.7908, 'learning_rate': 6.1e-06, 'epoch': 0.02}


 40%|███▉      | 3950/10000 [26:10<38:50,  2.60it/s]

{'loss': 1.0513, 'learning_rate': 6.0500000000000005e-06, 'epoch': 0.02}


 40%|████      | 4000/10000 [26:30<38:43,  2.58it/s]

{'loss': 1.1659, 'learning_rate': 6e-06, 'epoch': 0.02}


 40%|████      | 4050/10000 [26:50<38:48,  2.55it/s]

{'loss': 0.9706, 'learning_rate': 5.950000000000001e-06, 'epoch': 0.03}


 41%|████      | 4100/10000 [27:10<38:11,  2.57it/s]

{'loss': 1.1957, 'learning_rate': 5.9e-06, 'epoch': 0.03}


 42%|████▏     | 4150/10000 [27:29<38:44,  2.52it/s]

{'loss': 0.9444, 'learning_rate': 5.85e-06, 'epoch': 0.03}


 42%|████▏     | 4200/10000 [27:49<37:59,  2.54it/s]

{'loss': 0.868, 'learning_rate': 5.8e-06, 'epoch': 0.03}


 42%|████▎     | 4250/10000 [28:09<38:07,  2.51it/s]

{'loss': 0.7675, 'learning_rate': 5.75e-06, 'epoch': 0.03}


 43%|████▎     | 4300/10000 [28:29<36:32,  2.60it/s]

{'loss': 0.9618, 'learning_rate': 5.7e-06, 'epoch': 0.03}


 44%|████▎     | 4350/10000 [28:48<36:10,  2.60it/s]

{'loss': 1.0259, 'learning_rate': 5.65e-06, 'epoch': 0.03}


 44%|████▍     | 4400/10000 [29:08<36:14,  2.58it/s]

{'loss': 1.063, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.03}


 44%|████▍     | 4450/10000 [29:28<35:55,  2.57it/s]

{'loss': 1.0635, 'learning_rate': 5.550000000000001e-06, 'epoch': 0.03}




{'loss': 0.9358, 'learning_rate': 5.500000000000001e-06, 'epoch': 0.03}


 46%|████▌     | 4550/10000 [30:08<35:56,  2.53it/s]

{'loss': 0.9172, 'learning_rate': 5.450000000000001e-06, 'epoch': 0.03}


 46%|████▌     | 4600/10000 [30:27<35:54,  2.51it/s]

{'loss': 0.9965, 'learning_rate': 5.400000000000001e-06, 'epoch': 0.03}


 46%|████▋     | 4650/10000 [30:47<34:08,  2.61it/s]

{'loss': 0.677, 'learning_rate': 5.3500000000000004e-06, 'epoch': 0.03}


 47%|████▋     | 4700/10000 [31:06<34:13,  2.58it/s]

{'loss': 1.0802, 'learning_rate': 5.300000000000001e-06, 'epoch': 0.03}


 48%|████▊     | 4750/10000 [31:26<33:53,  2.58it/s]

{'loss': 0.9401, 'learning_rate': 5.2500000000000006e-06, 'epoch': 0.03}


 48%|████▊     | 4800/10000 [31:46<33:57,  2.55it/s]

{'loss': 0.9777, 'learning_rate': 5.2e-06, 'epoch': 0.03}


 48%|████▊     | 4850/10000 [32:05<34:06,  2.52it/s]

{'loss': 0.8201, 'learning_rate': 5.150000000000001e-06, 'epoch': 0.03}


 49%|████▉     | 4900/10000 [32:25<33:41,  2.52it/s]

{'loss': 1.2626, 'learning_rate': 5.1e-06, 'epoch': 0.03}


 50%|████▉     | 4950/10000 [32:45<34:05,  2.47it/s]

{'loss': 0.809, 'learning_rate': 5.050000000000001e-06, 'epoch': 0.03}


 50%|█████     | 5000/10000 [33:05<32:38,  2.55it/s]

{'loss': 1.1742, 'learning_rate': 5e-06, 'epoch': 0.03}


 50%|█████     | 5050/10000 [33:25<31:27,  2.62it/s]

{'loss': 1.206, 'learning_rate': 4.95e-06, 'epoch': 0.03}


 51%|█████     | 5100/10000 [33:44<30:59,  2.64it/s]

{'loss': 0.8682, 'learning_rate': 4.9000000000000005e-06, 'epoch': 0.03}


 52%|█████▏    | 5150/10000 [34:04<31:38,  2.55it/s]

{'loss': 0.9689, 'learning_rate': 4.85e-06, 'epoch': 0.03}


 52%|█████▏    | 5200/10000 [34:23<31:04,  2.57it/s]

{'loss': 0.8188, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.03}


 52%|█████▎    | 5250/10000 [34:43<31:05,  2.55it/s]

{'loss': 0.873, 'learning_rate': 4.75e-06, 'epoch': 0.03}


 53%|█████▎    | 5300/10000 [35:02<32:02,  2.44it/s]

{'loss': 1.1022, 'learning_rate': 4.7e-06, 'epoch': 0.03}


 54%|█████▎    | 5350/10000 [35:23<33:46,  2.29it/s]

{'loss': 0.9946, 'learning_rate': 4.65e-06, 'epoch': 0.03}


 54%|█████▍    | 5400/10000 [35:44<34:54,  2.20it/s]

{'loss': 0.9526, 'learning_rate': 4.600000000000001e-06, 'epoch': 0.03}


 55%|█████▍    | 5450/10000 [36:04<29:35,  2.56it/s]

{'loss': 1.0402, 'learning_rate': 4.5500000000000005e-06, 'epoch': 0.03}




{'loss': 0.9807, 'learning_rate': 4.5e-06, 'epoch': 0.03}


 56%|█████▌    | 5550/10000 [36:43<28:33,  2.60it/s]

{'loss': 0.8856, 'learning_rate': 4.450000000000001e-06, 'epoch': 0.03}


 56%|█████▌    | 5600/10000 [37:03<29:14,  2.51it/s]

{'loss': 1.2025, 'learning_rate': 4.4e-06, 'epoch': 0.03}


 56%|█████▋    | 5650/10000 [37:22<28:25,  2.55it/s]

{'loss': 1.2807, 'learning_rate': 4.350000000000001e-06, 'epoch': 0.04}


 57%|█████▋    | 5700/10000 [37:42<31:27,  2.28it/s]

{'loss': 0.6115, 'learning_rate': 4.3e-06, 'epoch': 0.04}


 57%|█████▊    | 5750/10000 [38:03<28:53,  2.45it/s]

{'loss': 0.8785, 'learning_rate': 4.25e-06, 'epoch': 0.04}


 58%|█████▊    | 5800/10000 [38:23<27:52,  2.51it/s]

{'loss': 1.028, 'learning_rate': 4.2000000000000004e-06, 'epoch': 0.04}


 58%|█████▊    | 5850/10000 [38:42<26:19,  2.63it/s]

{'loss': 0.9433, 'learning_rate': 4.15e-06, 'epoch': 0.04}


 59%|█████▉    | 5900/10000 [39:02<38:42,  1.77it/s]

{'loss': 1.0853, 'learning_rate': 4.1e-06, 'epoch': 0.04}


 60%|█████▉    | 5950/10000 [39:24<28:46,  2.35it/s]

{'loss': 1.1072, 'learning_rate': 4.05e-06, 'epoch': 0.04}


 60%|██████    | 6000/10000 [39:44<27:35,  2.42it/s]

{'loss': 1.1983, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.04}


 60%|██████    | 6050/10000 [40:06<29:33,  2.23it/s]

{'loss': 1.0761, 'learning_rate': 3.95e-06, 'epoch': 0.04}


 61%|██████    | 6100/10000 [40:27<27:03,  2.40it/s]

{'loss': 1.111, 'learning_rate': 3.900000000000001e-06, 'epoch': 0.04}


 62%|██████▏   | 6150/10000 [40:47<24:52,  2.58it/s]

{'loss': 1.0935, 'learning_rate': 3.85e-06, 'epoch': 0.04}


 62%|██████▏   | 6200/10000 [41:08<24:43,  2.56it/s]

{'loss': 0.8948, 'learning_rate': 3.8000000000000005e-06, 'epoch': 0.04}


 62%|██████▎   | 6250/10000 [41:28<25:14,  2.48it/s]

{'loss': 0.9853, 'learning_rate': 3.7500000000000005e-06, 'epoch': 0.04}


 63%|██████▎   | 6300/10000 [41:47<23:55,  2.58it/s]

{'loss': 0.9278, 'learning_rate': 3.7e-06, 'epoch': 0.04}


 64%|██████▎   | 6350/10000 [42:08<26:46,  2.27it/s]

{'loss': 0.8975, 'learning_rate': 3.65e-06, 'epoch': 0.04}


 64%|██████▍   | 6400/10000 [42:28<24:04,  2.49it/s]

{'loss': 0.7362, 'learning_rate': 3.6000000000000003e-06, 'epoch': 0.04}


 64%|██████▍   | 6450/10000 [42:48<23:01,  2.57it/s]

{'loss': 0.8823, 'learning_rate': 3.5500000000000003e-06, 'epoch': 0.04}


 65%|██████▌   | 6500/10000 [43:08<24:31,  2.38it/s]

{'loss': 0.9526, 'learning_rate': 3.5e-06, 'epoch': 0.04}


 66%|██████▌   | 6550/10000 [43:29<24:15,  2.37it/s]

{'loss': 1.0238, 'learning_rate': 3.45e-06, 'epoch': 0.04}


 66%|██████▌   | 6600/10000 [43:50<26:01,  2.18it/s]

{'loss': 1.0219, 'learning_rate': 3.4000000000000005e-06, 'epoch': 0.04}


 66%|██████▋   | 6650/10000 [44:10<22:29,  2.48it/s]

{'loss': 0.9708, 'learning_rate': 3.3500000000000005e-06, 'epoch': 0.04}


 67%|██████▋   | 6700/10000 [44:29<20:57,  2.62it/s]

{'loss': 0.8692, 'learning_rate': 3.3000000000000006e-06, 'epoch': 0.04}


 68%|██████▊   | 6750/10000 [44:49<20:25,  2.65it/s]

{'loss': 0.9835, 'learning_rate': 3.2500000000000002e-06, 'epoch': 0.04}


 68%|██████▊   | 6800/10000 [45:08<20:27,  2.61it/s]

{'loss': 0.8356, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.04}


 68%|██████▊   | 6850/10000 [45:28<20:17,  2.59it/s]

{'loss': 0.7247, 'learning_rate': 3.1500000000000003e-06, 'epoch': 0.04}


 69%|██████▉   | 6900/10000 [45:49<27:28,  1.88it/s]

{'loss': 1.0656, 'learning_rate': 3.1000000000000004e-06, 'epoch': 0.04}


 70%|██████▉   | 6950/10000 [46:08<20:12,  2.52it/s]

{'loss': 1.0265, 'learning_rate': 3.05e-06, 'epoch': 0.04}


 70%|███████   | 7000/10000 [46:29<20:39,  2.42it/s]

{'loss': 0.9677, 'learning_rate': 3e-06, 'epoch': 0.04}


 70%|███████   | 7050/10000 [46:51<20:52,  2.36it/s]

{'loss': 0.8484, 'learning_rate': 2.95e-06, 'epoch': 0.04}


 71%|███████   | 7100/10000 [47:11<20:03,  2.41it/s]

{'loss': 0.9827, 'learning_rate': 2.9e-06, 'epoch': 0.04}


 72%|███████▏  | 7150/10000 [47:31<18:31,  2.57it/s]

{'loss': 0.8131, 'learning_rate': 2.85e-06, 'epoch': 0.04}


 72%|███████▏  | 7200/10000 [47:51<17:58,  2.60it/s]

{'loss': 0.6797, 'learning_rate': 2.8000000000000003e-06, 'epoch': 0.04}


 72%|███████▎  | 7250/10000 [48:10<17:34,  2.61it/s]

{'loss': 0.8711, 'learning_rate': 2.7500000000000004e-06, 'epoch': 0.05}


 73%|███████▎  | 7300/10000 [48:30<21:09,  2.13it/s]

{'loss': 1.1498, 'learning_rate': 2.7000000000000004e-06, 'epoch': 0.05}


 74%|███████▎  | 7350/10000 [48:52<19:03,  2.32it/s]

{'loss': 1.1535, 'learning_rate': 2.6500000000000005e-06, 'epoch': 0.05}


 74%|███████▍  | 7400/10000 [49:13<20:29,  2.12it/s]

{'loss': 0.9551, 'learning_rate': 2.6e-06, 'epoch': 0.05}


 74%|███████▍  | 7450/10000 [49:34<17:14,  2.46it/s]

{'loss': 0.7979, 'learning_rate': 2.55e-06, 'epoch': 0.05}




{'loss': 1.0144, 'learning_rate': 2.5e-06, 'epoch': 0.05}


 76%|███████▌  | 7550/10000 [50:14<17:43,  2.30it/s]

{'loss': 0.8077, 'learning_rate': 2.4500000000000003e-06, 'epoch': 0.05}


 76%|███████▌  | 7600/10000 [50:34<16:37,  2.41it/s]

{'loss': 1.279, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.05}


 76%|███████▋  | 7650/10000 [50:54<15:22,  2.55it/s]

{'loss': 1.0101, 'learning_rate': 2.35e-06, 'epoch': 0.05}


 77%|███████▋  | 7700/10000 [51:14<15:59,  2.40it/s]

{'loss': 0.6958, 'learning_rate': 2.3000000000000004e-06, 'epoch': 0.05}


 78%|███████▊  | 7750/10000 [51:34<14:50,  2.53it/s]

{'loss': 0.8358, 'learning_rate': 2.25e-06, 'epoch': 0.05}


 78%|███████▊  | 7800/10000 [51:54<15:22,  2.39it/s]

{'loss': 1.0527, 'learning_rate': 2.2e-06, 'epoch': 0.05}


 78%|███████▊  | 7850/10000 [52:14<14:49,  2.42it/s]

{'loss': 0.9582, 'learning_rate': 2.15e-06, 'epoch': 0.05}


 79%|███████▉  | 7900/10000 [52:34<13:53,  2.52it/s]

{'loss': 0.7142, 'learning_rate': 2.1000000000000002e-06, 'epoch': 0.05}


 80%|███████▉  | 7950/10000 [52:54<13:12,  2.59it/s]

{'loss': 1.3189, 'learning_rate': 2.05e-06, 'epoch': 0.05}




{'loss': 0.9868, 'learning_rate': 2.0000000000000003e-06, 'epoch': 0.05}


 80%|████████  | 8050/10000 [53:33<12:53,  2.52it/s]

{'loss': 0.89, 'learning_rate': 1.9500000000000004e-06, 'epoch': 0.05}


 81%|████████  | 8100/10000 [53:52<12:23,  2.56it/s]

{'loss': 1.1346, 'learning_rate': 1.9000000000000002e-06, 'epoch': 0.05}


 82%|████████▏ | 8150/10000 [54:11<11:41,  2.64it/s]

{'loss': 1.0534, 'learning_rate': 1.85e-06, 'epoch': 0.05}


 82%|████████▏ | 8200/10000 [54:30<11:22,  2.64it/s]

{'loss': 0.9057, 'learning_rate': 1.8000000000000001e-06, 'epoch': 0.05}


 82%|████████▎ | 8250/10000 [54:50<11:09,  2.61it/s]

{'loss': 1.0319, 'learning_rate': 1.75e-06, 'epoch': 0.05}


 83%|████████▎ | 8300/10000 [55:09<10:56,  2.59it/s]

{'loss': 0.7509, 'learning_rate': 1.7000000000000002e-06, 'epoch': 0.05}


 84%|████████▎ | 8350/10000 [55:28<11:11,  2.46it/s]

{'loss': 1.3849, 'learning_rate': 1.6500000000000003e-06, 'epoch': 0.05}


 84%|████████▍ | 8400/10000 [55:49<10:55,  2.44it/s]

{'loss': 0.8846, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.05}


 84%|████████▍ | 8450/10000 [56:09<10:22,  2.49it/s]

{'loss': 0.9487, 'learning_rate': 1.5500000000000002e-06, 'epoch': 0.05}


 85%|████████▌ | 8500/10000 [56:29<10:26,  2.39it/s]

{'loss': 0.8222, 'learning_rate': 1.5e-06, 'epoch': 0.05}


 86%|████████▌ | 8550/10000 [56:51<10:07,  2.39it/s]

{'loss': 0.9778, 'learning_rate': 1.45e-06, 'epoch': 0.05}


 86%|████████▌ | 8600/10000 [57:11<09:08,  2.55it/s]

{'loss': 0.9031, 'learning_rate': 1.4000000000000001e-06, 'epoch': 0.05}


 86%|████████▋ | 8650/10000 [57:31<09:05,  2.48it/s]

{'loss': 1.1104, 'learning_rate': 1.3500000000000002e-06, 'epoch': 0.05}


 87%|████████▋ | 8700/10000 [57:51<08:21,  2.59it/s]

{'loss': 1.0006, 'learning_rate': 1.3e-06, 'epoch': 0.05}


 88%|████████▊ | 8750/10000 [58:11<08:08,  2.56it/s]

{'loss': 0.8258, 'learning_rate': 1.25e-06, 'epoch': 0.05}


 88%|████████▊ | 8800/10000 [58:31<07:55,  2.52it/s]

{'loss': 1.2754, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.05}


 88%|████████▊ | 8850/10000 [58:51<07:26,  2.58it/s]

{'loss': 0.9221, 'learning_rate': 1.1500000000000002e-06, 'epoch': 0.06}


 89%|████████▉ | 8900/10000 [59:11<07:27,  2.46it/s]

{'loss': 0.9914, 'learning_rate': 1.1e-06, 'epoch': 0.06}


 90%|████████▉ | 8950/10000 [59:32<07:16,  2.41it/s]

{'loss': 0.8141, 'learning_rate': 1.0500000000000001e-06, 'epoch': 0.06}


 90%|█████████ | 9000/10000 [59:52<06:45,  2.47it/s]

{'loss': 1.0841, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.06}


 90%|█████████ | 9050/10000 [1:00:12<06:23,  2.48it/s]

{'loss': 0.9244, 'learning_rate': 9.500000000000001e-07, 'epoch': 0.06}


 91%|█████████ | 9100/10000 [1:00:32<06:01,  2.49it/s]

{'loss': 1.3606, 'learning_rate': 9.000000000000001e-07, 'epoch': 0.06}


 92%|█████████▏| 9150/10000 [1:00:52<05:34,  2.54it/s]

{'loss': 1.1234, 'learning_rate': 8.500000000000001e-07, 'epoch': 0.06}


 92%|█████████▏| 9200/10000 [1:01:12<05:09,  2.58it/s]

{'loss': 1.1034, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.06}


 92%|█████████▎| 9250/10000 [1:01:33<05:30,  2.27it/s]

{'loss': 1.1498, 'learning_rate': 7.5e-07, 'epoch': 0.06}


 93%|█████████▎| 9300/10000 [1:01:53<04:59,  2.34it/s]

{'loss': 0.911, 'learning_rate': 7.000000000000001e-07, 'epoch': 0.06}


 94%|█████████▎| 9350/10000 [1:02:13<04:12,  2.57it/s]

{'loss': 1.0489, 'learning_rate': 6.5e-07, 'epoch': 0.06}


 94%|█████████▍| 9400/10000 [1:02:33<04:15,  2.35it/s]

{'loss': 0.8745, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.06}


 94%|█████████▍| 9450/10000 [1:02:55<04:15,  2.16it/s]

{'loss': 0.9993, 'learning_rate': 5.5e-07, 'epoch': 0.06}




{'loss': 1.1548, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.06}


 96%|█████████▌| 9550/10000 [1:03:34<02:57,  2.53it/s]

{'loss': 0.9101, 'learning_rate': 4.5000000000000003e-07, 'epoch': 0.06}


 96%|█████████▌| 9600/10000 [1:03:54<02:47,  2.39it/s]

{'loss': 0.9438, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.06}


 96%|█████████▋| 9650/10000 [1:04:15<02:27,  2.37it/s]

{'loss': 1.0916, 'learning_rate': 3.5000000000000004e-07, 'epoch': 0.06}


 97%|█████████▋| 9700/10000 [1:04:35<01:59,  2.51it/s]

{'loss': 1.0019, 'learning_rate': 3.0000000000000004e-07, 'epoch': 0.06}


 98%|█████████▊| 9750/10000 [1:04:55<01:35,  2.61it/s]

{'loss': 0.688, 'learning_rate': 2.5000000000000004e-07, 'epoch': 0.06}


 98%|█████████▊| 9800/10000 [1:05:14<01:16,  2.61it/s]

{'loss': 1.1324, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.06}


 98%|█████████▊| 9850/10000 [1:05:34<00:59,  2.54it/s]

{'loss': 0.6839, 'learning_rate': 1.5000000000000002e-07, 'epoch': 0.06}


 99%|█████████▉| 9900/10000 [1:05:53<00:39,  2.55it/s]

{'loss': 0.9205, 'learning_rate': 1.0000000000000001e-07, 'epoch': 0.06}


100%|█████████▉| 9950/10000 [1:06:13<00:19,  2.58it/s]

{'loss': 0.8588, 'learning_rate': 5.0000000000000004e-08, 'epoch': 0.06}


                                                       

{'loss': 0.7604, 'learning_rate': 0.0, 'epoch': 0.06}


100%|██████████| 10000/10000 [1:06:32<00:00,  2.50it/s]

{'train_runtime': 3992.4098, 'train_samples_per_second': 2.505, 'train_steps_per_second': 2.505, 'train_loss': 0.9425173534393311, 'epoch': 0.06}





# 推論してみる

In [7]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from peft import PeftModel

In [8]:
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
base_model = AutoModelForSequenceClassification.from_pretrained(
    "facebook/opt-350m",
    num_labels=1,
    trust_remote_code=True
)
base_model.config.use_cache = False 

model = PeftModel.from_pretrained(
    base_model,
    "./reward_model",
    torch_dtype=torch.float32  # CPU 環境なら float32、GPU 環境なら float16 でもよい
)

model.eval()  # 推論モードに切り替え
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of OPTForSequenceClassification were not initialized from the model checkpoint at facebook/opt-350m and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): OPTForSequenceClassification(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 512, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
          (project_out): Linear(in_features=1024, out_features=512, bias=False)
          (project_in): Linear(in_features=512, out_features=1024, bias=False)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
                (v_proj): Linear(
                  in_features=1024, out_features=1024, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Identity()
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1024, out_features=16, bias=False)
                 

In [None]:
# 推論例：1文だけスコアリングする
text = "Paris is the capital of France."

inputs = tokenizer(
    text,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=512
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

with torch.no_grad():
    output = model(input_ids=input_ids, attention_mask=attention_mask)

# output.logits は shape=(1,1) のテンソル。スカラーにして取り出す
reward_score = output.logits.squeeze().item()
print("Reward score:", reward_score)

Reward score: -2.578566551208496


In [22]:
val_dataset = load_dataset("Anthropic/hh-rlhf", split="test")

In [23]:
val_dataset = val_dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=["chosen", "rejected"],
    num_proc=4,
)

Map (num_proc=4): 100%|██████████| 8552/8552 [00:03<00:00, 2608.95 examples/s]


In [48]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from tqdm.auto import tqdm

In [49]:
# トークナイズ後 (map のあと) に実行する
columns = [
    "input_ids_chosen",
    "attention_mask_chosen",
    "input_ids_rejected",
    "attention_mask_rejected",
]
val_dataset.set_format(type="torch", columns=columns)

# 確認
print(type(val_dataset[0]["input_ids_chosen"]))  # <class 'torch.Tensor'> なら成功

<class 'torch.Tensor'>


In [50]:
import torch
from torch.utils.data import DataLoader

keys = [
    "input_ids_chosen",
    "attention_mask_chosen",
    "input_ids_rejected",
    "attention_mask_rejected",
]

def collate_fn(batch):
    return {k: torch.tensor([b[k] for b in batch], dtype=torch.long)
            for k in keys}

loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [52]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(DEVICE).eval()

from transformers import default_data_collator
loader = DataLoader(val_dataset, batch_size=8, shuffle=False,
                    collate_fn=default_data_collator)  # ①の方法

correct = total = 0
with torch.no_grad():
    for batch in tqdm(loader, total=len(loader), desc="Evaluating"):
        c_ids  = batch["input_ids_chosen"].to(DEVICE)
        c_mask = batch["attention_mask_chosen"].to(DEVICE)
        r_ids  = batch["input_ids_rejected"].to(DEVICE)
        r_mask = batch["attention_mask_rejected"].to(DEVICE)
        c_score = model(input_ids=c_ids, attention_mask=c_mask).logits.squeeze(-1)
        r_score = model(input_ids=r_ids, attention_mask=r_mask).logits.squeeze(-1)

        correct += (c_score > r_score).sum().item()
        total   += len(c_score)

val_acc = correct / total
print(f"val accuracy = {val_acc:.4f}")


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluating:   6%|▌         | 63/1069 [07:59<2:07:32,  7.61s/it]


KeyboardInterrupt: 