In [1]:
!pip install trl



In [5]:
!pip install datasets

Collecting datasets
  Using cached datasets-2.20.0-py3-none-any.whl (547 kB)
Collecting multiprocess
  Using cached multiprocess-0.70.16-py38-none-any.whl (132 kB)
Collecting pyarrow-hotfix
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting xxhash
  Using cached xxhash-3.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
Collecting dill<0.3.9,>=0.3.0
  Using cached dill-0.3.8-py3-none-any.whl (116 kB)
Collecting requests>=2.32.2
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[K     |████████████████████████████████| 64 kB 1.1 MB/s eta 0:00:01
[?25hCollecting pandas
  Using cached pandas-2.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)
Collecting pyarrow>=15.0.0
  Using cached pyarrow-16.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (41.0 MB)
Collecting tzdata>=2022.1
  Using cached tzdata-2024.1-py2.py3-none-any.whl (345 kB)
Installing collected packages: dill, multiprocess, pyarrow-hotfix, xxhash, 

In [6]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 972 kB/s eta 0:00:01
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0


In [7]:
!pip install peft

Collecting peft
  Using cached peft-0.11.1-py3-none-any.whl (251 kB)
Installing collected packages: peft
Successfully installed peft-0.11.1


In [8]:
!pip install tensorboardX

Collecting tensorboardX
  Downloading tensorboardX-2.6.2.2-py2.py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 1.3 MB/s ta 0:00:01
Collecting protobuf>=3.20
  Downloading protobuf-5.27.2-cp38-abi3-manylinux2014_x86_64.whl (309 kB)
[K     |████████████████████████████████| 309 kB 5.4 MB/s eta 0:00:01
[?25hInstalling collected packages: protobuf, tensorboardX
Successfully installed protobuf-5.27.2 tensorboardX-2.6.2.2


In [1]:
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

**Note:**

Add pad_token to resolve `AssertionError`: Cannot handle batch sizes > 1 if no padding token is defined.

In [2]:
rm_dataset = load_dataset("gen_rm_dataset.py", trust_remote_code=True)
rm_dataset

DatasetDict({
    train: Dataset({
        features: ['weibo', 'text_j', 'text_k', 'text_j_like', 'text_k_like'],
        num_rows: 9985
    })
    validation: Dataset({
        features: ['weibo', 'text_j', 'text_k', 'text_j_like', 'text_k_like'],
        num_rows: 3451
    })
    test: Dataset({
        features: ['weibo', 'text_j', 'text_k', 'text_j_like', 'text_k_like'],
        num_rows: 7249
    })
})

In [3]:
def preprocess_function(examples):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for text_j, text_k, text_j_like, text_k_like in zip(examples["text_j"], examples["text_k"], examples["text_j_like"], examples["text_k_like"]):
        if text_j_like >= text_k_like:
            chosen = text_j
            rejected = text_k
        else:
            chosen = text_k
            rejected = text_j        
        
        tokenized_chosen = tokenizer(chosen, padding="max_length", max_length=32, truncation=True) # Same Problem as SFT
        tokenized_rejected = tokenizer(rejected, padding="max_length", max_length=32, truncation=True) # Same Problem as SFT

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

    return new_examples

In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Initial Reward Model
rm_model = AutoModelForSequenceClassification.from_pretrained("gpt2")

# Initial Tokenizer for RM
tokenizer = AutoTokenizer.from_pretrained("gpt2", use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
rm_model.config.pad_token_id = tokenizer.pad_token_id # Suggested by Claude 3.5 Sonnet

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# Preprocess the dataset
rm_dataset = rm_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
)

In [6]:
rm_train_dataset = rm_dataset["train"]
rm_eval_dataset = rm_dataset["validation"]

In [9]:
from trl import RewardTrainer, RewardConfig

config = RewardConfig(
    output_dir="RM_model7",
    per_device_train_batch_size=32,
    num_train_epochs=150,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=5e-3,
    report_to="tensorboard",
    logging_dir="./results/rm7-Latest7",
    remove_unused_columns=False,
    optim="adamw_torch",
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=500,
    max_length=256,
)

In [10]:
peft_config = LoraConfig(
    task_type="SEQ_CLS", # NOT CAUSAL_LM (bug)
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

##### `First` Formal Trial (1560 steps) 
bs = 8, lr ~ 5e-5, 10 epochs

In [19]:
trainer = RewardTrainer(
    model=rm_model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=rm_train_dataset,
    eval_dataset=rm_eval_dataset,
    peft_config=peft_config,
)
trainer.train()
trainer.save_model(config.output_dir)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss,Accuracy
500,0.7246,0.705777,0.509128
1000,0.6919,0.690434,0.538974
1500,0.7041,0.688446,0.541292










##### `Second` Formal Trial 
bs = 16, lr = 2e-3, 40 epochs

In [22]:
trainer = RewardTrainer(
    model=rm_model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=rm_train_dataset,
    eval_dataset=rm_eval_dataset,
    peft_config=peft_config,
)
trainer.train()
trainer.save_model(config.output_dir)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy
500,0.6925,0.70663,0.548826
1000,0.6971,0.708805,0.549406
1500,0.6927,0.711551,0.542741
2000,0.7057,0.70641,0.542162
2500,0.6784,0.709342,0.545059
3000,0.685,0.727206,0.545639
3500,0.6841,0.718486,0.540133
4000,0.6799,0.706571,0.548247
4500,0.668,0.701843,0.551434
5000,0.683,0.70061,0.555781


























##### `Third` Formal Trial (Failed, Gradient Explosion)
bs = 32, lr = 2e-2, 150 epochs, grad_accum = 1

In [33]:
trainer = RewardTrainer(
    model=rm_model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=rm_train_dataset,
    eval_dataset=rm_eval_dataset,
    peft_config=peft_config,
)
trainer.train()
trainer.save_model(config.output_dir)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss,Validation Loss,Accuracy
500,0.7587,0.691762,0.525065
1000,0.73,0.692003,0.532889
1500,0.0,,1.0
2000,0.0,,1.0






NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf fou

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf fou

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.

KeyboardInterrupt



##### `Seventh` Formal Trial (11700 steps, time cost 24:24 on RTX 4090)
bs = 32, lr = 5e-3, grad_accum = 4, 150 epochs

In [12]:
trainer = RewardTrainer(
    model=rm_model,
    tokenizer=tokenizer,
    args=config,
    train_dataset=rm_train_dataset,
    eval_dataset=rm_eval_dataset,
    peft_config=peft_config,
)
trainer.train()
trainer.save_model(config.output_dir)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Accuracy
500,0.6979,0.716391,0.541582
1000,0.6989,0.70338,0.55636
1500,0.7011,0.712082,0.551145
2000,0.6832,0.71182,0.548826
2500,0.6915,0.708923,0.549696
3000,0.6848,0.73357,0.527673
3500,0.6769,0.704638,0.565054
4000,0.6648,0.705384,0.553463
4500,0.6934,0.705815,0.55665
5000,0.7038,0.712788,0.543611
















































#### Some summary

In [20]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
print(metrics)





***** eval metrics *****
  epoch                   =      9.992
  eval_accuracy           =      0.541
  eval_loss               =     0.6884
  eval_runtime            = 0:00:08.94
  eval_samples_per_second =    385.626
  eval_steps_per_second   =     48.273
{'eval_loss': 0.6884294748306274, 'eval_accuracy': 0.5410026079397277, 'eval_runtime': 8.9491, 'eval_samples_per_second': 385.626, 'eval_steps_per_second': 48.273, 'epoch': 9.9919935948759}


