Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/maxtext/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,21 @@ def get_remat_policy(self):
policy = self.minimal_policy()
elif cfg.remat_policy == "minimal_with_quantization":
if cfg.scan_layers:
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=False, with_quantization=True)
elif cfg.remat_policy == "minimal_with_context_and_quantization":
if cfg.scan_layers:
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
warnings.warn(
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
"beneficial for performance."
)
policy = self.minimal_policy(with_context=True, with_quantization=True)
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
Expand Down
71 changes: 17 additions & 54 deletions tests/integration/grpo_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,8 @@ def setUp(self):
self.rng = jax.random.PRNGKey(42)
devices_array = maxtext_utils.create_device_mesh(self.cfg)
mesh = Mesh(devices_array, self.cfg.mesh_axes)
self.model = models.transformer_as_linen(
config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
)
self.state, _ = maxtext_utils.setup_decode_state(
self.model, self.cfg, self.rng, mesh, None
)
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None)
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(
"meta-llama/Llama-3.1-8B",
add_bos_token=False,
Expand Down Expand Up @@ -104,16 +100,12 @@ def _prepare_maxtext_inputs(self):
"""prepare maxtext inputs"""
prompt = self.tokenizer_model.encode(self.input_str)
input_ids = jnp.pad(
jnp.tile(
jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)
),
jnp.tile(jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)),
((0, 0), (0, 4)),
constant_values=0,
) # pad some tokens at the end of input prompt
input_segmentation = (input_ids > 0).astype(jnp.int32)
input_position = jnp.where(
input_segmentation, jnp.arange(input_segmentation.shape[1]), 0
)
input_position = jnp.where(input_segmentation, jnp.arange(input_segmentation.shape[1]), 0)
completion_segmentation = jnp.tile(
jnp.pad(
jnp.array([0] * len(prompt) + [1] * len(prompt)),
Expand All @@ -129,12 +121,9 @@ def _prepare_maxtext_inputs(self):
)

def _prepare_trl_inputs(self):
tokenized_inputs = self.tokenizer_model(
[self.input_str], return_tensors="pt"
)
input_ids = torch.cat(
(tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1
)
"""Prepare TRL inputs."""
tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt")
input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1)
attention_mask = torch.cat(
(
tokenized_inputs["attention_mask"],
Expand All @@ -147,9 +136,7 @@ def _prepare_trl_inputs(self):

def test_logits(self):
def _prepare_inputs():
input_ids = jnp.tile(
jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1)
)
input_ids = jnp.tile(jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1))
input_segmentation = (input_ids > 0).astype(jnp.int32)
input_position = jnp.tile(jnp.arange(input_ids.shape[1]), (4, 1))

Expand All @@ -175,17 +162,11 @@ def _prepare_inputs():
.numpy()
)
print(f"Max Diff {np.max(np.abs(logits - hf_logits))}")
self.assertTrue(
jax.numpy.allclose(
hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False
)
)
self.assertTrue(jax.numpy.allclose(hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False))

def test_logps(self):

input_ids, input_segmentation, input_position, completion_segmentation = (
self._prepare_maxtext_inputs()
)
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
maxtext_per_token_logps, _ = compute_log_probs(
self.model,
self.state.params,
Expand All @@ -202,12 +183,7 @@ def test_logps(self):

print(
"Max Diff",
np.max(
np.abs(
np.trim_zeros(np.asarray(maxtext_per_token_logps)[0])
- hf_per_token_logps.detach().numpy()[0]
)
),
np.max(np.abs(np.trim_zeros(np.asarray(maxtext_per_token_logps)[0]) - hf_per_token_logps.detach().numpy()[0])),
)
self.assertTrue(
jax.numpy.allclose(
Expand All @@ -228,27 +204,16 @@ def test_loss_kl_div(self):

completions = [{"prompt": self.input_str}] * 4
rewards = torch.tensor(
[
self.trainer.reward_funcs[0](completion)
for completion in completions
],
[self.trainer.reward_funcs[0](completion) for completion in completions],
dtype=torch.float32,
)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(
dim=1
)
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(
dim=1
)
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(dim=1)
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(dim=1)

# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
self.trainer.num_generations, dim=0
)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
self.trainer.num_generations, dim=0
)
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
# since we are using the same completion, so advantages = 0 for every sequence
# but we can keep it this way since our on-policy implementation
# gets average advantage which becomes zero anyway
Expand All @@ -273,9 +238,7 @@ def test_loss_kl_div(self):

self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access

input_ids, input_segmentation, input_position, completion_segmentation = (
self._prepare_maxtext_inputs()
)
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
maxtext_per_token_logps, _ = compute_log_probs(
self.model,
self.state.params,
Expand Down
Loading