In [1]:
import logging
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Iterator
from arithmetic_coder import arithmetic_coder, ac_utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)


class Metric:
    def __init__(self):
        self.total_length = 0
        self.compressed_length = 0

    def compute_ratio(self):
        if self.total_length != 0 and self.compressed_length != 0:
            return (
                self.total_length / self.compressed_length,
                self.compressed_length / self.total_length,
            )
        else:
            return 0, 0

    def accumulate(self, compressed, original):
        if isinstance(compressed, list):
            self.compressed_length += len(compressed)
        elif isinstance(compressed, int):
            self.compressed_length += compressed
        else:
            raise ValueError(f"Unsupported compressed length type: {type(compressed)}")

        if isinstance(original, list):
            self.total_length += len(original)
        elif isinstance(original, int):
            self.total_length += original
        else:
            raise ValueError(f"Unsupported original length type: {type(original)}")


def compress(compress_input, logits, metric):
    """
    :param compress_input: symbols to be compressed
    :param logits: generation probabilities from the model
    :param metric: compression metrics
    :return: compressed result, a floating number
    """
    output = []
    # Initialize a Encoder Object
    # Precision is for the encoder, not the model
    # You must have the same precision for encoder and decoder
    # Tricky things here: Though theoratically prefill == decode, but in practice there are numerical problems
    encoder = arithmetic_coder.Encoder(
        base=2,
        precision=64,
        output_fn=output.append,
    )
    # the first symbol should be saved for generation in decoding
    start_symbol = compress_input[:, :1]
    probs = logits.softmax(dim=-1).to(torch.float32)
    pd = torch.gather(probs, dim=-1, index=compress_input[:, 1:].unsqueeze(-1)).squeeze(
        -1
    )

    probs = np.vstack(probs.detach().cpu().numpy().squeeze())

    sequence_array = compress_input[:, 1:].detach().cpu().numpy().reshape(-1)

    pd = pd.squeeze()

    # compress the sequence
    for symbol, prob, pd_prob in zip(sequence_array, probs, pd):
        encoder.encode(
            ac_utils.normalize_pdf_for_arithmetic_coding(prob, np.float32), symbol
        )
    encoder.terminate()

    # to visualize and compute metrics, map to str
    compressed_bits = "".join(map(str, output))
    # you can only save in bytes, so need to pad some bits
    compressed_bytes, num_padded_bits = ac_utils.bits_to_bytes(compressed_bits)
    metric.accumulate(len(compressed_bytes) + num_padded_bits, len(sequence_array))

    compress_rate, compress_ratio = metric.compute_ratio()
    logger.info(f"compressed length: {metric.compressed_length}")
    logger.info(f"original length: {metric.total_length}")
    logger.info(f"compression ratio: {compress_ratio:.6f}")
    logger.info(f"compression rate: {compress_rate:.6f}")

    return compressed_bytes, num_padded_bits, start_symbol, sequence_array, pd, probs


