-
Notifications
You must be signed in to change notification settings - Fork 248
/
ppo_trainer.py
515 lines (459 loc) · 20.9 KB
/
ppo_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
import math
import os
import os.path
from abc import ABC
from typing import Any, Callable, Dict, List, Optional, Union
import ray
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from openrlhf.models import Actor, GPTLMLoss, PolicyLoss, ValueLoss
from openrlhf.models.utils import masked_mean
from openrlhf.utils.distributed_sampler import DistributedSampler
from .ppo_utils import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer
class PPOTrainer(ABC):
"""
Trainer for Proximal Policy Optimization (PPO) algorithm.
Args:
strategy (Strategy): The training strategy to use.
actor (Actor): The actor model in the PPO algorithm.
critic (nn.Module): The critic model in the PPO algorithm.
reward_model (nn.Module): The reward model for calculating rewards in the RLHF setup.
initial_model (Actor): The initial model for reference logits to limit actor updates in RLHF.
ema_model (Actor): The exponential moving average model for stable training.
actor_optim (Optimizer): The optimizer for the actor model.
critic_optim (Optimizer): The optimizer for the critic model.
actor_scheduler (Scheduler): The learning rate scheduler for the actor.
critic_scheduler (Scheduler): The learning rate scheduler for the critic.
ema_beta (float, defaults to 0.992): EMA decay rate for model stability.
init_kl_coef (float, defaults to 0.001): Initial coefficient for KL divergence.
kl_target (float, optional): Target value for KL divergence.
kl_horizon (int, defaults to 10000): Horizon for KL annealing.
ptx_coef (float, defaults to 0): Coefficient for supervised loss from pre-trained data.
micro_train_batch_size (int, defaults to 8): Micro-batch size for actor training.
buffer_limit (int, defaults to 0): Maximum size of the replay buffer.
buffer_cpu_offload (bool, defaults to True): If True, offloads replay buffer to CPU.
eps_clip (float, defaults to 0.2): Clipping coefficient for policy loss.
value_clip (float, defaults to 0.2): Clipping coefficient for value function loss.
micro_rollout_batch_size (int, defaults to 8): Micro-batch size for generating rollouts.
gradient_checkpointing (bool, defaults to False): If True, enables gradient checkpointing.
max_epochs (int, defaults to 1): Number of epochs to train.
max_norm (float, defaults to 1.0): Maximum gradient norm for gradient clipping.
tokenizer (Callable, optional): Tokenizer for input data.
prompt_max_len (int, defaults to 128): Maximum length for prompts.
dataloader_pin_memory (bool, defaults to True): If True, pins memory in the data loader.
remote_rm_url (str, optional): URL for remote reward model API.
reward_fn (Callable, optional): Custom reward function for computing rewards.
**generate_kwargs: Additional arguments for model generation.
"""
def __init__(
self,
strategy,
actor: Actor,
critic: nn.Module,
reward_model: nn.Module,
initial_model: Actor,
ema_model: Actor,
actor_optim: Optimizer,
critic_optim: Optimizer,
actor_scheduler,
critic_scheduler,
ema_beta: float = 0.992,
init_kl_coef: float = 0.001,
kl_target: float = None,
kl_horizon: int = 10000,
ptx_coef: float = 0,
micro_train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
value_clip: float = 0.2,
micro_rollout_batch_size: int = 8,
gradient_checkpointing: bool = False,
max_epochs: int = 1,
max_norm: float = 1.0,
tokenizer: Optional[Callable[[Any], dict]] = None,
prompt_max_len: int = 128,
dataloader_pin_memory: bool = True,
remote_rm_url: str = None,
reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
**generate_kwargs,
) -> None:
assert (
not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None
), "reward_fn must be specified if using multiple reward models"
super().__init__()
self.strategy = strategy
self.args = strategy.args
self.micro_rollout_batch_size = micro_rollout_batch_size
self.max_epochs = max_epochs
self.tokenizer = tokenizer
self.generate_kwargs = generate_kwargs
self.dataloader_pin_memory = dataloader_pin_memory
self.max_norm = max_norm
self.ptx_coef = ptx_coef
self.micro_train_batch_size = micro_train_batch_size
self.kl_target = kl_target
self.prompt_max_len = prompt_max_len
self.ema_beta = ema_beta
self.gradient_checkpointing = gradient_checkpointing
self.reward_fn = reward_fn
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.remote_rm_url = remote_rm_url
self.initial_model = initial_model
self.ema_model = ema_model
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.actor_scheduler = actor_scheduler
self.critic_scheduler = critic_scheduler
self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip)
self.ptx_loss_fn = GPTLMLoss()
self.freezing_actor_steps = getattr(self.args, "freezing_actor_steps", -1)
# Mixtral 8x7b
self.aux_loss = self.args.aux_loss_coef > 1e-8
if self.kl_target:
self.kl_ctl = AdaptiveKLController(init_kl_coef, kl_target, kl_horizon)
else:
self.kl_ctl = FixedKLController(init_kl_coef)
self.experience_maker = NaiveExperienceMaker(
actor,
critic,
reward_model,
initial_model,
tokenizer,
prompt_max_len,
self.kl_ctl,
strategy,
remote_rm_url,
reward_fn,
)
packing_samples = getattr(self.args, "packing_samples", False)
self.replay_buffer = NaiveReplayBuffer(
micro_train_batch_size, buffer_limit, buffer_cpu_offload, packing_samples
)
# wandb/tensorboard setting
self._wandb = None
self._tensorboard = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/epoch")
wandb.define_metric("eval/*", step_metric="eval/epoch", step_sync=True)
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
def fit(
self,
args,
prompts_dataloader,
pretrain_dataloader,
consumed_samples=0,
num_update_steps_per_episodes=1,
) -> None:
num_rollouts_per_episodes = (
num_update_steps_per_episodes
* args.train_batch_size
// args.max_epochs
// args.rollout_batch_size
// args.n_samples_per_prompt
)
# get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # do not save ckpt
self.prompts_dataloader = prompts_dataloader
self.pretrain_dataloader = pretrain_dataloader
# Restore step and start_epoch
steps = consumed_samples // args.rollout_batch_size + 1
start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes
consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size)
for episode in range(start_episode, args.num_episodes):
if isinstance(self.prompts_dataloader.sampler, DistributedSampler):
self.prompts_dataloader.sampler.set_epoch(
episode, consumed_samples=0 if episode > start_episode else consumed_samples
)
pbar = tqdm(
range(self.prompts_dataloader.__len__()),
desc=f"Episode [{episode + 1}/{args.num_episodes}]",
disable=not self.strategy.is_rank_0(),
)
for rand_prompts in self.prompts_dataloader:
for i, experience in enumerate(
self.experience_maker.make_experience_list(rand_prompts, **self.generate_kwargs)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print(output)
self.replay_buffer.append(experience)
torch.cuda.empty_cache()
self.replay_buffer.normalize("advantages", self.strategy)
status = self.ppo_train(steps)
self.replay_buffer.clear()
torch.cuda.empty_cache()
if "kl" in status:
self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt)
pbar.set_postfix(status)
# logs/checkpoints
client_states = {"consumed_samples": steps * args.rollout_batch_size}
self.save_logs_and_checkpoints(args, steps, pbar, status, client_states)
pbar.update()
steps = steps + 1
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
def ppo_train(self, global_steps=0):
# replay buffer may be empty at first, we should rebuild at each training
dataloader = DataLoader(
self.replay_buffer,
batch_size=self.replay_buffer.sample_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.replay_buffer.collate_fn,
)
device = torch.cuda.current_device()
status_list = []
status_mean = {}
for epoch in range(self.max_epochs):
pbar = tqdm(
dataloader,
desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
disable=not self.strategy.is_rank_0(),
)
for experience in pbar:
experience.to_device(device)
status = self.training_step(experience, global_steps)
# for DP
# weighted mean for kl
if "kl" in status:
status["kl"] *= status["response_length"]
status = self.strategy.all_reduce(status)
status["kl"] /= status["response_length"]
short_status = {}
if "policy_loss" in status:
short_status = {
"pg": status["policy_loss"],
"rm": status["reward"],
"ret": status["return"],
"glen": status["response_length"],
"tlen": status["total_length"],
"kl": status["kl"],
"act_lr": status["actor_lr"],
}
if "critic_loss" in status:
short_status["cri"] = status["critic_loss"]
short_status["vals"] = status["values"]
short_status["cri_lr"] = status["critic_lr"]
if "ptx_loss" in status:
short_status["ptx"] = status["ptx_loss"]
status_list.append(status)
pbar.set_postfix(short_status)
if status_list:
status_mean = status_list[0]
for m in status_list[1:]:
for k, v in m.items():
status_mean[k] += v
for k in status_mean.keys():
status_mean[k] /= len(status_list)
return status_mean
def training_step(self, experience: Experience, global_steps) -> Dict[str, float]:
status = {}
if global_steps > self.freezing_actor_steps:
status = self.training_step_actor(experience)
if self.critic is not None:
status.update(self.training_step_critic(experience))
return status
def training_step_actor(self, experience: Experience) -> Dict[str, float]:
self.actor.train()
# TODO: this is a bad indicator to say that data is packed...
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
old_action_log_probs = torch.cat(experience.action_log_probs, dim=0).unsqueeze(0)
advantages = torch.cat(experience.advantages, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat(
[torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0
).unsqueeze(0)
else:
sequences = experience.sequences
old_action_log_probs = experience.action_log_probs
advantages = experience.advantages
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
# actor loss
action_log_probs, output = self.actor(
sequences,
num_actions,
attention_mask=attention_mask,
return_output=True,
packed_seq_lens=packed_seq_lens,
)
# loss function
actor_loss = self.actor_loss_fn(
action_log_probs,
old_action_log_probs,
advantages,
action_mask=experience.action_mask,
)
# mixtral
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = actor_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(loss, self.actor, self.actor_optim)
# ptx loss
if self.pretrain_dataloader is not None:
data = next(self.pretrain_dataloader)
inputs = data[1].squeeze(1).to(torch.cuda.current_device())
attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
label = torch.where(
attention_mask.bool(),
inputs,
self.ptx_loss_fn.IGNORE_INDEX,
)
output = self.actor(inputs, attention_mask=attention_mask, return_output=True)
ptx_log_probs = output["logits"]
# loss function
ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
# mixtral
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = ptx_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)
self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
if self.ema_model:
self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cpu")
# status
status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
if self.pretrain_dataloader is not None:
status["ptx_loss"] = ptx_loss.item()
for k, v in experience.info.items():
if k == "kl":
status[k] = (
(v * experience.info["response_length"]).sum() / experience.info["response_length"].sum()
).item()
else:
status[k] = v.mean().item()
return status
def training_step_critic(self, experience: Experience) -> Dict[str, float]:
self.critic.train()
# TODO: this is a bad indicator to say that data is packed...
if isinstance(experience.sequences, list):
sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0)
old_values = torch.cat(experience.values, dim=0).unsqueeze(0)
returns = torch.cat(experience.returns, dim=0).unsqueeze(0)
num_actions = [v.numel() for v in experience.advantages]
packed_seq_lens = [s.numel() for s in experience.sequences]
attention_mask = torch.cat(
[torch.full_like(s, i + 1) for i, s in enumerate(experience.sequences)], dim=0
).unsqueeze(0)
else:
sequences = experience.sequences
old_values = experience.values
returns = experience.returns
num_actions = experience.action_mask.size(1)
packed_seq_lens = None
attention_mask = experience.attention_mask
# critic loss
values, output = self.critic(
sequences,
num_actions=num_actions,
attention_mask=attention_mask,
return_output=True,
packed_seq_lens=packed_seq_lens,
)
# loss function
critic_loss = self.critic_loss_fn(
values,
old_values,
returns,
action_mask=experience.action_mask,
)
# mixtral
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
loss = critic_loss + aux_loss * self.args.aux_loss_coef
self.strategy.backward(loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")
# status
status = {
"critic_loss": critic_loss.item(),
"values": masked_mean(values, experience.action_mask).item(),
"critic_lr": self.critic_scheduler.get_last_lr()[0],
}
return status
def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}):
if global_step % args.logging_steps == 0:
# wandb
if self._wandb is not None and self.strategy.is_rank_0():
logs = {
"train/%s" % k: v
for k, v in {
**logs_dict,
"global_step": global_step,
}.items()
}
if self.experience_maker.perf_stats is not None:
logs.update({f"perf/experience_maker/{k}": v for k, v in self.experience_maker.perf_stats.items()})
self._wandb.log(logs)
# TensorBoard
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in logs_dict.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
if self.experience_maker.perf_stats is not None:
for k, v in self.experience_maker.perf_stats.items():
self._tensorboard.add_scalar(f"perf/experience_maker/{k}", v, global_step)
# TODO: Add evaluation mechanism for PPO
if global_step % args.eval_steps == 0:
# self.evaluate(self.eval_dataloader, global_step)
pass
# save ckpt
# TODO: save best model on dev, use loss/perplexity/others on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self._save_checkpoint(args, tag, client_states)
def _save_checkpoint(self, args, tag, client_states):
self.strategy.save_ckpt(
self.actor.model,
os.path.join(args.ckpt_path, "_actor"),
tag,
args.max_ckpt_num,
args.max_ckpt_mem,
client_states,
)
if self.critic is not None:
self.strategy.save_ckpt(
self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)