In [1]:
from typing import List

import torch

import bz2

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaForCausalLM

In [2]:
class LLMCompression:

    def __init__(self,
        llm_name: str,
        context_size: int, # context_window: int,
    ):
        self.llm_name = llm_name
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for better performance
            bnb_4bit_use_double_quant=True,  # Double quantization for memory efficiency
        )
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_name,
            quantization_config=quant_config,
            # device_map=torch.device("cuda"),
            device_map="auto"
        )
        self.llm.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)

        self.context_size = context_size
    
    def _pad(self, tokens):
        if tokens.shape[0] % self.context_size == 0:
            return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
        pad_len = self.context_size - tokens.shape[0] % self.context_size

        pads = torch.full([pad_len], self.tokenizer.eos_token_id, device=tokens.device)
        padded_tokens = torch.cat([tokens, pads])

        return padded_tokens, pad_len
    
    # rank from "stable" sort
    def _get_rank(self, logits, token_ids):  # (B, N, V) (B, N)
        # count the strictly the number of greater values
        # print(logits.shape, token_ids.shape)
        selected_logits = logits.gather(-1, token_ids[..., None]).squeeze(-1)
        n_gt = (logits > selected_logits[..., None]).sum(-1)  # (B, N)

        # "mimic" stable sorting
        eq = logits.eq(selected_logits[..., None])  # (B, N, V)
        mask = torch.arange(logits.shape[-1], device=logits.device) < token_ids.unsqueeze(-1)
        n_eq = (eq*mask).sum(-1)

        return n_gt + n_eq
    
    def argsort_solution(self, logits, targets):
        sort = torch.argsort(-logits, -1)
        return torch.where(sort == targets[:, None])[1]

    @torch.no_grad
    def encode(self, s):
        tokens = self.tokenizer(s, return_tensors="pt")
        tokens = tokens["input_ids"].squeeze()
        tokens = tokens.to(self.llm.device)

        tokens, pad_len = self._pad(tokens[1:])
        tokens = tokens.view(-1, self.context_size)

        bos = torch.full([tokens.shape[0]], self.tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
        tokens = torch.cat((bos, tokens), 1)

        ranks = torch.empty_like(tokens[:, :-1])
        past_key_values = None
        for idx in range(self.context_size):
            next_tokens = self.llm(tokens[:, :idx+1], past_key_values=past_key_values)
            # past_key_values = next_tokens.past_key_values
            
            rank = self._get_rank(next_tokens.logits[:, -1, :], tokens[:, idx+1])
            ranks[:, idx] = rank

        # output = self.llm(tokens[:, :-1])
        # ranks = self._get_rank(output.logits, tokens[:, 1:])

        return ranks, pad_len

    @torch.no_grad
    def decode(self, rank: List[int], pad_len: int):
        generated_ids = torch.full((rank.shape[0], 1), self.tokenizer.bos_token_id, device=rank.device)
        
        past_key_values = None
        for idx in range(self.context_size):
            output = self.llm(generated_ids, past_key_values=past_key_values)
            # past_key_values = output.past_key_values#[:, -1, :]

            logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
            logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

            next_token_id = sorted_tokens.gather(-1, rank[:, idx].unsqueeze(-1))

            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

        output = generated_ids[:, 1:].flatten()
        return self.tokenizer.decode(output[:-pad_len], skip_special_tokens=True)

    def evaluate(self, s):
        rank, pad_len = self.encode(s)
        torch.cuda.empty_cache()

        s_hat = self.decode(rank, pad_len)
        assert s_hat == s, f"incorrect (de)-compression \n Expected: {s} \n Got: {s_hat}"

        compressed_s = bz2.compress(s.encode('utf-8'))
        _rank = rank.flatten()
        compressed_s_hat = bz2.compress(_rank.cpu().numpy().tobytes())

        # Get the size of the compressed data
        s_size = len(compressed_s)
        s_hat_size = len(compressed_s_hat)
        # print(s_hat_size, s_size)
        print(f"Compression ratio: {(s_hat_size / s_size)*100:.4f}")

        return _rank, pad_len

llm_zip = LLMCompression(
    llm_name="meta-llama/Llama-3.2-1B",
    context_size=8
)
s = "The quick brown fox jumps over the lazy dog."
_, _ = llm_zip.evaluate(s)


Compression ratio: 90.1235


In [None]:
def _pad(tokens):
    if tokens.shape[0] % context_size == 0:
        return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
    pad_len = context_size - tokens.shape[0] % context_size

    pads = torch.full([pad_len], tokenizer.eos_token_id, device=tokens.device)
    padded_tokens = torch.cat([tokens, pads])

    return padded_tokens, pad_len


model_name = "meta-llama/Llama-3.2-1B"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
llm = LlamaForCausalLM.from_pretrained(
    model_name,
    # torch_dtype=torch.bfloat16,
    # quantization_config=quant_config,
    # device_map=torch.device("cuda"),
    device_map="auto"
)
llm.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

s = "The quick brown fox jumps over the lazy dog."
context_size = 8


tokens = tokenizer(s, return_tensors="pt")
tokens = tokens["input_ids"].squeeze()
tokens = tokens.to(llm.device)

tokens, pad_len = _pad(tokens[1:])
tokens = tokens.view(-1, context_size)

bos = torch.full([tokens.shape[0]], tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
tokens = torch.cat((bos, tokens), 1)

past_key_values = None
prev_logits = None
output_1 = []
for idx in range(context_size):
    next_tokens = llm(tokens[:, :idx+1])
    # output_1.append(next_tokens.logits[:, -1, :])
    past_key_values = next_tokens.past_key_values

    if idx != 0:
        print(torch.equal(next_tokens.logits[:, -1], prev_logits))
        print(((next_tokens.logits[:, -1] - prev_logits)**2).mean())
    
    prev_logits = next_tokens.logits[:, -1]