def decode(
    compressed_bytes,
    num_padded_bits,
    model,
    start_symbol,
    device,
    original_seq_len,
    original_sequence=None,
    pd=None,
    probs=None,
    do_test=True,
):
    """

    :param compressed_bytes:  compressed data
    :param num_padded_bits:  padded bits
    :param model: same model as encoder
    :param start_symbol: first symbol to generate
    :param original_sequence: original symbol sequence, for testing purpose
    :param pd: actually not needed, used for testing
    :param probs:
    :param device:
    :return:
    """
    # convert bytes back to bit stream
    data_iter = iter(
        ac_utils.bytes_to_bits(compressed_bytes, num_padded_bits=num_padded_bits)
    )

    # utils function to read bits
    def _input_fn(bit_sequence: Iterator[str] = data_iter) -> int | None:
        try:
            return int(next(bit_sequence))
        except StopIteration:
            return None

    # initialize a Decoder Object
    decoder = arithmetic_coder.Decoder(
        base=2,
        precision=64,
        input_fn=_input_fn,
    )

    sequence_array_de = start_symbol.squeeze(0).detach().cpu().numpy()
    sequence_array_de_input = start_symbol
    target_diff_list = []
    target_in_top5_list = []

    # loop for decompressing
    # pad the input to the original length
    sequence_array_de_input = torch.tensor(sequence_array_de_input, dtype=torch.long, device=device)
    sequence_array_de_input = torch.nn.functional.pad(sequence_array_de_input, (0, original_seq_len-1), value=0)

    for i in range(original_seq_len):
        # attention_mask = (sequence_array_de_input != 0).long()
        with torch.no_grad():
            logits = model(sequence_array_de_input, use_cache=False).logits.to(
                torch.float32
            )
        # get generaton probabilities, decode the next token
        prob_de = logits.softmax(dim=-1).detach().cpu().numpy().squeeze(0)

        de_token = decoder.decode(
            ac_utils.normalize_pdf_for_arithmetic_coding(prob_de[i], np.float32)
        )
        # using the original probs to decode, for testing purpose
        # de_token = decoder.decode(ac_utils.normalize_pdf_for_arithmetic_coding(probs[i]))
        # append to the generated sequence
        sequence_array_de = np.append(sequence_array_de, de_token)

        current_len = len(sequence_array_de)
        target_len = original_seq_len

        if current_len < target_len:
            padded = np.pad(
                sequence_array_de, (0, (target_len - current_len)), constant_values=0
            )
        else:
            padded = sequence_array_de
        sequence_array_de_input = torch.tensor(
            padded, dtype=torch.long, device=device
        ).unsqueeze(0)

        if do_test:
            top_indices_de = prob_de[i].argsort()[-5:][::-1]
            top_indices = probs[i].argsort()[-5:][::-1]

            # target diff
            target_diff = probs[i, original_sequence[i]] - prob_de[i, original_sequence[i]]
            target_diff_list.append(target_diff)

            # target in top 5
            target_in_top5 = original_sequence[i] in top_indices
            target_in_top5_list.append(target_in_top5)
            print(
                f"idx: {i}, original token: {original_sequence[i]}, decoder token: {de_token}"
            )
            print(
                f"diff probs max: {max(abs(probs[i] - prob_de[i]))}, original sum error: {abs(sum(prob_de[i]) - 1.0)}, decoder sum error: {abs(sum(probs[i]) - 1.0)}"
            )
            print(
                f"original: {top_indices}, target_in_top5: {target_in_top5} decode: {top_indices_de}, "
            )
            print(f"target diff: {target_diff}")
            if original_sequence[i] != de_token:
                import pdb
                pdb.set_trace()
        
    return sequence_array_de_input

In [3]:
def write_padded_bytes(filename: str, data: bytes, num_padded_bits: int, original_length: int):
    """
    file format:
    - first byte: number of padded bit
    - second and third byte: original length (usually, llm context will not exceed 65535)
    - subsequent bytes: actual bytes data

    :param filename: output file name
    :param data: bytes data to write
    :param padding_bits: number of padded bits (must be between 0 and 7)
    :param original_length: original length of the uncompressed data (in tokens)
    """

    if not 0 <= num_padded_bits <= 7:
        raise ValueError("num_padded_bits must be between 0 and 7.")

    if not 0 <= original_length <= 65535:
        raise ValueError("original_length must be between 0 and 65535.")

    if not isinstance(data, bytes):
        raise TypeError("data must be of bytes type.")

    with open(filename, 'wb') as f:
        padding_byte = num_padded_bits.to_bytes(1, 'big')
        f.write(padding_byte)
        f.write(original_length.to_bytes(2, 'big'))
        f.write(data)

def read_padded_bytes(filename: str) -> tuple[bytes, int]:
    """
    Read data and padding bits from a file.

    :param filename: The name of the file to read.
    :return: A tuple containing (bytes data, number of padded bits).
             May raise an error if the file is empty or improperly formatted.
    """

    with open(filename, 'rb') as f:
        # the first byte indicates the number of padded bits
        padding_byte = f.read(1)

        # If the file is empty, f.read(1) will return an empty bytes object b''
        if not padding_byte:
            raise EOFError("File is empty or improperly formatted: unable to read padding bits byte.")

        original_length_bytes = f.read(2)
        if not original_length_bytes:
            raise EOFError("File is empty or improperly formatted: unable to read original length bytes.")
    
        padding_bits = int.from_bytes(padding_byte, 'big')
        original_length = int.from_bytes(original_length_bytes, 'big')

        data = f.read()
        
        return data, padding_bits, original_length

