Skip to content

Commit

Permalink
Add ArmoRM to RewardBench (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
Haoxiang-Wang committed May 24, 2024
1 parent 2eacee8 commit 0851402
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
10 changes: 10 additions & 0 deletions rewardbench/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import (
AutoModel,
AutoModelForCausalLM,
Expand All @@ -23,6 +24,7 @@
pipeline,
)

from .armorm import ArmoRMPipeline
from .beaver import BeaverCostPipeline, BeaverPipeline, LlamaForScore
from .betterpairrm import BetterPairRMPipeline
from .openassistant import * # noqa
Expand Down Expand Up @@ -130,6 +132,14 @@
"custom_dialogue": True,
"model_type": "Custom Classifier",
},
"RLHFlow/ArmoRM-Llama3-8B-v0.1": {
"model_builder": AutoModelForSequenceClassification.from_pretrained,
"pipeline_builder": ArmoRMPipeline,
"quantized": False,
"custom_dialogue": True,
"model_type": "Custom Classifier",
"torch_dtype": torch.bfloat16,
},
}

DPO_MODEL_CONFIG = {
Expand Down
32 changes: 32 additions & 0 deletions rewardbench/models/armorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import random
from typing import List

import torch


class ArmoRMPipeline:
def __init__(self, task, model, tokenizer):
self.task = task
self.model = model
self.tokenizer = tokenizer
random.seed(0)

def __call__(self, candidates_A: List[str], candidates_B: List[str], **kwargs):
"""
samples: List[str]
"""
device = self.model.device
out = []
with torch.no_grad():
for candidate_A, candidate_B in zip(candidates_A, candidates_B):
pair_scores = []
for candidate in [candidate_A, candidate_B]:
input_ids = self.tokenizer.apply_chat_template(candidate, return_tensors="pt").to(device)
output = self.model(input_ids)
score = output.score.float().item()
pair_scores.append(score)
if pair_scores[0] == pair_scores[1]:
out.append(random.choice([True, False]))
else:
out.append(pair_scores[0] > pair_scores[1])
return torch.Tensor(out).bool()
7 changes: 7 additions & 0 deletions scripts/configs/eval_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,10 @@ RLHFlow/RewardModel-Mistral-7B-for-DPA-v1:
batch_size: 4
trust_remote_code: True
dpo: False
RLHFlow/ArmoRM-Llama3-8B-v0.1:
model: RLHFlow/ArmoRM-Llama3-8B-v0.1
tokenizer: RLHFlow/ArmoRM-Llama3-8B-v0.1
chat_template: # none for tokenizer
batch_size: 4
trust_remote_code: True
dpo: False
11 changes: 9 additions & 2 deletions scripts/run_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from rewardbench.constants import EXAMPLE_COUNTS, SUBSET_MAPPING
from rewardbench.utils import calculate_scores_per_section

# Enable TensorFloat32 (TF32) tensor cores on Ampere GPUs for matrix multiplications (faster than FP32)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# get token from HF_TOKEN env variable, but if it doesn't exist pass none
HF_TOKEN = os.getenv("HF_TOKEN", None)
# this is necessary to automatically log in when running this script in docker/batch beaker jobs
Expand Down Expand Up @@ -117,7 +121,7 @@ def main():
model_type = config["model_type"]
model_builder = config["model_builder"]
pipeline_builder = config["pipeline_builder"]

torch_dtype = config.get("torch_dtype", None)
# not included in config to make user explicitly understand they are passing this
trust_remote_code = args.trust_remote_code

Expand Down Expand Up @@ -167,7 +171,10 @@ def main():
"torch_dtype": torch.float16 if torch.cuda.is_available() else None,
}
else:
model_kwargs = {"device_map": {"": current_device}}
model_kwargs = {
"device_map": {"": current_device},
"torch_dtype": torch_dtype,
}

model = model_builder(args.model, **model_kwargs, trust_remote_code=trust_remote_code)
reward_pipe = pipeline_builder(
Expand Down

0 comments on commit 0851402

Please sign in to comment.