In [8]:
# Filter batch function subroutines
import torch

batch = {
    "input_ids": torch.tensor([
        [101, 102, 103],
        [201, 202, 203],
        [301, 302, 0]
    ]),
    "position_ids": torch.tensor([
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 0]
    ]),
    "attention_mask": torch.tensor([
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 0]
    ]),
    "responses": ["Hello", "World", "Test"],
    "tokens_remaining": [3, 0, 2],
    "past_key_values": (
        ([torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 0]]), torch.tensor([[2, 2, 2], [2, 2, 2], [2, 2, 0]])]),  # keys and values for each layer/head
    )
}

In [9]:
# 1. Identify Entries to Remove
def identify_removal_indeces(tokens_remaining):
    """Function finds all indexes, 
    which doesn't have tokens left
    to response
    """
    return [i for i,tokens in enumerate(tokens_remaining) if tokens <= 0]

print(batch["tokens_remaining"])
remove_indices = identify_removal_indeces(batch["tokens_remaining"])
print(remove_indices)

[3, 0, 2]
[1]


In [10]:
#2. Create a Mask and Filter Tensors 
def create_mask_and_filter(batch, remove_indices):
    mask = torch.ones(batch["input_ids"].size(0),dtype = torch.bool)
    mask[remove_indices] = False

    # Apply the mask to tensors
    filtered_batch = {
        "input_ids":batch["input_ids"][mask],
        "position_ids":batch["position_ids"][mask],
        "attention_mask":batch["attention_mask"][mask]
    }

    return filtered_batch, mask 

print(batch["input_ids"])

filtered_batch,mask = create_mask_and_filter(batch, remove_indices)
print(mask)
print(filtered_batch["input_ids"])

tensor([[101, 102, 103],
        [201, 202, 203],
        [301, 302,   0]])
tensor([ True, False,  True])
tensor([[101, 102, 103],
        [301, 302,   0]])


In [11]:
# 3. Filter Lists
def filter_lists(responses, tokens_remaining, remove_indices):
    filtered_responses = [r for i, r in enumerate(responses) if i not in remove_indices]
    filtered_tokens_remaining = [t for i, t in enumerate(tokens_remaining) if i not in remove_indices]
    return filtered_responses, filtered_tokens_remaining

filtered_responses, filtered_tokens_remaining = filter_lists(batch["responses"], batch["tokens_remaining"], remove_indices)

print(filtered_responses)
print(filtered_tokens_remaining)

['Hello', 'Test']
[3, 2]


In [14]:
# 4. Filter Past Key Values
def filter_past_key_values(past_key_values, mask):
    new_past_key_values = []
    # Iterate over each tuple in the past_key_values
    for k, v in past_key_values:
        # Apply the mask to both the key and value tensors
        filtered_k = k[mask]
        filtered_v = v[mask]
        # Append the filtered key-value pairs back into the new list
        new_past_key_values.append((filtered_k, filtered_v))
    return new_past_key_values


print(batch["past_key_values"])
filtered_past_key_values = filter_past_key_values(batch["past_key_values"], mask)
print(filtered_past_key_values)

([tensor([[1, 1, 1],
        [1, 1, 1],
        [1, 1, 0]]), tensor([[2, 2, 2],
        [2, 2, 2],
        [2, 2, 0]])],)
[(tensor([[1, 1, 1],
        [1, 1, 0]]), tensor([[2, 2, 2],
        [2, 2, 0]]))]


In [18]:
# 5. Truncate Left
import torch

# Define the attention_mask tensor with leading zeros
attention_mask = torch.tensor([
    [0, 0, 1, 1, 1, 1, 1],
    [0, 1, 1, 1, 1, 0, 0],
    [0, 0, 0, 1, 1, 1, 1]
])

def truncate_left(attention_mask):
    # Identify positions of leading zeros and their cumulative product
    zero_mask = attention_mask == 0
    cumprod = zero_mask.cumprod(dim=1)
    
    # Calculate how many leading zeros are present in each sequence
    leading_zeros_count = cumprod.sum(dim=1)
    min_leading_zeros = torch.min(leading_zeros_count)
    
    # Determine the offset to truncate the matrix
    truncation_offset = min_leading_zeros.item()
    
    # Truncate the attention_mask by removing columns from the start
    truncated_attention_mask = attention_mask[:, truncation_offset:]
    return truncated_attention_mask

# Apply the truncate_left method to the attention_mask
truncated_attention_mask = truncate_left(attention_mask)

# Print the original and truncated attention masks
print("Original Attention Mask:\n", attention_mask)
print("Truncated Attention Mask:\n", truncated_attention_mask)


Original Attention Mask:
 tensor([[0, 0, 1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 1, 1, 1, 1]])
Truncated Attention Mask:
 tensor([[0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0],
        [0, 0, 1, 1, 1, 1]])
