diff --git a/src/maxtext/layers/decoders.py b/src/maxtext/layers/decoders.py index 95d181e729..279220c3d2 100644 --- a/src/maxtext/layers/decoders.py +++ b/src/maxtext/layers/decoders.py @@ -319,11 +319,21 @@ def get_remat_policy(self): policy = self.minimal_policy() elif cfg.remat_policy == "minimal_with_quantization": if cfg.scan_layers: - warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.') + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) policy = self.minimal_policy(with_context=False, with_quantization=True) elif cfg.remat_policy == "minimal_with_context_and_quantization": if cfg.scan_layers: - warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.') + warnings.warn( + "Scan layers can introduce overhead to checkpointed values that in some configurations is slower" + "than not checkpointing at all. If you are using scan layers, benchmark with and without quantization " + "checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is " + "beneficial for performance." + ) policy = self.minimal_policy(with_context=True, with_quantization=True) elif cfg.remat_policy == "save_dot_with_context_except_mlp": policy = jax.checkpoint_policies.save_only_these_names( diff --git a/tests/integration/grpo_correctness.py b/tests/integration/grpo_correctness.py index 6bb3fb3897..92fb18a580 100644 --- a/tests/integration/grpo_correctness.py +++ b/tests/integration/grpo_correctness.py @@ -60,12 +60,8 @@ def setUp(self): self.rng = jax.random.PRNGKey(42) devices_array = maxtext_utils.create_device_mesh(self.cfg) mesh = Mesh(devices_array, self.cfg.mesh_axes) - self.model = models.transformer_as_linen( - config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN - ) - self.state, _ = maxtext_utils.setup_decode_state( - self.model, self.cfg, self.rng, mesh, None - ) + self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN) + self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None) self.tokenizer_model = transformers.AutoTokenizer.from_pretrained( "meta-llama/Llama-3.1-8B", add_bos_token=False, @@ -104,16 +100,12 @@ def _prepare_maxtext_inputs(self): """prepare maxtext inputs""" prompt = self.tokenizer_model.encode(self.input_str) input_ids = jnp.pad( - jnp.tile( - jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1) - ), + jnp.tile(jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)), ((0, 0), (0, 4)), constant_values=0, ) # pad some tokens at the end of input prompt input_segmentation = (input_ids > 0).astype(jnp.int32) - input_position = jnp.where( - input_segmentation, jnp.arange(input_segmentation.shape[1]), 0 - ) + input_position = jnp.where(input_segmentation, jnp.arange(input_segmentation.shape[1]), 0) completion_segmentation = jnp.tile( jnp.pad( jnp.array([0] * len(prompt) + [1] * len(prompt)), @@ -129,12 +121,9 @@ def _prepare_maxtext_inputs(self): ) def _prepare_trl_inputs(self): - tokenized_inputs = self.tokenizer_model( - [self.input_str], return_tensors="pt" - ) - input_ids = torch.cat( - (tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1 - ) + """Prepare TRL inputs.""" + tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt") + input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1) attention_mask = torch.cat( ( tokenized_inputs["attention_mask"], @@ -147,9 +136,7 @@ def _prepare_trl_inputs(self): def test_logits(self): def _prepare_inputs(): - input_ids = jnp.tile( - jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1) - ) + input_ids = jnp.tile(jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1)) input_segmentation = (input_ids > 0).astype(jnp.int32) input_position = jnp.tile(jnp.arange(input_ids.shape[1]), (4, 1)) @@ -175,17 +162,11 @@ def _prepare_inputs(): .numpy() ) print(f"Max Diff {np.max(np.abs(logits - hf_logits))}") - self.assertTrue( - jax.numpy.allclose( - hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False - ) - ) + self.assertTrue(jax.numpy.allclose(hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False)) def test_logps(self): - input_ids, input_segmentation, input_position, completion_segmentation = ( - self._prepare_maxtext_inputs() - ) + input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs() maxtext_per_token_logps, _ = compute_log_probs( self.model, self.state.params, @@ -202,12 +183,7 @@ def test_logps(self): print( "Max Diff", - np.max( - np.abs( - np.trim_zeros(np.asarray(maxtext_per_token_logps)[0]) - - hf_per_token_logps.detach().numpy()[0] - ) - ), + np.max(np.abs(np.trim_zeros(np.asarray(maxtext_per_token_logps)[0]) - hf_per_token_logps.detach().numpy()[0])), ) self.assertTrue( jax.numpy.allclose( @@ -228,27 +204,16 @@ def test_loss_kl_div(self): completions = [{"prompt": self.input_str}] * 4 rewards = torch.tensor( - [ - self.trainer.reward_funcs[0](completion) - for completion in completions - ], + [self.trainer.reward_funcs[0](completion) for completion in completions], dtype=torch.float32, ) # Compute grouped-wise rewards - mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean( - dim=1 - ) - std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std( - dim=1 - ) + mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(dim=1) # Normalize the rewards to compute the advantages - mean_grouped_rewards = mean_grouped_rewards.repeat_interleave( - self.trainer.num_generations, dim=0 - ) - std_grouped_rewards = std_grouped_rewards.repeat_interleave( - self.trainer.num_generations, dim=0 - ) + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0) # since we are using the same completion, so advantages = 0 for every sequence # but we can keep it this way since our on-policy implementation # gets average advantage which becomes zero anyway @@ -273,9 +238,7 @@ def test_loss_kl_div(self): self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access - input_ids, input_segmentation, input_position, completion_segmentation = ( - self._prepare_maxtext_inputs() - ) + input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs() maxtext_per_token_logps, _ = compute_log_probs( self.model, self.state.params,