In [None]:
%env CUBLAS_WORKSPACE_CONFIG=:16:8

In [None]:
!pip install torch transformers pandas tqdm samplings==0.1.7 constriction bitsandbytes accelerate

# LLM-Based Compressor Implementation

Below is the implementation of the LLM-based compressor.

## Instructions

1. Run the **two cells above**.
2. Upload a file you wish to compress (ideally a `.txt` file).
After uploading your file, update the variable at the bottom of the notebook:

```python
input_path = "ENTER FILENAME HERE"
```
3. Run the third cell to start compressing.
4. Once it completes, two files will be created:
   - The .bin file is the encoded version of the input.
   - The recovered file is also written to show the decoded version of the compressed file.

The compressed and recovered versions should match **exactly**.

In [None]:
"""
LLM-based lossless compression using arithmetic coding
"""

import gzip
import math
import os
import random
import struct
import time
from enum import Enum

import numpy as np
import torch
import zstandard as zstd
import constriction
from transformers import AutoModelForCausalLM, AutoTokenizer, StaticCache

MODEL_NAME = "unsloth/Llama-3.2-1B"
DTYPE = torch.float16
MAX_SEQ_LEN = 2048

SEED = 42
NUM_PARALLEL_SEGMENTS = 2

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")


def set_determinism(seed=SEED):
    """
    Configure all random number generators for reproducible results.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def normalize_pdf(logits):
    """
    Convert logits to a valid probability distribution for arithmetic coding.
    """
    probs = torch.softmax(logits, dim=-1).to(torch.float32)
    vocab_size = probs.shape[-1]
    epsilon = 1e-7
    scale = 1.0 - (2.0 * vocab_size * epsilon)
    return (probs * scale) + epsilon


def format_duration(seconds):
    """Format a duration in seconds to a human-readable string."""
    if seconds < 0.01:
        return f"{seconds * 1000:.2f}ms"
    return f"{seconds:.3f}s"


def get_gpu_memory_str():
    """Get current GPU memory usage as a formatted string."""
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    return f"GPU: {allocated:.1f}/{reserved:.1f} GB"


def report_compression(codec, duration, input_bytes, output_bytes, num_tokens=None, filename=None):
    """Print compression metrics to stdout."""
    ratio = input_bytes / output_bytes if output_bytes else float("inf")
    bits_per_byte = (output_bytes * 8) / input_bytes if input_bytes else 0
    throughput_kb = (input_bytes / 1024) / duration if duration else 0

    prefix = f"{filename}: " if filename else ""
    metrics = (
        f"{prefix}{codec.value:12} | {format_duration(duration):>10} | "
        f"{input_bytes} B -> {output_bytes} B | "
        f"ratio: {ratio:.2f}x | {bits_per_byte:.2f} bits/byte | "
        f"{throughput_kb:.1f} KB/s"
    )

    if num_tokens:
        tokens_per_sec = num_tokens / duration if duration else 0
        metrics += f" | {tokens_per_sec:.1f} tok/s"

    print(metrics)


def report_decompression(codec, duration, input_bytes, output_bytes, num_tokens=None):
    """Print decompression metrics to stdout."""
    throughput_kb = (output_bytes / 1024) / duration if duration else 0

    metrics = (
        f"{codec.value:12} | {format_duration(duration):>10} | "
        f"{input_bytes} B -> {output_bytes} B | "
        f"{throughput_kb:.1f} KB/s"
    )

    if num_tokens:
        tokens_per_sec = num_tokens / duration if duration else 0
        metrics += f" | {tokens_per_sec:.1f} tok/s"

    print(metrics)


class Codec(Enum):
    """Supported compression codecs for benchmarking."""

    LLM = "LLM"
    ZSTD = "ZSTD"
    GZIP = "GZIP"


class LLMCompressor:
    """
    Compressor that uses a language model for arithmetic coding.

    Uses a pretrained LLM to predict token probabilities, then encodes
    tokens using range coding. Achieves high compression ratios on
    natural language text by leveraging the model's predictions.
    """

    def __init__(self, model_name=MODEL_NAME, max_seq_len=MAX_SEQ_LEN, dtype=DTYPE):
        self.model_name = model_name
        self.max_seq_len = max_seq_len
        self.dtype = dtype
        self.model = None
        self.tokenizer = None

    def load_model(self):
        if self.model is not None:
            return

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.model_name, use_fast=True, trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype=self.dtype,
            device_map="auto",
            attn_implementation="sdpa",
        )
        self.model.eval()

    def create_static_cache(self, device, total_tokens, batch_size=1):
        """Create a static KV cache for efficient inference."""
        max_cache_len = min(self.max_seq_len, total_tokens)
        return StaticCache(
            config=self.model.config,
            batch_size=batch_size,
            max_cache_len=max_cache_len,
            device=device,
            dtype=self.dtype,
        )

    def warmup_and_capture_graph(
        self, device, vocab_size, past_key_values, static_input, static_position, first_tokens, batch_size=1
    ):
        """
        Run warmup passes and capture the forward pass as a CUDA graph.
        """
        # Warmup passes before graph capture
        warmup_stream = torch.cuda.Stream()
        warmup_stream.wait_stream(torch.cuda.current_stream())

        # Handle single token or batch of tokens
        if isinstance(first_tokens, int):
            static_input.fill_(first_tokens)
        else:
            for i, tok in enumerate(first_tokens):
                static_input[i, 0] = tok

        with torch.cuda.stream(warmup_stream):
            for _ in range(3):
                self.model(
                    input_ids=static_input,
                    past_key_values=past_key_values,
                    use_cache=True,
                    cache_position=static_position.zero_(),
                )

        torch.cuda.current_stream().wait_stream(warmup_stream)
        past_key_values.reset()
        static_position.zero_()

        # Allocate output buffer (batched or single)
        if batch_size > 1:
            static_probs = torch.empty((batch_size, vocab_size), device=device, dtype=torch.float32)
        else:
            static_probs = torch.empty((vocab_size,), device=device, dtype=torch.float32)

        # Capture forward pass as CUDA graph
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            outputs = self.model(
                input_ids=static_input,
                past_key_values=past_key_values,
                use_cache=True,
                cache_position=static_position,
            )
            if batch_size > 1:
                probs_tensor = normalize_pdf(outputs.logits[:, 0, :])
            else:
                probs_tensor = normalize_pdf(outputs.logits[0, 0, :])
            static_probs.copy_(probs_tensor)

        return graph, static_probs

    def print_progress(self, current, total, segment_index, num_segments, operation):
        """Print progress update to stdout."""
        if current % 200 == 0:
            print(
                f"\r {operation} token {current}/{total} "
                f"(segment {segment_index + 1}/{num_segments})",
                end="",
            )

    def write_header(self, file_handle, total_tokens, segment_first_tokens):
        """
        Write the compressed file header.

        Format: total_tokens (4B) + num_segments (4B) + first_tokens (4B each)
        """
        file_handle.write(struct.pack("<I", total_tokens))
        file_handle.write(struct.pack("<I", len(segment_first_tokens)))
        for token in segment_first_tokens:
            file_handle.write(struct.pack("<I", token))

    def read_header(self, raw_data):
        """
        Read and parse the compressed file header.
        """
        if len(raw_data) < 8:
            raise ValueError("File too small to contain header.")

        total_tokens = struct.unpack_from("<I", raw_data, 0)[0]
        num_segments = struct.unpack_from("<I", raw_data, 4)[0]

        if num_segments == 0 or total_tokens == 0:
            raise ValueError("Invalid header: zero segments or tokens.")

        offset = 8
        if len(raw_data) < offset + 4 * num_segments:
            raise ValueError("File too small for segment first tokens.")

        segment_first_tokens = []
        for _ in range(num_segments):
            token = struct.unpack_from("<I", raw_data, offset)[0]
            segment_first_tokens.append(int(token))
            offset += 4

        return total_tokens, segment_first_tokens, offset

    def run_zstd_compress_baseline(self, raw_bytes, filename=None):
        """Run zstd compression for baseline comparison."""
        start = time.time()
        compressor = zstd.ZstdCompressor(level=22)
        compressed = compressor.compress(raw_bytes)
        report_compression(
            Codec.ZSTD, time.time() - start, len(raw_bytes), len(compressed),
            filename=filename
        )

    def run_zstd_decompress_baseline(self, raw_bytes):
        """Run zstd decompression for baseline comparison."""
        compressor = zstd.ZstdCompressor(level=22)
        compressed = compressor.compress(raw_bytes)
        decompressor = zstd.ZstdDecompressor()

        start = time.time()
        decompressed = decompressor.decompress(compressed)
        report_decompression(
            Codec.ZSTD, time.time() - start, len(compressed), len(decompressed)
        )

    def run_gzip_compress_baseline(self, raw_bytes, filename=None):
        """Run gzip compression for baseline comparison."""
        start = time.time()
        compressed = gzip.compress(raw_bytes, compresslevel=9)
        report_compression(
            Codec.GZIP, time.time() - start, len(raw_bytes), len(compressed),
            filename=filename
        )

    def run_gzip_decompress_baseline(self, raw_bytes):
        """Run gzip decompression for baseline comparison."""
        compressed = gzip.compress(raw_bytes, compresslevel=9)

        start = time.time()
        decompressed = gzip.decompress(compressed)
        report_decompression(
            Codec.GZIP, time.time() - start, len(compressed), len(decompressed)
        )

    def compress(self, input_path, output_path, show_baselines=True):
        """
        Compress a text file using LLM-based arithmetic coding.
        """
        filename = os.path.basename(input_path)
        print(f"\n{'='*60}\nCompressing: {filename}\n{'='*60}")
        set_determinism()
        self.load_model()

        device = self.model.device
        vocab_size = self.model.config.vocab_size

        # Read file as UTF-8 text
        with open(input_path, "r", encoding="utf-8") as file:
            content = file.read()

        input_ids = self.tokenizer(content, add_special_tokens=False)["input_ids"]
        tokens = np.array(input_ids, dtype=np.int32)
        total_tokens = len(tokens)

        if total_tokens == 0:
            raise ValueError("Input has no tokens to compress.")

        num_segments = math.ceil(total_tokens / self.max_seq_len)
        print(f" Total tokens: {total_tokens}, segments: {num_segments}, parallel: {min(NUM_PARALLEL_SEGMENTS, num_segments)}")

        # Precompute segment info
        segment_starts = []
        segment_ends = []
        segment_lengths = []
        segment_first_tokens = []

        for seg_idx in range(num_segments):
            seg_start = seg_idx * self.max_seq_len
            seg_end = min(seg_start + self.max_seq_len, total_tokens)
            segment_starts.append(seg_start)
            segment_ends.append(seg_end)
            segment_lengths.append(seg_end - seg_start)
            segment_first_tokens.append(int(tokens[seg_start]))

        # Determine batch size (may be smaller for final batch)
        batch_size = min(NUM_PARALLEL_SEGMENTS, num_segments)

        # Setup inference state for batched processing
        past_key_values = self.create_static_cache(device, total_tokens, batch_size=batch_size)
        static_input = torch.empty((batch_size, 1), device=device, dtype=torch.long)
        static_position = torch.zeros((1,), device=device, dtype=torch.long)

        # Pinned CPU buffer for probability transfer
        probs_cpu = torch.empty((batch_size, vocab_size), dtype=torch.float32, pin_memory=True)
        probs_np = probs_cpu.numpy()

        encoder = constriction.stream.queue.RangeEncoder()
        entropy_family = constriction.stream.model.Categorical(perfect=False)

        with torch.inference_mode():
            # Get first tokens for initial batch for warmup
            first_batch_tokens = segment_first_tokens[:batch_size]
            graph, static_probs = self.warmup_and_capture_graph(
                device, vocab_size, past_key_values,
                static_input, static_position, first_batch_tokens, batch_size=batch_size
            )

            start_time = time.time()
            num_coded_tokens = total_tokens - num_segments
            tokens_coded = 0

            # Process segments in batches
            num_batches = math.ceil(num_segments / NUM_PARALLEL_SEGMENTS)

            for batch_idx in range(num_batches):
                batch_start = batch_idx * NUM_PARALLEL_SEGMENTS
                batch_end = min(batch_start + NUM_PARALLEL_SEGMENTS, num_segments)
                active_segments = batch_end - batch_start

                # Get segment info for this batch
                batch_seg_starts = segment_starts[batch_start:batch_end]
                batch_seg_lengths = segment_lengths[batch_start:batch_end]
                batch_first_tokens = segment_first_tokens[batch_start:batch_end]
                max_seg_len = max(batch_seg_lengths)

                # Reset KV cache and set first tokens
                past_key_values.reset()
                static_position.zero_()

                # Fill input with first tokens
                for i in range(batch_size):
                    if i < active_segments:
                        static_input[i, 0] = batch_first_tokens[i]
                    else:
                        static_input[i, 0] = batch_first_tokens[0]  # Padding

                for step in range(1, max_seg_len):
                    static_position.fill_(step - 1)
                    graph.replay()

                    probs_cpu.copy_(static_probs, non_blocking=False)

                    active_idx = []
                    symbols = []
                    for i in range(active_segments):
                        if step < batch_seg_lengths[i]:
                            seg_start = batch_seg_starts[i]
                            symbols.append(int(tokens[seg_start + step]))
                            active_idx.append(i)

                    if symbols:
                        symbols_np = np.array(symbols, dtype=np.int32)
                        probs_active = np.ascontiguousarray(probs_np[active_idx, :])
                        encoder.encode(symbols_np, entropy_family, probs_active)
                        tokens_coded += len(symbols)

                    for i in range(batch_size):
                        if i < active_segments and step < batch_seg_lengths[i]:
                            seg_start = batch_seg_starts[i]
                            static_input[i, 0] = int(tokens[seg_start + step])

                    # Progress update
                    if tokens_coded % 200 == 0:
                        print(
                            f"\r Compressing token {tokens_coded}/{num_coded_tokens} "
                            f"(batch {batch_idx + 1}/{num_batches}) | {get_gpu_memory_str()}",
                            end="",
                        )

            print("")

            compressed_data = encoder.get_compressed()

            with open(output_path, "wb") as file:
                self.write_header(file, total_tokens, segment_first_tokens)
                file.write(compressed_data)

            input_bytes = content.encode("utf-8")
            report_compression(
                Codec.LLM,
                time.time() - start_time,
                len(input_bytes),
                os.path.getsize(output_path),
                total_tokens,
                filename=filename,
            )

        if show_baselines:
            input_bytes = content.encode("utf-8")
            self.run_zstd_compress_baseline(input_bytes, filename=filename)
            self.run_gzip_compress_baseline(input_bytes, filename=filename)

    def decompress(self, input_path, output_path):
        """
        Decompress a file compressed with LLM-based arithmetic coding.
        """
        print(f"Starting decompression [{input_path}]...")
        set_determinism()
        self.load_model()

        device = self.model.device
        vocab_size = self.model.config.vocab_size

        # Read compressed file
        with open(input_path, "rb") as file:
            raw_data = file.read()

        total_tokens, segment_first_tokens, data_offset = self.read_header(raw_data)
        num_segments = len(segment_first_tokens)

        # Setup decoder
        compressed_array = np.frombuffer(raw_data[data_offset:], dtype=np.uint32).copy()
        decoder = constriction.stream.queue.RangeDecoder(compressed_array)
        entropy_family = constriction.stream.model.Categorical(perfect=False)

        # Precompute segment info (same as compression)
        segment_starts = []
        segment_lengths = []
        for seg_idx in range(num_segments):
            seg_start = seg_idx * self.max_seq_len
            seg_end = min(seg_start + self.max_seq_len, total_tokens)
            segment_starts.append(seg_start)
            segment_lengths.append(seg_end - seg_start)

        # Output array for decoded tokens
        decoded_tokens = np.empty(total_tokens, dtype=np.int32)

        # Write first tokens to their positions
        for seg_idx, tok in enumerate(segment_first_tokens):
            decoded_tokens[segment_starts[seg_idx]] = tok

        # Batched inference setup
        batch_size = min(NUM_PARALLEL_SEGMENTS, num_segments)
        past_key_values = self.create_static_cache(device, total_tokens, batch_size=batch_size)
        static_input = torch.empty((batch_size, 1), device=device, dtype=torch.long)
        static_position = torch.zeros((1,), device=device, dtype=torch.long)

        probs_cpu = torch.empty((batch_size, vocab_size), dtype=torch.float32, pin_memory=True)
        probs_np = probs_cpu.numpy()

        with torch.inference_mode():
            # Warmup with first batch tokens
            first_batch_tokens = segment_first_tokens[:batch_size]
            graph, static_probs = self.warmup_and_capture_graph(
                device, vocab_size, past_key_values,
                static_input, static_position, first_batch_tokens, batch_size=batch_size
            )

            start_time = time.time()
            num_coded_tokens = total_tokens - num_segments
            tokens_decoded = 0

            # Process segments in batches
            num_batches = math.ceil(num_segments / NUM_PARALLEL_SEGMENTS)

            for batch_idx in range(num_batches):
                batch_start = batch_idx * NUM_PARALLEL_SEGMENTS
                batch_end = min(batch_start + NUM_PARALLEL_SEGMENTS, num_segments)
                active_segments = batch_end - batch_start

                # Get segment info for this batch
                batch_seg_starts = segment_starts[batch_start:batch_end]
                batch_seg_lengths = segment_lengths[batch_start:batch_end]
                batch_first_tokens = segment_first_tokens[batch_start:batch_end]
                max_seg_len = max(batch_seg_lengths)

                # Reset KV cache and set first tokens
                past_key_values.reset()
                static_position.zero_()

                # Fill input with first tokens
                for i in range(batch_size):
                    if i < active_segments:
                        static_input[i, 0] = batch_first_tokens[i]
                    else:
                        static_input[i, 0] = batch_first_tokens[0]

                for step in range(1, max_seg_len):
                    static_position.fill_(step - 1)
                    graph.replay()

                    probs_cpu.copy_(static_probs, non_blocking=False)

                    active_idx = []
                    for i in range(active_segments):
                        if step < batch_seg_lengths[i]:
                            active_idx.append(i)

                    if active_idx:
                        probs_active = np.ascontiguousarray(probs_np[active_idx, :])
                        decoded_step = decoder.decode(entropy_family, probs_active)

                        for j, i in enumerate(active_idx):
                            pos = batch_seg_starts[i] + step
                            decoded_tokens[pos] = int(decoded_step[j])
                            static_input[i, 0] = int(decoded_step[j])

                        tokens_decoded += len(active_idx)

                    if tokens_decoded % 200 == 0:
                        print(
                            f"\r Decompressing token {tokens_decoded}/{num_coded_tokens} "
                            f"(batch {batch_idx + 1}/{num_batches}) | {get_gpu_memory_str()}",
                            end="",
                        )

            print("")

            if tokens_decoded != num_coded_tokens:
                raise RuntimeError(
                    f"Token count mismatch: expected {num_coded_tokens}, "
                    f"decoded {tokens_decoded}"
                )

            text = self.tokenizer.decode(
                decoded_tokens.tolist(), clean_up_tokenization_spaces=False
            )

            with open(output_path, "w", encoding="utf-8") as file:
                file.write(text)

            output_bytes = text.encode("utf-8")
            report_decompression(
                Codec.LLM,
                time.time() - start_time,
                len(raw_data),
                len(output_bytes),
                total_tokens,
            )

        self.run_zstd_decompress_baseline(output_bytes)
        self.run_gzip_decompress_baseline(output_bytes)


    def compress_directory(self, input_dir):
        """
        Compress all files in a directory.

        """
        if not os.path.isdir(input_dir):
            raise ValueError(f"Not a directory: {input_dir}")

        output_dir = os.path.join(input_dir, "compressed")
        os.makedirs(output_dir, exist_ok=True)

        files = []
        for filename in sorted(os.listdir(input_dir)):
            filepath = os.path.join(input_dir, filename)
            if os.path.isfile(filepath) and not filename.endswith(".bin"):
                files.append(filepath)

        if not files:
            print(f"No files found in {input_dir}")
            return

        print(f"Found {len(files)} files to compress")

        self.load_model()

        for filepath in files:
            filename = os.path.basename(filepath)
            output_path = os.path.join(output_dir, filename + ".bin")

            try:
                self.compress(filepath, output_path, show_baselines=True)
            except Exception as e:
                print(f"Error compressing {filename}: {e}")
                continue

        print(f"\n{'='*60}")
        print(f"Compression complete. Output directory: {output_dir}")
        print(f"{'='*60}")


    def decompress_directory(self, input_dir):
        if not os.path.isdir(input_dir):
            raise ValueError(f"Not a directory: {input_dir}")

        output_dir = os.path.join(input_dir, "recovered")
        os.makedirs(output_dir, exist_ok=True)

        # Find all .bin files
        files = []
        for filename in sorted(os.listdir(input_dir)):
            filepath = os.path.join(input_dir, filename)
            if os.path.isfile(filepath) and filename.endswith(".bin"):
                files.append(filepath)

        if not files:
            print(f"No .bin files found in {input_dir}")
            return

        print(f"Found {len(files)} files to decompress")

        self.load_model()

        for filepath in files:
            filename = os.path.basename(filepath)
            original_name = filename[:-4] if filename.endswith(".bin") else filename
            output_path = os.path.join(output_dir, original_name)

            try:
                self.decompress(filepath, output_path)
            except Exception as e:
                print(f"Error decompressing {filename}: {e}")
                continue

        print(f"\n{'='*60}")
        print(f"Decompression complete. Output directory: {output_dir}")
        print(f"{'='*60}")



if __name__ == "__main__":

    input_path = "ENTER FILENAME HERE" # For example, input_path = "test.txt"
    compressor = LLMCompressor()

    if os.path.isdir(input_path):
        # Compress all files, then decompress all, then verify
        compressor.compress_directory(input_path)
        compressed_dir = os.path.join(input_path, "compressed")
        compressor.decompress_directory(compressed_dir)

        # Verify all files
        recovered_dir = os.path.join(compressed_dir, "recovered")
        print(f"\n{'='*60}")
        print("Verifying integrity...")
        print(f"{'='*60}")
        for filename in sorted(os.listdir(input_path)):
            original_path = os.path.join(input_path, filename)
            if os.path.isfile(original_path) and not filename.endswith(".bin"):
                recovered_path = os.path.join(recovered_dir, filename)
                if os.path.exists(recovered_path):
                    pass

    elif os.path.isfile(input_path):
        compressed_path = input_path + ".bin"
        recovered_path = input_path + ".recovered"

        compressor.compress(input_path, compressed_path)
        compressor.decompress(compressed_path, recovered_path)

    else:
        raise FileNotFoundError(f"Path not found: {input_path}")
