In [47]:
import torch
import numpy as np

In [145]:
def encode_consecutive_bool_tensor(bool_tensor):
    
    # Convert the boolean tensor to a tensor of integers (0 and 1)
    int_tensor = bool_tensor.to(torch.int)
    
    # Compute the differences between consecutive elements
    diff_tensor = torch.diff(int_tensor, prepend=int_tensor[:1])
    
    # Find positions where the value changes
    change_positions = torch.cat((torch.tensor([0],device="cuda"), torch.nonzero(diff_tensor, as_tuple=True)[0], torch.tensor([int_tensor.size(0)],device="cuda")))
    
    # Calculate the lengths of consecutive runs
    run_lengths = torch.diff(change_positions)
    
    # Identify run lengths that need to be split
    over_limit_mask = run_lengths > 255
    num_full_chunks = torch.div(run_lengths, 255, rounding_mode='floor')
    remainder_chunks = run_lengths % 255

    # Build the final tensor with correct sizes
    full_chunks = torch.repeat_interleave(torch.tensor([255], dtype=torch.uint8,device="cuda"), num_full_chunks.sum().item())
    remainders = remainder_chunks[remainder_chunks > 0].to(torch.uint8)

    # Concatenate all parts
    final_lengths = torch.cat((full_chunks, remainders))

    return final_lengths


In [153]:
def merge_255_values(tensor):
    # Find positions of 255 values
    is_255 = tensor == 255
    changes = torch.diff(is_255.int(), prepend=torch.tensor([0],device="cuda"), append=torch.tensor([0],device="cuda"))

    # Start and end positions of consecutive 255 blocks
    start_positions = torch.nonzero(changes == 1, as_tuple=True)[0]
    end_positions = torch.nonzero(changes == -1, as_tuple=True)[0] - 1

    # Compute the cumulative sum for segments with 255 values
    mask = is_255.clone()
    mask[end_positions + 1] = True  # Include the element after the last 255 in the block
    cumsum_tensor = torch.cumsum(tensor * mask, dim=0)
    
    # Create a tensor to store the merged values
    merged_tensor = tensor.clone()

    # Add the cumulative sum of 255 blocks to the element after the last 255
    merged_tensor[end_positions + 1] += cumsum_tensor[end_positions] - cumsum_tensor[start_positions] + tensor[start_positions]

    # Remove the 255 values
    mask[start_positions] = False  # Keep the start of each block
    merged_tensor = merged_tensor[~is_255 | (is_255 & mask)]

    return merged_tensor

In [155]:
def reconstruct_bool_tensor(encoded_list, first_value):
       # Convert the encoded list to a PyTorch tensor of uint8

    encoded_tensor = torch.tensor(encoded_list, dtype=torch.uint8,device="cuda")
    
    # Convert encoded_tensor to int32 for repeat_interleave
    encoded_tensor = encoded_tensor.to(torch.int32)
    tensor_len = 0

    while len(encoded_tensor)!=tensor_len:
        encoded_tensor = merge_255_values(encoded_tensor)
        tensor_len = len(encoded_tensor)
    
    
    # Create an alternating values tensor based on the first value
    num_segments = len(encoded_tensor)
    values = torch.tensor([first_value], dtype=torch.bool,device="cuda").repeat(num_segments)
    values[1::2] = ~values[1::2]  # Flip every other value
    
    # Create the decoded tensor by repeating each value according to the run lengths
    decoded_tensor = torch.repeat_interleave(values, encoded_tensor)
    
    return decoded_tensor
    

In [159]:
bool_tensor[0].cpu().item()

False

In [146]:
bool_tensor = torch.tensor([False] * 300 + [True] * 5 + [False] * 3, dtype=torch.bool,device="cuda")
bool_tensor

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [150]:
encoded_list = encode_consecutive_bool_tensor(bool_tensor)
print(encoded_list)  # Output: [1, 1, 2, 3, 1]

tensor([255,  45,   5,   3], device='cuda:0', dtype=torch.uint8)


In [134]:
buf = encoded_list.cpu().numpy().tobytes()
buf

b'\xff-\x05\x03'

In [135]:
reconst = np.frombuffer(buf, dtype=np.uint8)
reconst

array([255,  45,   5,   3], dtype=uint8)

In [157]:
decode_tensor = reconstruct_bool_tensor(reconst, False)
decode_tensor

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 