-
Notifications
You must be signed in to change notification settings - Fork 164
/
rm_trainer.py
290 lines (247 loc) · 11.7 KB
/
rm_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import math
from abc import ABC
import loralib as lora
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from tqdm import tqdm
from openrlhf.models import LogExpLoss, PairWiseLoss
class RewardModelTrainer(ABC):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(
self,
model,
strategy,
optim: Optimizer,
train_dataloader,
eval_dataloader,
scheduler,
tokenizer,
max_norm=0.5,
max_epochs: int = 2,
loss="sigmoid",
) -> None:
super().__init__()
self.strategy = strategy
self.epochs = max_epochs
self.max_norm = max_norm
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.scheduler = scheduler
self.optimizer = optim
self.tokenizer = tokenizer
self.args = strategy.args
if loss == "sigmoid":
self.loss_fn = PairWiseLoss()
self.strategy.print("LogSigmoid Loss")
else:
self.loss_fn = LogExpLoss()
self.strategy.print("LogExp Loss")
# Mixtral 8*7b
self.aux_loss = self.args.aux_loss_coef > 1e-8
self.margin_loss = self.strategy.args.margin_loss
self.compute_fp32_loss = self.strategy.args.compute_fp32_loss
self._wandb = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/global_step")
wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True)
def fit(self, args):
# get eval and save steps
if args.eval_steps == -1:
args.eval_steps = self.train_dataloader.__len__() # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # do not save ckpt
global_step = 1
epoch_bar = tqdm(range(self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0())
for epoch in range(self.epochs):
# train
step_bar = tqdm(
range(self.train_dataloader.__len__()),
desc="Train step of epoch %d" % epoch,
disable=not self.strategy.is_rank_0(),
)
if isinstance(self.train_dataloader.sampler, DistributedSampler):
self.train_dataloader.sampler.set_epoch(epoch)
self.model.train()
acc_mean = 0
loss_mean = 0
for chosen_ids, c_mask, reject_ids, r_mask, margin in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
if self.margin_loss:
margin = torch.tensor(margin).to(torch.cuda.current_device())
else:
margin = None
chosen_reward, reject_reward, aux_loss = self.concatenated_forward(
self.model, chosen_ids, c_mask, reject_ids, r_mask
)
# loss function
if self.compute_fp32_loss:
chosen_reward = chosen_reward.float()
reject_reward = reject_reward.float()
preference_loss = self.loss_fn(chosen_reward, reject_reward, margin)
# mixtral
if not self.aux_loss:
aux_loss = 0
loss = preference_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)
acc_mean = acc_mean * 0.9 + 0.1 * (chosen_reward > reject_reward).float().mean().item()
loss_mean = loss_mean * 0.9 + 0.1 * preference_loss.item()
# optional rm info
logs_dict = {
"preference_loss": preference_loss.item(),
"chosen_reward": chosen_reward.mean().item(),
"reject_reward": reject_reward.mean().item(),
"acc_mean": acc_mean,
"loss_mean": loss_mean,
}
if self.aux_loss:
logs_dict["aux_loss"] = aux_loss.item()
# logs/checkpoints/evaluate
self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict)
step_bar.update()
global_step += 1
epoch_bar.update()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
# logs/checkpoints/evaluate
def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}):
if global_step % args.logging_steps == 0:
# step bar
logs_dict = self.strategy.all_reduce(logs_dict)
step_bar.set_postfix(logs_dict)
# wandb
if (
self._wandb is not None
and self.strategy.is_rank_0()
and global_step % self.strategy.accumulated_gradient == 0
):
logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
self._wandb.log(logs)
# eval
if global_step % args.eval_steps == 0:
self.evaluate(self.eval_dataloader, global_step)
# save ckpt
# TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self.strategy.save_ckpt(self.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem)
def evaluate(self, eval_dataloader, steps=0):
step_bar = tqdm(
range(eval_dataloader.__len__()),
desc="Eval stage of steps %d" % steps,
disable=not self.strategy.is_rank_0(),
)
self.model.eval()
with torch.no_grad():
acc = 0
rewards = []
loss_sum = 0
for chosen_ids, c_mask, reject_ids, r_mask, margin in eval_dataloader:
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
margin = torch.tensor(margin).to(torch.cuda.current_device())
chosen_reward, reject_reward, _ = self.concatenated_forward(
self.model, chosen_ids, c_mask, reject_ids, r_mask
)
loss = self.loss_fn(chosen_reward, reject_reward, margin)
rewards += [chosen_reward.flatten(), reject_reward.flatten()]
acc += (chosen_reward > reject_reward).float().mean().item()
loss_sum += loss.item()
step_bar.update()
acc_mean = acc / self.eval_dataloader.__len__()
loss_mean = loss_sum / self.eval_dataloader.__len__()
rewards = torch.cat(rewards).float()
rewards = self.strategy.all_gather(rewards)
reward_mean = torch.mean(rewards)
reward_std = torch.std(rewards).clamp(min=1e-8)
# save mean std
self.strategy.print("Set reward mean std")
unwrap_model = self.strategy._unwrap_model(self.model)
unwrap_model.config.mean = reward_mean.item()
unwrap_model.config.std = reward_std.item()
bar_dict = {
"eval_loss": loss_mean,
"acc_mean": acc_mean,
"reward_mean": reward_mean.item(),
"reward_std": reward_std.item(),
}
logs = self.strategy.all_reduce(bar_dict)
step_bar.set_postfix(logs)
histgram = torch.histogram(rewards.cpu(), bins=10, range=(-10, 10), density=True) * 2
self.strategy.print("histgram")
self.strategy.print(histgram)
if self._wandb is not None and self.strategy.is_rank_0():
logs = {"eval/%s" % k: v for k, v in {**logs, "global_step": steps}.items()}
self._wandb.log(logs)
self.model.train() # reset model state
def concatenated_forward(self, model, chosen_ids, c_mask, reject_ids, r_mask):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
input_ids, att_masks = self.concatenated_inputs(chosen_ids, c_mask, reject_ids, r_mask)
all_values, output = model(input_ids, attention_mask=att_masks, return_output=True)
chosen_rewards = all_values[: chosen_ids.shape[0]]
rejected_rewards = all_values[chosen_ids.shape[0] :]
aux_loss = output.aux_loss if "aux_loss" in output else []
return chosen_rewards, rejected_rewards, aux_loss
def concatenated_inputs(self, chosen_ids, c_mask, reject_ids, r_mask):
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
def pad_to_length(tensor, length, pad_value, dim=-1):
if tensor.size(dim) >= length:
return tensor
else:
pad_size = list(tensor.shape)
pad_size[dim] = length - tensor.size(dim)
# left pad
return torch.cat(
[pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), tensor], dim=dim
)
max_length = max(chosen_ids.shape[1], reject_ids.shape[1])
inputs_ids = torch.cat(
(
pad_to_length(chosen_ids, max_length, self.tokenizer.pad_token_id),
pad_to_length(reject_ids, max_length, self.tokenizer.pad_token_id),
),
dim=0,
)
max_length = max(c_mask.shape[1], r_mask.shape[1])
att_masks = torch.cat((pad_to_length(c_mask, max_length, 0), pad_to_length(r_mask, max_length, 0)), dim=0)
return inputs_ids, att_masks