In [4]:
import time
# model and tokenizer loading
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
llm = AutoModelForCausalLM.from_pretrained("pretrained/Qwen2.5-0.5B", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("pretrained/Qwen2.5-0.5B", use_fast=False)
llm.eval()

pad_token_id = tokenizer.pad_token_id

# data
sample_text = "Super simple text to be tested."
# sample_text = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
#sample_text = r"""Greenhouse gas emissions from the burning of fossil fuels have pushed the acidity of the world's oceans past a safe threshold, scientists warn, threatening their ability to sustain shellfish and corals and help us in the fight against climate change. A new report says that ocean acidification is the latest "planetary boundary" to be crossed, a reference to a set of warning signs related to key planetary systems that keep the Earth safe for human civilization. Other planetary boundaries that have already been crossed — including dangerous levels of chemical pollution, the warming atmosphere and changes to the nutrient cycle — have already signalled threats to people.  "Go outside of these boundaries and you first enter a danger zone, with higher risk of causing changes that would undermine that ability to support human life and human development," said Johan Rockström, director of the Potsdam Institute for Climate Impact Research, which is behind the Planetary Health Check report released on Wednesday. "And once you are at the upper end of the uncertainty range ... you enter the red zone, the high-risk zone where most science agrees that we are very likely to depress buttons that will cause irreversible changes, basically committing ourselves to drifting away from livable conditions on Earth." Adding the oceans to the planetary boundaries list is a major concern because of the billions of people who depend on them. Continuing ocean acidification could not only destroy fisheries that people rely on for food but reduce the ability of the ocean to absorb carbon dioxide and moderate global warming. As humans burn fossil fuels and pump carbon dioxide into the atmosphere, it's estimated that the ocean is absorbing more than a quarter of that CO2.  "Just like when we add carbon dioxide to Coke or soda, that makes the soft drink more acidic," said Christopher Harley, a professor who studies climate change and the ocean at the University of British Columbia.  But when CO2 is absorbed, the chemical process effectively lowers the availability of a mineral that certain marine life — from shellfish to coral — need to develop their bodies. "It makes it harder to build shells — and you need to add shell if you want to grow bigger," Harley explained, comparing it to the construction of a house.  "All of a sudden, the building materials become more costly. You're either going to build smaller homes or not as many." """

# work flow
compression_start_time = time.time()

tokenized = tokenizer(sample_text, return_tensors="pt")

metric = Metric()
with torch.inference_mode():
    # we don't need the last token's logits
    logits = (
        llm(tokenized["input_ids"], use_cache=False).logits[:, :-1].to(torch.float32)
    )
compressed_bytes, num_padded_bits, start_symbol, sequence_array, pd, probs = compress(
    tokenized["input_ids"], logits, metric
)

compression_end_time = time.time()

print(compressed_bytes)
print(num_padded_bits)
original_length = tokenized["input_ids"].shape[1] - 1
print(original_length)
write_padded_bytes("compressed.bin", compressed_bytes, num_padded_bits, original_length)
compressed_bytes, num_padded_bits, original_length = read_padded_bytes("compressed.bin")
print(compressed_bytes)
print(num_padded_bits)
print(original_length)

decompression_start_time = time.time()

decompressed = decode(
    compressed_bytes,
    num_padded_bits,
    llm,
    start_symbol,
    device,
    original_length,
    sequence_array,
    pd,
    probs,
    do_test=True,
)

decompression_end_time = time.time()

print(tokenized["input_ids"].squeeze(0).numpy())
print(decompressed)

print(f"Compression time: {compression_end_time - compression_start_time:.2f} seconds")
print(f"Decompression time: {decompression_end_time - decompression_start_time:.2f} seconds")

`torch_dtype` is deprecated! Use `dtype` instead!


b'2G\xd9[\xa4d'
1
6
b'2G\xd9[\xa4d'
1
6
idx: 0, original token: 4285, decoder token: 4285
diff probs max: 0.0, original sum error: 0.000357210636138916, decoder sum error: 0.000357210636138916
original: [ 2462  1515    69  1040 76652], target_in_top5: False decode: [ 2462  1515    69  1040 76652], 
target diff: 0.0


  sequence_array_de_input = torch.tensor(sequence_array_de_input, dtype=torch.long, device=device)


idx: 1, original token: 1467, decoder token: 1467
diff probs max: 0.0, original sum error: 0.00028574466705322266, decoder sum error: 0.00028574466705322266
original: [  11  323 3405 1616  714], target_in_top5: False decode: [  11  323 3405 1616  714], 
target diff: 0.0
idx: 2, original token: 311, decoder token: 311
diff probs max: 0.0, original sum error: 0.000306546688079834, decoder sum error: 0.000306546688079834
original: [6440  311 1034 5980 8692], target_in_top5: True decode: [6440  311 1034 5980 8692], 
target diff: 0.0
idx: 3, original token: 387, decoder token: 387
diff probs max: 0.0, original sum error: 0.00027436017990112305, decoder sum error: 0.00027436017990112305
original: [2168 1467 8806 9308 7699], target_in_top5: False decode: [2168 1467 8806 9308 7699], 
target diff: 0.0
idx: 4, original token: 12510, decoder token: 12510
diff probs max: 0.0, original sum error: 0.00024390220642089844, decoder sum error: 0.00024390220642089844
original: [ 1349  2952  1483 12596  5