In [1]:
from typing import List

import torch
import torch.nn.functional as F

from numpy import random

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

Network

In [29]:
model_name = "meta-llama/Llama-3.2-1B"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    quantization_config=quant_config,
    # device_map=torch.device("cuda"),
    device_map="auto"
)
llm.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

Configuration

In [None]:
class LLMCompression:

    def __init__(self,
        llm_name: str,
        context_size: int, # context_window: int,
    ):
        self.llm_name = llm_name
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for better performance
            bnb_4bit_use_double_quant=True,  # Double quantization for memory efficiency
        )
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_name,
            quantization_config=quant_config,
            # device_map=torch.device("cuda"),
            device_map="auto"
        )
        self.llm.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)

        self.context_size = context_size
    
    def _pad(self, tokens):
        if tokens.shape[0] % self.context_size == 0:
            return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
        pad_len = self.context_size - tokens.shape[0] % self.context_size

        pads = torch.full([pad_len], self.tokenizer.eos_token_id, device=tokens.device)
        padded_tokens = torch.cat([tokens, pads])

        return padded_tokens, pad_len
    
    def _get_rank(self, logits, token_ids):
        # count the strictly the number of greater values
        selected_logits = logits.gather(-1, token_ids[..., None]).squeeze(-1)
        n_gt = (logits > selected_logits[..., None]).sum(-1)

        # "mimic" stable sorting
        eq = (logits == selected_logits[..., None])
        mask = torch.arange(logits.shape[-1], device=logits.device).unsqueeze(0) < token_ids.unsqueeze(1)
        n_eq = (eq*mask).sum(-1)

        return n_gt + n_eq
    
    def argsort_solution(self, logits, targets):
        sort = torch.argsort(-logits, -1)
        return torch.where(sort == targets[:, None])[1]

    @torch.no_grad
    def encode(self, s):
        tokens = self.tokenizer(s, return_tensors="pt")
        tokens = tokens["input_ids"].squeeze()
        tokens = tokens.to(self.llm.device)

        tokens, pad_len = self._pad(tokens[1:])
        tokens = tokens.view(-1, self.context_size)

        bos = torch.full([tokens.shape[0]], self.tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
        tokens = torch.cat((bos, tokens), 1)

        ranks = torch.empty_like(tokens[:, :-1])
        past_key_values = None
        for idx in range(self.context_size):
            next_tokens = self.llm(tokens[:, :idx+1], past_key_values=past_key_values)
            past_key_values = next_tokens.past_key_values
            
            rank = self.argsort_solution(next_tokens.logits[:, -1, :], tokens[:, idx+1])
            ranks[:, idx] = rank

        return ranks, pad_len

    @torch.no_grad
    def decode(self, rank: List[int], pad_len: int):
        generated_ids = torch.full((rank.shape[0], 1), tokenizer.bos_token_id, device=rank.device)
        
        past_key_values = None
        for idx in range(self.context_size):
            output = self.llm(generated_ids, past_key_values=past_key_values)
            past_key_values = output.past_key_values

            logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
            logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

            next_token_id = sorted_tokens.gather(-1, rank[:, idx].unsqueeze(-1))

            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

        output = generated_ids[:, 1:].flatten()
        return tokenizer.decode(output[:-pad_len], skip_special_tokens=True)

    def evaluate(self, s):
        rank, pad_len = self.encode(s)
        torch.cuda.empty_cache()

        s_hat = self.decode(rank, pad_len)



In [50]:
llm_zip = LLMCompression(
    llm_name="meta-llama/Llama-3.2-1B",
    context_size=8
)
s = "The quick brown fox jumps over the lazy dog."
llm_zip.evaluate(s)

The quick brown fox jumps over the lazy dog.
The quick brown fox jumps over the lazy dog.


In [None]:
s = "The quick brown fox jumps over the lazy dog."

# s = ":".join(
#     str(x)
#     # for x in random.randint(0, 5000, (50,)).tolist()
#     for x in random.rand(25).tolist()
# )

print("String length:", len(s))

Encoding

In [46]:
def pad(tokens, padding_val):
    if tokens.shape[0] % CONTEXT_SIZE == 0:
        return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
    pad_len = CONTEXT_SIZE - tokens.shape[0] % CONTEXT_SIZE

    pads = torch.full([pad_len], padding_val, device=tokens.device)
    padded_tokens = torch.cat([tokens, pads])

    return padded_tokens, pad_len

def text_to_tokens(text):
    tokens = tokenizer(text, return_tensors="pt")
    return tokens["input_ids"].squeeze()

def get_rank(logits, indices):
    # count the strictly the number of greater values
    selected_logits = logits.gather(-1, indices[..., None]).squeeze(-1)
    n_gt = (logits > selected_logits[..., None]).sum(-1)

    # "mimic" stable sorting
    eq = (logits == selected_logits[..., None])#.sum(-1)
    mask = torch.arange(logits.shape[-1], device=logits.device).unsqueeze(0) < indices.unsqueeze(1)
    n_eq = (eq*mask).sum(-1)

    return n_gt + n_eq

def argsort_solution(logits, targets):
    sort = torch.argsort(-logits, -1)
    return torch.where(sort == targets[:, None])[1]

def get_token_by_rank(logits, ranks): ...

CONTEXT_SIZE = 8
s = "The quick brown fox jumps over the lazy dog."

# s = ":".join(
#     str(x)
#     # for x in random.randint(0, 5000, (50,)).tolist()
#     for x in random.rand(25).tolist()
# )

print("String length:", len(s))

tokens = text_to_tokens(s) 

tokens, pad_len = pad(tokens[1:], tokenizer.eos_token_id)
tokens = tokens.view(-1, CONTEXT_SIZE)

bos = torch.full([tokens.shape[0]], tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
tokens = torch.cat((bos, tokens), 1)

ranks = torch.empty_like(tokens[:, :-1])
past_key_values = None
for idx in range(CONTEXT_SIZE):
    next_tokens = llm(tokens[:, :idx+1].cuda(), past_key_values=past_key_values)
    past_key_values = next_tokens.past_key_values
    
    rank = get_rank(next_tokens.logits[:, -1, :], tokens[:, idx+1].cuda())
    # rank = argsort_solution(next_tokens.logits[:, -1, :], tokens[:, idx+1].cuda())
    ranks[:, idx] = rank

torch.cuda.empty_cache()
print(tokens.shape, ranks.shape)
generated_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)

with torch.no_grad():
    past_key_values = None
    for idx in range(CONTEXT_SIZE):
        # print(f'\r{idx}/{CONTEXT_SIZE}', end='')
        output = llm(generated_ids.cuda(), past_key_values=past_key_values, top_k=1)
        past_key_values = output.past_key_values

        logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
        logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

        next_token_id = sorted_tokens.gather(-1, ranks.cuda()[:, idx].unsqueeze(-1))

        generated_ids = torch.cat([generated_ids.cuda(), next_token_id], dim=1)
output = generated_ids[:, 1:].flatten()
generated_text = tokenizer.decode(output[:-pad_len], skip_special_tokens=True)
print("Final generated sequence:")
print(generated_text)
print(s)

String length: 44
torch.Size([2, 9]) torch.Size([2, 8])
Final generated sequence:
The quick brown fox jumps over the lazy dog.
The quick brown fox jumps over the lazy dog.


Decoding

In [None]:
input_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)
input_ids.shape

