In [None]:
from typing import List

import torch
import torch.nn.functional as F

from numpy import random
import bz2

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

Network

In [None]:
model_name = "meta-llama/Llama-3.2-1B"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)
llm = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    quantization_config=quant_config,
    # device_map=torch.device("cuda"),
    device_map="auto"
)
llm.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

Configuration

In [None]:
class LLMCompression:

    def __init__(self,
        llm_name: str,
        context_size: int, # context_window: int,
    ):
        self.llm_name = llm_name
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,  # Use bfloat16 for better performance
            bnb_4bit_use_double_quant=True,  # Double quantization for memory efficiency
        )
        self.llm = AutoModelForCausalLM.from_pretrained(
            llm_name,
            quantization_config=quant_config,
            # device_map=torch.device("cuda"),
            device_map="auto"
        )
        self.llm.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(llm_name)

        self.context_size = context_size
    
    def _pad(self, tokens):
        if tokens.shape[0] % self.context_size == 0:
            return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
        pad_len = self.context_size - tokens.shape[0] % self.context_size

        pads = torch.full([pad_len], self.tokenizer.eos_token_id, device=tokens.device)
        padded_tokens = torch.cat([tokens, pads])

        return padded_tokens, pad_len
    
    def _get_rank(self, logits, token_ids):
        # count the strictly the number of greater values
        selected_logits = logits.gather(-1, token_ids[..., None]).squeeze(-1)
        n_gt = (logits > selected_logits[..., None]).sum(-1)

        # "mimic" stable sorting
        eq = (logits == selected_logits[..., None])
        mask = torch.arange(logits.shape[-1], device=logits.device).unsqueeze(0) < token_ids.unsqueeze(1)
        n_eq = (eq*mask).sum(-1)

        return n_gt + n_eq
    
    def argsort_solution(self, logits, targets):
        sort = torch.argsort(-logits, -1)
        return torch.where(sort == targets[:, None])[1]

    @torch.no_grad
    def encode(self, s):
        tokens = self.tokenizer(s, return_tensors="pt")
        tokens = tokens["input_ids"].squeeze()
        tokens = tokens.to(self.llm.device)

        tokens, pad_len = self._pad(tokens[1:])
        tokens = tokens.view(-1, self.context_size)

        bos = torch.full([tokens.shape[0]], self.tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
        tokens = torch.cat((bos, tokens), 1)

        ranks = torch.empty_like(tokens[:, :-1])
        past_key_values = None
        for idx in range(self.context_size):
            next_tokens = self.llm(tokens[:, :idx+1], past_key_values=past_key_values)
            past_key_values = next_tokens.past_key_values
            
            rank = self.argsort_solution(next_tokens.logits[:, -1, :], tokens[:, idx+1])
            ranks[:, idx] = rank

        return ranks, pad_len

    @torch.no_grad
    def decode(self, rank: List[int], pad_len: int):
        generated_ids = torch.full((rank.shape[0], 1), tokenizer.bos_token_id, device=rank.device)
        
        past_key_values = None
        for idx in range(self.context_size):
            output = self.llm(generated_ids, past_key_values=past_key_values)
            past_key_values = output.past_key_values

            logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
            logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

            next_token_id = sorted_tokens.gather(-1, rank[:, idx].unsqueeze(-1))

            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

        output = generated_ids[:, 1:].flatten()
        return tokenizer.decode(output[:-pad_len], skip_special_tokens=True)

    def evaluate(self, s):
        rank, pad_len = self.encode(s)
        torch.cuda.empty_cache()

        s_hat = self.decode(rank, pad_len)
        assert s_hat == s, "incorrect (de)-compression"

        compressed_s = bz2.compress(s.encode('utf-8'))
        _rank = rank.flatten()
        print(_rank)
        compressed_s_hat = bz2.compress(_rank.cpu().numpy().tobytes())

        # Get the size of the compressed data
        s_size = len(compressed_s)
        s_hat_size = len(compressed_s_hat)
        print(s_hat_size, s_size)
        print(f"Compression ratio is: {(s_hat_size / s_size)*100:.4f}")

        return _rank, pad_len



In [None]:
llm_zip = LLMCompression(
    llm_name="meta-llama/Llama-3.2-1B",
    context_size=256
)
# s = "The rapid advancement of technology has dramatically reshaped the way humans live, work, and interact with the world. In just a few decades, society has transitioned from relying on traditional forms of communication, such as letters and landline telephones, to an era dominated by smartphones, social media, and artificial intelligence. This transformation has brought numerous benefits, making information more accessible, improving efficiency in various industries, and enhancing global connectivity. However, it has also introduced new challenges, including privacy concerns, the digital divide, and the potential for job displacement due to automation. One of the most significant changes driven by technology is the way people communicate. In the past, communication was often slow and limited to physical mail, face-to-face conversations, or expensive long-distance phone calls. Today, instant messaging, video conferencing, and social media platforms allow individuals to stay connected regardless of geographic location. This has strengthened personal relationships, enabled remote work opportunities, and facilitated the exchange of ideas on a global scale. However, the convenience of digital communication has also led to a decline in face-to-face interactions, raising concerns about its impact on social skills and mental health. Moreover, the rise of misinformation and the spread of fake news through digital platforms pose a significant challenge in today's interconnected world. The ability to share information instantly means that false narratives can gain traction quickly, influencing public opinion and even political outcomes. While technology companies have implemented algorithms and fact-checking mechanisms to combat misinformation, the responsibility ultimately lies with users to critically evaluate the information they consume and share. The integration of artificial intelligence and automation has also transformed various industries, improving productivity and efficiency. In healthcare, AI-powered diagnostic tools assist doctors in identifying diseases more accurately, while robotic surgeries enable precision procedures. In the business sector, automation streamlines supply chains, enhances customer service through chatbots, and improves decision-making with data-driven insights. Despite these advantages, the increasing reliance on technology raises concerns about job displacement, as automation continues to replace human workers in certain roles. This shift necessitates a focus on reskilling and upskilling workers to prepare them for the evolving job market. Education systems must adapt to equip students with the skills needed for the digital age, including proficiency in coding, data analysis, and critical thinking. Additionally, ethical considerations surrounding artificial intelligence must be addressed, ensuring that AI systems are developed and used responsibly. Cybersecurity is another pressing issue in the digital era. With the rise of online transactions, cloud computing, and interconnected devices, cyber threats have become more sophisticated. Data breaches, hacking attempts, and identity theft pose risks to individuals and organizations alike. As a result, cybersecurity measures must continually evolve to protect sensitive information and maintain trust in digital platforms. While technology has undoubtedly improved many aspects of life, it is essential to strike a balance between embracing innovation and addressing its challenges. Responsible use, ethical considerations, and continued education will play a crucial role in shaping a future where technology serves humanity in a positive and sustainable manner."
s = ":".join(
    str(x)
    for x in random.randint(0, 5000, (50,)).tolist()
    # for x in random.rand(25).tolist()
)
rank, pad_len = llm_zip.evaluate(s)

In [None]:
s = "The rapid advancement of technology has dramatically reshaped the way humans live, work, and interact with the world. In just a few decades, society has transitioned from relying on traditional forms of communication, such as letters and landline telephones, to an era dominated by smartphones, social media, and artificial intelligence. This transformation has brought numerous benefits, making information more accessible, improving efficiency in various industries, and enhancing global connectivity. However, it has also introduced new challenges, including privacy concerns, the digital divide, and the potential for job displacement due to automation. One of the most significant changes driven by technology is the way people communicate. In the past, communication was often slow and limited to physical mail, face-to-face conversations, or expensive long-distance phone calls. Today, instant messaging, video conferencing, and social media platforms allow individuals to stay connected regardless of geographic location. This has strengthened personal relationships, enabled remote work opportunities, and facilitated the exchange of ideas on a global scale. However, the convenience of digital communication has also led to a decline in face-to-face interactions, raising concerns about its impact on social skills and mental health. Moreover, the rise of misinformation and the spread of fake news through digital platforms pose a significant challenge in today's interconnected world. The ability to share information instantly means that false narratives can gain traction quickly, influencing public opinion and even political outcomes. While technology companies have implemented algorithms and fact-checking mechanisms to combat misinformation, the responsibility ultimately lies with users to critically evaluate the information they consume and share. The integration of artificial intelligence and automation has also transformed various industries, improving productivity and efficiency. In healthcare, AI-powered diagnostic tools assist doctors in identifying diseases more accurately, while robotic surgeries enable precision procedures. In the business sector, automation streamlines supply chains, enhances customer service through chatbots, and improves decision-making with data-driven insights. Despite these advantages, the increasing reliance on technology raises concerns about job displacement, as automation continues to replace human workers in certain roles. This shift necessitates a focus on reskilling and upskilling workers to prepare them for the evolving job market. Education systems must adapt to equip students with the skills needed for the digital age, including proficiency in coding, data analysis, and critical thinking. Additionally, ethical considerations surrounding artificial intelligence must be addressed, ensuring that AI systems are developed and used responsibly. Cybersecurity is another pressing issue in the digital era. With the rise of online transactions, cloud computing, and interconnected devices, cyber threats have become more sophisticated. Data breaches, hacking attempts, and identity theft pose risks to individuals and organizations alike. As a result, cybersecurity measures must continually evolve to protect sensitive information and maintain trust in digital platforms. While technology has undoubtedly improved many aspects of life, it is essential to strike a balance between embracing innovation and addressing its challenges. Responsible use, ethical considerations, and continued education will play a crucial role in shaping a future where technology serves humanity in a positive and sustainable manner."


In [None]:
1287 1677
Compression ratio is: 76.7442

In [None]:

compressed_s = bz2.compress(s.encode('utf-8'))

compressed_s_hat = bz2.compress(rank.cpu().numpy().tobytes())

print(len(compressed_s_hat), len(compressed_s))
print(f"Compression ratio is: {(len(compressed_s_hat) / len(compressed_s))*100:.4f}")

In [None]:
rank

In [None]:
s = "The quick brown fox jumps over the lazy dog."

# s = ":".join(
#     str(x)
#     # for x in random.randint(0, 5000, (50,)).tolist()
#     for x in random.rand(25).tolist()
# )

print("String length:", len(s))

Encoding

In [None]:
def pad(tokens, padding_val):
    if tokens.shape[0] % CONTEXT_SIZE == 0:
        return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
    pad_len = CONTEXT_SIZE - tokens.shape[0] % CONTEXT_SIZE

    pads = torch.full([pad_len], padding_val, device=tokens.device)
    padded_tokens = torch.cat([tokens, pads])

    return padded_tokens, pad_len

def text_to_tokens(text):
    tokens = tokenizer(text, return_tensors="pt")
    return tokens["input_ids"].squeeze()

def get_rank(logits, indices):
    # count the strictly the number of greater values
    selected_logits = logits.gather(-1, indices[..., None]).squeeze(-1)
    n_gt = (logits > selected_logits[..., None]).sum(-1)

    # "mimic" stable sorting
    eq = (logits == selected_logits[..., None])#.sum(-1)
    mask = torch.arange(logits.shape[-1], device=logits.device).unsqueeze(0) < indices.unsqueeze(1)
    n_eq = (eq*mask).sum(-1)

    return n_gt + n_eq

def argsort_solution(logits, targets):
    sort = torch.argsort(-logits, -1)
    return torch.where(sort == targets[:, None])[1]

def get_token_by_rank(logits, ranks): ...

CONTEXT_SIZE = 8
s = "The quick brown fox jumps over the lazy dog."

# s = ":".join(
#     str(x)
#     # for x in random.randint(0, 5000, (50,)).tolist()
#     for x in random.rand(25).tolist()
# )

print("String length:", len(s))

tokens = text_to_tokens(s) 

tokens, pad_len = pad(tokens[1:], tokenizer.eos_token_id)
tokens = tokens.view(-1, CONTEXT_SIZE)

bos = torch.full([tokens.shape[0]], tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
tokens = torch.cat((bos, tokens), 1)

ranks = torch.empty_like(tokens[:, :-1])
past_key_values = None
for idx in range(CONTEXT_SIZE):
    next_tokens = llm(tokens[:, :idx+1].cuda(), past_key_values=past_key_values)
    past_key_values = next_tokens.past_key_values
    
    rank = get_rank(next_tokens.logits[:, -1, :], tokens[:, idx+1].cuda())
    # rank = argsort_solution(next_tokens.logits[:, -1, :], tokens[:, idx+1].cuda())
    ranks[:, idx] = rank

torch.cuda.empty_cache()
print(tokens.shape, ranks.shape)
generated_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)

with torch.no_grad():
    past_key_values = None
    for idx in range(CONTEXT_SIZE):
        # print(f'\r{idx}/{CONTEXT_SIZE}', end='')
        output = llm(generated_ids.cuda(), past_key_values=past_key_values, top_k=1)
        past_key_values = output.past_key_values

        logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
        logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

        next_token_id = sorted_tokens.gather(-1, ranks.cuda()[:, idx].unsqueeze(-1))

        generated_ids = torch.cat([generated_ids.cuda(), next_token_id], dim=1)
output = generated_ids[:, 1:].flatten()
generated_text = tokenizer.decode(output[:-pad_len], skip_special_tokens=True)
print("Final generated sequence:")
print(generated_text)
print(s)

Decoding

In [None]:
input_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)
input_ids.shape

