In [1]:
from torch.nn.utils.rnn import pack_padded_sequence
import torch

In [2]:
def pack_input_ids(input_ids, max_length=32):
    """
    Pack input_ids into batches of size max_length.
    """
    packed_inputs = []
    attention_masks = []
    current_batch = []
    current_attention_mask = []
    current_length = 0
    current_attention_mask_no = 1

    for i, ids in enumerate(input_ids):
        ids_length = len(ids)

        if current_length + ids_length > max_length:
            # Calculate how many ids can be added to the current batch
            ids_to_add = max_length - current_length
            current_batch += ids[:ids_to_add]
            current_attention_mask += [current_attention_mask_no] * ids_length
            packed_inputs.append(torch.tensor(current_batch))
            attention_masks.append(torch.tensor(current_attention_mask))
            
            # Start a new batch
            current_batch = []
            current_attention_mask = []
            current_length = 0
            current_attention_mask_no = 1
        else:
            # Add to the current batch
            current_batch += ids
            current_length += ids_length
            current_attention_mask += [current_attention_mask_no] * ids_length 
            current_attention_mask_no += 1

    # Add the last batch
    if current_batch:
        packed_inputs.append(torch.tensor(current_batch))
        attention_masks.append(torch.tensor(current_attention_mask))

    return packed_inputs, attention_masks


In [3]:
prompts = [
    [101, 1045, 2293, 2026, 2171, 102],  # Example prompt 1 token ids
    [101, 1045, 2572, 2293, 102],         # Example prompt 2 token ids
    [101, 1045, 2293, 2026, 2171, 100, 102]  # Example prompt 3 token ids
]



In [4]:
packed_prompts, attention_mask = pack_input_ids(prompts, max_length=10)

In [5]:
packed_prompts

[tensor([ 101, 1045, 2293, 2026, 2171,  102,  101, 1045, 2572, 2293]),
 tensor([ 101, 1045, 2293, 2026, 2171,  100,  102])]

In [6]:
attention_mask

[tensor([1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]), tensor([1, 1, 1, 1, 1, 1, 1])]

In [7]:
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer.pad_token_id=tokenizer.eos_token_id

In [9]:
packed_prompts = torch.nn.utils.rnn.pad_sequence(packed_prompts, batch_first=True, padding_value=tokenizer.pad_token_id)

In [10]:
attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)

In [11]:
packed_prompts

tensor([[ 101, 1045, 2293, 2026, 2171,  102,  101, 1045, 2572, 2293],
        [ 101, 1045, 2293, 2026, 2171,  100,  102,    2,    2,    2]])

In [12]:
attention_mask

tensor([[1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]])

In [13]:
tokenizer

LlamaTokenizerFast(name_or_path='meta-llama/Llama-2-7b-hf', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [1]:
import torch

In [2]:
a = torch.ones([1, 7])
b = torch.ones([1, 7])
c = torch.ones([1, 84])

In [3]:
d = torch.cat([a, b, c], dim=-1)

In [4]:
d

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1.]])

In [2]:
import torch

In [3]:
att_mask = torch.ones([2, 7])
att_mask[0][2] = 0

In [4]:
att_mask

tensor([[1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.]])

In [9]:
unique_mask_values = torch.unique(att_mask)

In [10]:
unique_mask_values

tensor([0., 1.])

In [15]:
import torch

def build_batched_causal_mask_from_attention_mask(attention_mask: torch.Tensor, dtype: torch.dtype = torch.float32):
    """
    Creates a batched 4D causal mask from attention_mask where only elements with the same values in the attention_mask
    can attend to each other, and elements with value 0 in the attention_mask are ignored.

    Args:
        attention_mask (torch.Tensor): A 2D tensor of shape (batch_size, sequence_length) with different values representing groups.
        dtype (torch.dtype): Data type for the mask (default: torch.float32).

    Returns:
        causal_mask (torch.Tensor): A batched 4D causal mask of shape (batch_size, 1, sequence_length, sequence_length)
                                    where elements with the same values in attention_mask can attend to each other,
                                    and elements with value 0 are ignored.
    """
    batch_size, seq_len = attention_mask.size()

    # Create a base causal mask (lower triangular matrix), ensuring tokens can't attend to future tokens
    causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=dtype)).unsqueeze(0).expand(batch_size, -1, -1)

    # Create a mask where elements with the same attention_mask values can attend to each other (for each batch)
    same_value_mask = (attention_mask.unsqueeze(1) == attention_mask.unsqueeze(2)).to(dtype)

    # Create a mask to ignore elements with value 0 in the attention_mask (for each batch)
    non_zero_mask = (attention_mask != 0).unsqueeze(1).to(dtype)  # Mask where 1 indicates non-zero elements

    # Combine the causal mask with the same-value mask (element-wise multiplication)
    combined_mask = causal_mask * same_value_mask

    # Apply the non-zero mask to ignore tokens with attention_mask value 0
    combined_mask = combined_mask * non_zero_mask * non_zero_mask.transpose(1, 2)

    # Expand the mask to 4D with the second dimension as 1
    combined_mask = combined_mask.unsqueeze(1)

    return combined_mask

# Example usage:
batch_attention_mask = torch.tensor([
    [1, 1, 0, 2, 3],  # Batch 1
    [2, 2, 2, 0, 1],  # Batch 2
])  # Shape: [batch_size=2, seq_len=5]

causal_mask = build_batched_causal_mask_from_attention_mask(batch_attention_mask)

print(causal_mask.shape)  # Should be [batch_size=2, 1, seq_len=5, seq_len=5]
print(causal_mask)

torch.Size([2, 1, 5, 5])
tensor([[[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 1., 0.],
          [0., 0., 0., 0., 1.]]],


        [[[1., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0.],
          [1., 1., 1., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 1.]]]])


In [14]:
causal_mask.shape

torch.Size([2, 5, 5])