In [None]:
input_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)

with torch.no_grad():
    past_key_values = None
    for idx in range(CONTEXT_SIZE):
        print(f'\r{idx}/{CONTEXT_SIZE}', end='')
        output = llm(input_ids.cuda(), past_key_values=past_key_values, top_k=1)
        past_key_values = output.past_key_values

        logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
        logits, sorted_tokens = torch.sort(logits, descending=True)

        next_token_id = sorted_tokens.gather(-1, ranks.cuda()[:, idx].unsqueeze(-1))

        input_ids = torch.cat([input_ids.cuda(), next_token_id], dim=1)
input_ids

In [None]:
for a, b in zip(
    generated_text.split(":"),
    s.split(":")
):
    print(float(a) - float(b))
    # break

In [None]:
torch.manual_seed(44)
probas = torch.rand(4000, 50000)
probas /= probas.sum(1)[:, None]
targets = torch.randint(0, 50000, (4000,))

def argsort_solution(x, targets):
    sort = torch.argsort(-x, dim=1, stable=True)
    return torch.where(sort == targets[:, None])[1]

def get_rank(x, indices):
    # count the strictly the number of greater values
    vals = x.gather(-1, indices[..., None]).squeeze(-1)
    n_gt = (x > vals[:, None]).sum(-1)

    # "mimic" stable sorting
    eq = (x == vals[:, None])#.sum(-1)
    mask = torch.arange(x.shape[-1]).unsqueeze(0) < indices.unsqueeze(1)
    n_eq = (eq*mask).sum(-1)

    return n_gt + n_eq

a = argsort_solution(probas, targets)
b = get_rank(probas, targets)

for x, y, in zip(a, b):
    if x != y:
        print(x, y, x - y)

In [None]:
(targets == targets.max()).sum()

In [None]:
targets

In [None]:
v = torch.tensor([
    [4, 3, 5, 4, 7],
    [4, 4, 5, 7, 4],
])
idx = torch.tensor([3, 4])
val = v.gather(-1, idx[..., None]).squeeze(-1)
print("idx", idx, "val", val)

gt = (v > val[:, None]).sum(-1)
print("gt", gt)

eq = (v == val[:, None])
print("eq", eq.long())

mask = torch.arange(v.shape[-1]).unsqueeze(0) < idx.unsqueeze(1)
print(mask.int())
n_eq = (eq*mask).sum(-1)
print("n_eq", n_eq)

rank = gt + n_eq
print("rank", rank)