In [None]:
input_ids = torch.tensor([[tokenizer.bos_token_id]]*ranks.shape[0], device=tokens.device)

with torch.no_grad():
    past_key_values = None
    for idx in range(CONTEXT_SIZE):
        print(f'\r{idx}/{CONTEXT_SIZE}', end='')
        output = llm(input_ids.cuda(), past_key_values=past_key_values, top_k=1)
        past_key_values = output.past_key_values

        logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
        logits, sorted_tokens = torch.sort(logits, descending=True)

        next_token_id = sorted_tokens.gather(-1, ranks.cuda()[:, idx].unsqueeze(-1))

        input_ids = torch.cat([input_ids.cuda(), next_token_id], dim=1)
input_ids

In [None]:
for a, b in zip(
    generated_text.split(":"),
    s.split(":")
):
    print(float(a) - float(b))
    # break

In [None]:
torch.manual_seed(44)
probas = torch.rand(4000, 50000)
probas /= probas.sum(1)[:, None]
targets = torch.randint(0, 50000, (4000,))

def argsort_solution(x, targets):
    sort = torch.argsort(-x, dim=1, stable=True)
    return torch.where(sort == targets[:, None])[1]

def get_rank(x, indices):
    # count the strictly the number of greater values
    vals = x.gather(-1, indices[..., None]).squeeze(-1)
    n_gt = (x > vals[:, None]).sum(-1)

    # "mimic" stable sorting
    eq = (x == vals[:, None])#.sum(-1)
    mask = torch.arange(x.shape[-1]).unsqueeze(0) < indices.unsqueeze(1)
    n_eq = (eq*mask).sum(-1)

    return n_gt + n_eq

a = argsort_solution(probas, targets)
b = get_rank(probas, targets)

for x, y, in zip(a, b):
    if x != y:
        print(x, y, x - y)

In [None]:
(targets == targets.max()).sum()

In [None]:
targets

In [None]:
v = torch.tensor([
    [4, 3, 5, 4, 7],
    [4, 4, 5, 7, 4],
])
idx = torch.tensor([3, 4])
val = v.gather(-1, idx[..., None]).squeeze(-1)
print("idx", idx, "val", val)

gt = (v > val[:, None]).sum(-1)
print("gt", gt)

eq = (v == val[:, None])
print("eq", eq.long())

mask = torch.arange(v.shape[-1]).unsqueeze(0) < idx.unsqueeze(1)
print(mask.int())
n_eq = (eq*mask).sum(-1)
print("n_eq", n_eq)

rank = gt + n_eq
print("rank", rank)