Resolve pylint c-class: C0301,C3001,C0114,C0115,C0116,C0200,C0121,C0201,C0206,C0209,C0412,C0415,C2801#1668
Conversation
abf3f29 to
e0abd71
Compare
shralex
left a comment
There was a problem hiding this comment.
Thanks Samuel!
Two general comments.
As much as possible, when fixing these lets adhere to the Google Python styleguide (https://g3doc.corp.google.com/eng/doc/devguide/py/style/index.md?cl=head), as we want to improve the quality of the code and will eventually use Google's lint, if we're not already. In particular here, the first sentence of doc-strings (just the ones you're adding or changing) should be capitalized (except when starting with a variable name), and have a dot at the end of a sentence. Thank you!!
There was a problem hiding this comment.
Could we remove the comment before this function and replace the docstring with the following (curtesy of Gemini :)
"""Generates a combined attention mask for Transformer models.
This function constructs an attention mask by potentially combining
several types of masks based on the input parameters and model
configuration. The generated mask dictates which query-key pairs are
allowed to attend to each other.
The masking logic can enforce:
1. **Sequence Separation:** Using `decoder_segment_ids`, attention is
confined within distinct sequences in a batch. This is crucial when
multiple unrelated sequences are packed together.
2. **Causality:** Preventing attention to future positions. This is
standard for autoregressive decoding. For chunked prefill, as
described in the SARATHI paper [2], causality is adjusted based
on `previous_chunk` information.
3. **Specialized Attention Patterns:** Depending on `self.attention_type`,
it can apply:
* Local Sliding Window Attention: Restricts attention to a
fixed-size window around each query position.
* Chunk Attention: Divides sequences into chunks and applies
masking at the chunk level.
4. **Bidirectional Attention for Sub-sequences:** If `bidirectional_mask`
is provided (e.g., for image tokens in a multimodal model),
those parts of the sequence can attend bidirectionally, and this
mask is OR-ed with other generated masks.
The overall approach and specific masking techniques are influenced by
efficient attention mechanisms like those found in the Pallas MHA
Flash Attention reference [1].
Args:
query: The query tensor, typically of shape
`[batch_size, q_sequence_length, num_heads, head_dim]`.
Used primarily for deriving sequence length.
key: The key tensor, typically of shape
`[batch_size, kv_sequence_length, num_heads, head_dim]`.
Used primarily for deriving sequence length.
decoder_segment_ids: Optional `Array` of shape `[batch_size, q_sequence_length]`.
Identifies distinct sequences within the batch. Attention is
restricted to elements within the same segment ID. In autoregressive
mode, specific values (e.g., `common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR`)
can mark the currently active sequence for decoding.
model_mode: A string (e.g., `common_types.MODEL_MODE_AUTOREGRESSIVE`,
`common_types.MODEL_MODE_PREFILL`) indicating the operational
mode. This significantly influences mask generation, particularly
how causality and segment separation are handled.
previous_chunk: Optional. Information about previously processed
key/value chunks, often a tensor representing the previous keys/values.
Used to correctly offset causal masks in chunked attention or
streaming scenarios. Its shape might be
`[batch_size, prev_kv_sequence_length, ...]`.
bidirectional_mask: Optional `Array` of shape `[batch_size, kv_sequence_length]`.
If provided, this boolean mask indicates tokens (e.g., image tokens)
that are allowed to attend bidirectionally. The resulting
block-wise bidirectional mask is combined with other masks using a
logical OR.
Returns:
An `Array` representing the attention mask, broadcastable to the shape
`[batch_size, num_heads, q_sequence_length, kv_sequence_length]`.
Positions with `0.0` allow attention, while positions with
`DEFAULT_MASK_VALUE` (a large negative number) prevent it.
Returns `None` if no masking is determined to be necessary based on
the inputs and configuration.
References:
[1] JAX Pallas MHA Flash Attention:
https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py
[2] SARATHI: Efficient LLM Inference by Piggybacking Decodes with
Chunked Prefills - ArXiv:2308.16369 (https://arxiv.org/abs/2308.16369)
"""
There was a problem hiding this comment.
@shralex sure thing. In fact, the plan was once all the docstrings are in place to run gemini across the whole codebase to improve all the docstrings.
71899aa to
b2fd145
Compare
2c2a06a to
068aca1
Compare
4786773 to
bf18866
Compare
…121,C0201,C0206,C0209,C0412,C0415,C2801 ; [code_style.sh,.github/workflows/CPUTests.yml] Enable c-class
8266254
into
AI-Hypercomputer:main
Description
Linting is fun?
Tests
N/A
Checklist
Before submitting this PR, please make sure (put X in square brackets):