-
Notifications
You must be signed in to change notification settings - Fork 467
/
accelerate_ppo_trainer.py
553 lines (470 loc) · 24.9 KB
/
accelerate_ppo_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
import json
import os
import uuid
from time import time
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import transformers
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import trlx.utils.logging as logging
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.configs import TRLConfig
from trlx.data.ppo_types import PPORLBatch, PPORLElement
from trlx.models.modeling_ppo import (
AdaptiveKLController,
AutoModelForCausalLMWithHydraValueHead,
AutoModelForSeq2SeqLMWithHydraValueHead,
FixedKLController,
)
from trlx.pipeline.offline_pipeline import PromptPipeline
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.trainer import register_trainer
from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer
from trlx.utils import Clock, infinite_dataloader
from trlx.utils.modeling import RunningMoments, gather_dict, logprobs_of_labels
logger = logging.get_logger(__name__)
@register_trainer
class AcceleratePPOTrainer(AccelerateRLTrainer):
"""PPO Accelerate Trainer"""
reward_fn: Callable[[List[str], List[str], List[str]], List[float]]
tokenizer: AutoTokenizer
def __init__(self, config: TRLConfig, **kwargs):
"""PPO Accelerate Trainer initialization
Args:
config: `TRLConfig`
kwargs: Additional keyword arguments passed to `AccelerateRLTrainer`
"""
super().__init__(config, **kwargs)
# Setup rollout logging
if config.train.rollout_logging_dir is not None:
self.log_rollouts = True
self.setup_rollout_logging(config)
else:
self.log_rollouts = False
# Setup the rollout store
# Rollouts contain the prompt & response, log probs, values and rewards - from each rollout
self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side)
# Create the rollout store dataloader (for batching up rollouts)
# TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future
rollout_loader: DataLoader = self.store.create_loader(self.config.train.batch_size, shuffle=True)
# Prepare multi-GPU acceleration
self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(
self.model, self.opt, self.scheduler, rollout_loader
)
self.store.clear_history() # Clear the rollout store
# Set up a reference model when hydra heads are not used
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
self.ref_model = self.get_arch(self.config)
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
# Set up the KL controller
# This helps prevent large divergences in the controller (policy)
if config.method.target is not None:
self.kl_ctl = AdaptiveKLController(config.method.init_kl_coef, config.method.target, config.method.horizon)
else:
self.kl_ctl = FixedKLController(config.method.init_kl_coef)
# Create the parameters for the Hugging Face language model's generator
# method (that generates new tokens from a prompt).
# https://huggingface.co/docs/transformers/v4.25.1/en/main_classes/text_generation#transformers.GenerationMixin.generate
generate_kwargs = dict(
do_sample=True,
use_cache=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3",
)
self.generate_kwargs = {**generate_kwargs, **config.method.gen_kwargs}
if config.method.gen_experience_kwargs is not None:
self.generate_experience_kwargs = {**generate_kwargs, **config.method.gen_experience_kwargs}
else:
self.generate_experience_kwargs = None
# Setup stats tracker
self.running_moments = RunningMoments()
self.ref_mean = self.config.method.ref_mean
self.ref_std = self.config.method.ref_std
def get_arch(self, config: TRLConfig):
"""Returns a specific wrapper given a model's architecture"""
model_class = AutoModelForCausalLMWithHydraValueHead
if config.model.model_arch_type == "seq2seq":
model_class = AutoModelForSeq2SeqLMWithHydraValueHead
from_fn = model_class.from_pretrained
# backward-compat: Try to create a randomly initialized architecture from a config
if issubclass(type(config.model.model_path), transformers.PretrainedConfig):
from_fn = model_class.from_config
return from_fn(
config.model.model_path,
num_layers_unfrozen=config.model.num_layers_unfrozen,
num_value_layers_unfrozen=config.method.num_value_layers_unfrozen,
peft_config=self.config.model.peft_config,
**self.config.model.model_extra_configs,
)
def loss(self, batch: PPORLBatch) -> Tuple[float, Dict[str, Any]]:
"""Computes loss on a batch of data and returns statistics
Args:
batch: `PPORLBatch` Previous batch of episodes
Returns:
loss: `Float` Loss value
stats: `Dict[str, Any]` PPO Statistics values
"""
# Move `batch` data to `accelerator` device
query_tensors = batch.query_tensors.to(self.accelerator.device)
response_tensors = batch.response_tensors.to(self.accelerator.device)
old_logprobs = batch.logprobs.to(self.accelerator.device)
old_values = batch.values.to(self.accelerator.device)
old_rewards = batch.rewards.to(self.accelerator.device)
response_length = old_rewards.shape[1]
advantages, returns = self.config.method.get_advantages_and_returns(old_values, old_rewards, response_length)
if self.config.model.model_arch_type == "seq2seq":
input_ids = query_tensors
decoder_input_ids = response_tensors
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)
decoder_attention_mask = (
decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)
)
decoder_attention_mask[:, 0] = 1
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
logits = outputs.logits
values_pred = outputs.value
logprobs = logprobs_of_labels(logits[:, :-1, :], decoder_input_ids[:, 1:])
mask = decoder_input_ids.ne(self.tokenizer.pad_token_id).long().to(self.accelerator.device)
start = 0
end = start + response_length
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
mask[:, start + 1 : end + 1],
)
else:
tokens = torch.cat((query_tensors, response_tensors), dim=1)
attention_mask = tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
outputs = self.model(tokens, attention_mask, return_dict=True, position_ids=position_ids)
logits = outputs.logits
values_pred = outputs.value
values_pred = values_pred[:, :-1]
logprobs = logprobs_of_labels(logits[:, :-1, :], tokens[:, 1:])
start = query_tensors.shape[1] - 1
end = start + response_length
logprobs, values_pred, mask = (
logprobs[:, start:end],
values_pred[:, start:end],
attention_mask[:, start + 1 : end + 1],
)
loss, stats = self.config.method.loss(
logprobs=logprobs,
values=values_pred,
old_logprobs=old_logprobs,
old_values=old_values,
advantages=advantages,
returns=returns,
mask=mask,
)
return loss, stats
def setup_rollout_logging(self, config):
"""Make rollout logging directory to log rollouts to"""
exists = os.path.exists(config.train.rollout_logging_dir)
isdir = os.path.isdir(config.train.rollout_logging_dir)
assert exists and isdir
self.run_id = f"run-{uuid.uuid4()}"
self.rollout_logging_dir = os.path.join(config.train.rollout_logging_dir, self.run_id)
os.mkdir(self.rollout_logging_dir)
with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f:
f.write(json.dumps(config.to_dict(), indent=2))
def post_epoch_callback(self):
"""Clears the rollout store and creates `num_rollouts` new samples"""
if self.log_rollouts:
self.store.export_history(location=self.rollout_logging_dir)
self.store.clear_history()
# Collect more rollouts for training
self.make_experience(self.config.method.num_rollouts, self.iter_count)
def post_backward_callback(self):
self.kl_ctl.update(self.mean_kl, n_steps=self.config.train.batch_size)
def create_train_dataloader(self):
return self.store.create_loader(self.config.train.batch_size, shuffle=True)
def prepare_learning(self):
eval_dataloader = self.eval_pipeline.create_loader(self.config.method.chunk_size)
self.eval_dataloader = self.accelerator.prepare_data_loader(eval_dataloader)
self.make_experience(self.config.method.num_rollouts)
self.train_dataloader = self.create_train_dataloader()
self.n_inner_epochs = self.config.method.ppo_epochs
self.total_steps = self.config.train.epochs * self.n_inner_epochs * len(self.train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)
def add_prompt_pipeline(self, pipeline: PromptPipeline):
"""Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage"""
prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True)
prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader)
self.prompt_iterator = infinite_dataloader(prompt_dataloader)
def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa:
"""
Takes `chunk_size` number of prompts from `prompt_iterator`, samples
from the model and then computes the KL against a reference model. Finally it
then appends PPOElements to trainer's `store`.
Args:
num_rollouts: Number of rollouts to generate
iter_count: Total number of updates for all batches & epochs
"""
logger.info("Collecting rollouts")
tbar = logging.tqdm(
total=num_rollouts,
disable=os.environ.get("RANK", 0) != "0",
desc=f"[rollout 0 / {num_rollouts}]",
# Lower progress bar by 1 if we're in WARNING mode or above to avoid hiding high priority progress
# bars (e.g. loss progress in trainers)
position=logging.get_verbosity() >= logging.WARNING,
# Leave progress bar if we're in INFO mode or lower to avoid spamming in suppressed verbosity levels
leave=logging.get_verbosity() < logging.WARNING,
)
clock = Clock()
ppo_rl_elements = []
accumulated_stats = []
while len(ppo_rl_elements) < num_rollouts:
stats = {}
# Get next batch in prompt dataset
batch: PromptBatch = next(self.prompt_iterator)
rollout_generate_time = time()
# Generate samples from the language model (similar to using HuggingFace `generate` method)
samples = self.generate(batch["input_ids"], batch["attention_mask"])
stats["time/rollout_generate"] = time() - rollout_generate_time
prompt_tensors = batch.input_ids
device = samples.device
prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"})
if self.accelerator.is_main_process:
all_str_samples, all_str_prompts, all_str_outputs = self.decode(
gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
)
rollout_score_time = time()
# reward_fn should return list of rewards at each token per sample
# NOTE: all_scores[0][i] is the reward due to token (action) i in prompt + response (b/c of how kl is computed)
all_scores = self.reward_fn(
samples=all_str_samples,
prompts=all_str_prompts,
outputs=all_str_outputs,
tokenizer=self.tokenizer,
**metadata,
)
all_scores = [
torch.tensor(score, dtype=torch.float, device=device).view(
-1,
)
for score in all_scores
]
# Pad 0 reward on the ends
all_scores = pad_sequence(all_scores, batch_first=True, padding_value=-np.inf)
max_len = torch.tensor(all_scores.shape[1], dtype=torch.long, device=device)
stats["time/rollout_score"] = time() - rollout_score_time
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
else:
all_scores = None
max_len = torch.tensor(0, dtype=torch.long, device=device)
if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
else:
scores = all_scores[0].clone().detach()
scores_mask = scores != -np.inf
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
# Pad the sample outputs
outputs = self.tokenizer(str_outputs).input_ids
if self.config.model.model_arch_type == "seq2seq":
# add <pad> to the start of the output
for i in range(len(outputs)):
outputs[i] = [self.tokenizer.pad_token_id] + outputs[i]
outputs = list(map(torch.LongTensor, outputs))
maxsize = max(map(len, outputs))
outputs = [
F.pad(
output,
(0, maxsize - len(output)),
value=self.tokenizer.pad_token_id,
)
for output in outputs
]
sample_outputs = torch.vstack(outputs).to(device)
if self.config.method.cliprange_reward:
scores = torch.clip(scores, -self.config.method.cliprange_reward, self.config.method.cliprange_reward)
# store statistics of the initial rollout as reference
if self.ref_mean is None:
self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), (scores * scores_mask).sum(
dim=1
).std()
all_scores_mean, all_scores_std = self.running_moments.update(torch.sum(scores * scores_mask, dim=1))
stats["rollout_scores/mean"] = all_scores_mean.item()
stats["rollout_scores/std"] = all_scores_std.item()
stats["rollout_scores/running_mean"] = self.running_moments.mean.item()
stats["rollout_scores/running_std"] = self.running_moments.std.item()
if self.config.method.scale_reward == "running":
scores /= self.running_moments.std
elif self.config.method.scale_reward == "ref":
scores /= self.ref_std
# Precompute logprobs, values
if self.config.model.model_arch_type == "seq2seq":
attention_mask = batch.attention_mask.to(device)
prompt_tensors = batch.input_ids.to(device)
decoder_attention_mask = sample_outputs.not_equal(self.tokenizer.pad_token_id)
decoder_attention_mask[:, 0] = 1
with torch.no_grad():
outputs = self.model(
input_ids=prompt_tensors,
attention_mask=attention_mask,
decoder_input_ids=sample_outputs,
decoder_attention_mask=decoder_attention_mask,
)
logits = outputs.logits
values = outputs.value
if hasattr(self.model, "frozen_head") or self.model.peft_type:
ref_logits = self.model.forward_hydra(
input_ids=prompt_tensors,
attention_mask=attention_mask,
decoder_input_ids=sample_outputs,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
).logits
else:
ref_logits = self.ref_model(
input_ids=prompt_tensors,
attention_mask=attention_mask,
decoder_input_ids=sample_outputs,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
).logits
else:
all_tokens = torch.cat((prompt_tensors.to(device), sample_outputs), dim=1)
attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long().to(device)
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
with torch.no_grad():
logits, *_, values = self.model(
all_tokens, attention_mask=attention_mask, position_ids=position_ids
)
# TODO(dahoas): When hydra model works need to also support generation on hydra head
if hasattr(self.model, "frozen_head") or self.model.peft_type:
ref_logits = self.model.forward_hydra(
all_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
).logits
else:
ref_logits = self.ref_model(
all_tokens,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
).logits
ref_logits = ref_logits.to(device)
if self.config.model.model_arch_type == "seq2seq":
logprobs = logprobs_of_labels(logits[:, :-1, :], sample_outputs[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], sample_outputs[:, 1:])
else:
# NOTE: logprob[i] is (log)prob at which all_token[i+1] was sampled
logprobs = logprobs_of_labels(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
n_samples: int = samples.shape[0]
# Estimate the KL divergence between the model and reference model
if self.config.model.model_arch_type == "seq2seq":
attention_mask = sample_outputs != self.tokenizer.pad_token_id
start = 0
else:
start = prompt_tensors.shape[1] - 1
log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1]
kl = log_ratio.exp() - 1 - log_ratio
mean_kl_per_token = kl.mean()
mean_kl = kl.sum(1).mean()
logprobs = logprobs.cpu()
ref_logprobs = ref_logprobs.cpu()
prompt_tensors = prompt_tensors.cpu()
sample_outputs = sample_outputs.cpu()
values = values.cpu()[:, :-1]
# Get the logprobs and values, for tokens that are not padding,
# from the end of the prompt up to the <eos> token, while also including the latter
# (these are taken from the student model and not the reference model)
ends = start + attention_mask[:, start:].sum(1) + 1
all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]
kl_penalty = self.kl_ctl.value * -log_ratio.cpu()
kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)]
rollout_count = 0
for sample_idx in range(n_samples):
rewards = kl_penalty[sample_idx]
# Then add in rewards
if scores.shape[1] == 1:
# NOTE: Final reward given at EOS token following HHH practice
rewards[-1] += scores[sample_idx][0].cpu()
else:
score = scores[sample_idx]
score_right_padding = torch.sum(scores_mask[sample_idx])
score = score[:score_right_padding].cpu()
p_score = torch.zeros_like(rewards)
p_score[: score.shape[0]] += score
rewards += p_score
ppo_rl_elements.append(
PPORLElement(
query_tensor=prompt_tensors[sample_idx],
response_tensor=sample_outputs[sample_idx],
logprobs=all_logprobs[sample_idx],
values=all_values[sample_idx],
rewards=rewards,
)
)
rollout_count += 1
if torch.distributed.is_initialized():
torch.distributed.all_reduce(mean_kl, torch.distributed.ReduceOp.AVG)
stats["time/rollout_time"] = clock.tick()
stats["policy/sqrt_kl"] = torch.sqrt(mean_kl).item()
stats["policy/kl_per_token"] = torch.sqrt(mean_kl_per_token).item()
accumulated_stats.append(stats)
tbar.set_description(f"[rollout {len(ppo_rl_elements)} / {num_rollouts}]")
tbar.update(min(rollout_count, num_rollouts))
tbar.close()
stats = {k: sum([xs[k] for xs in accumulated_stats]) / len(accumulated_stats) for k in stats}
stats["kl_ctl_value"] = self.kl_ctl.value
self.mean_kl = stats["policy/sqrt_kl"] ** 2
self.accelerator.log(stats, step=iter_count)
# Push samples and rewards to trainer's rollout storage
self.push_to_store(ppo_rl_elements)
def save_pretrained(self, directory: Optional[str] = None, **kwargs):
"""
Args:
directory (str, *optional*): The directory to save the trainer files to.
NOTE: If not specified, the model will be saved to a directory named `hf_model` in the
checkpoint directory as specified by the Trainer's config.
**kwargs: Additional keyword arguments passed to the underlying Hugging Face model's
`save_pretrained` method.
"""
if directory is None:
directory = os.path.join(self.config.train.checkpoint_dir, "hf_model")
self.accelerator.wait_for_everyone()
# Save only the base model, so that is could be loaded directly
# with Hugging Face's `from_pretrained` method
state_dict = self.accelerator.get_state_dict(self.model, unwrap=True)
self.accelerator.unwrap_model(self.model).save_pretrained(
directory,
save_function=self.accelerator.save,
is_main_process=self.accelerator.is_main_process,
state_dict=state_dict,
**kwargs,
)
if self.accelerator.is_main_process:
self.tokenizer.save_pretrained(directory)