In [None]:
from typing import List

import cv2
import numpy as np
import matplotlib.pyplot as plt

import bz2

import torch

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, LlamaForCausalLM

In [None]:
class LLMCompression:

	def __init__(self,
		llm_name: str,
		context_size: int,
		patch_size: int, # context_window: int,
		color_sep: str="-",
		pixel_sep: str="|",
	):
		self.llm_name = llm_name
		self.llm = AutoModelForCausalLM.from_pretrained(
			llm_name,
			torch_dtype=torch.bfloat16,
			# load_in_4bit=True,
			# quantization_config=quantization_config,
			# device_map=torch.device("cuda"),
			device_map="auto"
		)
		self.llm.eval()
		for param in self.llm.parameters():
			param.requires_grad = False
		self.tokenizer = AutoTokenizer.from_pretrained(llm_name)

		self.context_size = context_size
		self.patch_size = patch_size
		self.color_sep = color_sep
		self.pixel_sep = pixel_sep
		self.words = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', self.color_sep, self.pixel_sep]
		self.tokens = self.tokenizer(self.words, add_special_tokens=False)
		self.tokens = [x for xs in self.tokens["input_ids"] for x in xs]

		self.word2token = {w: idx for w, idx in zip(self.words, self.tokens)}
		self.token2word = {idx: w for idx, w in zip(self.tokens, self.words)}
	
	def _pad(self, tokens):
		if tokens.shape[0] % self.context_size == 0:
			return tokens, torch.zeros(tokens.shape[0], device=tokens.device)
		pad_len = self.context_size - tokens.shape[0] % self.context_size

		pads = torch.full([pad_len], self.tokenizer.eos_token_id, device=tokens.device)
		padded_tokens = torch.cat([tokens, pads])

		return padded_tokens, pad_len
	
	# rank from "stable" sort
	def _get_rank(self, logits, token_ids):  # (B, N, V) (B, N)
		'''
			Find the rank of the tokens
			1/ The number of values less than token_ids
			2/ The number of values equal with token_ids
		'''
		# count the strictly the number of greater values
		selected_logits = logits.gather(-1, token_ids[..., None]).squeeze(-1)
		n_gt = (logits > selected_logits[..., None]).sum(-1)  # (B, N)

		# "mimic" stable sorting
		eq = logits.eq(selected_logits[..., None])  # (B, N, V)
		mask = torch.arange(logits.shape[-1], device=logits.device) < token_ids.unsqueeze(-1)
		n_eq = (eq*mask).sum(-1)

		return n_gt + n_eq

	def encode(self, s):
		tokens = self.tokenizer(s, return_tensors="pt")
		tokens = tokens["input_ids"].squeeze()
		tokens = tokens.to(self.llm.device)

		tokens, pad_len = self._pad(tokens[1:])
		tokens = tokens.view(-1, self.context_size)

		bos = torch.full([tokens.shape[0]], self.tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
		tokens = torch.cat((bos, tokens), 1)

		output = self.llm(tokens[:, :-1])
		logits = output.logits
		word_logits = logits[..., self.words]
		ranks = self._get_rank(word_logits, tokens[:, 1:])

		return ranks, pad_len
	
	def encode(self, img):
		p_size = self.patch_size
		patches = np.array([
			img[i*p_size:(i+1)*p_size, j*p_size:(j+1)*p_size, :].flatten()
			for i in range(img.shape[0]//p_size)
			for j in range(img.shape[1]//p_size)
		])
		tokens = self.tokenizer(s, return_tensors="pt")
		tokens = tokens["input_ids"].squeeze()
		tokens = tokens.to(self.llm.device)

		tokens, pad_len = self._pad(tokens[1:])
		tokens = tokens.view(-1, self.context_size)

		bos = torch.full([tokens.shape[0]], self.tokenizer.bos_token_id, device=tokens.device).unsqueeze(1)
		tokens = torch.cat((bos, tokens), 1)

		output = self.llm(tokens[:, :-1])
		logits = output.logits
		word_logits = logits[..., self.words]
		ranks = self._get_rank(word_logits, tokens[:, 1:])

		return ranks, pad_len


	def decode(self, rank: List[int], pad_len: int):
		generated_ids = torch.full((rank.shape[0], 1), self.tokenizer.bos_token_id, device=rank.device)
		
		past_key_values = None
		for idx in range(self.context_size):
			output = self.llm(generated_ids, past_key_values=past_key_values, use_cache=False)
			past_key_values = output.past_key_values

			logits = output.logits[:, -1, :]  # shape: (n_chunks, vocab)
			logits, sorted_tokens = torch.sort(logits, descending=True, stable=True)

			next_token_id = sorted_tokens.gather(-1, rank[:, idx].unsqueeze(-1))

			generated_ids = torch.cat([generated_ids, next_token_id], dim=1)

		output = generated_ids[:, 1:].flatten()
		return self.tokenizer.decode(output[:-pad_len], skip_special_tokens=True)

	def evaluate(self, s):
		rank, pad_len = self.encode(s)
		torch.cuda.empty_cache()

		s_hat = self.decode(rank, pad_len)
		assert s_hat == s, f"incorrect (de)-compression \n Expected: {s} \n Got: {s_hat}"

		compressed_s = bz2.compress(s.encode('utf-8'))
		_rank = rank.flatten()
		compressed_s_hat = bz2.compress(_rank.cpu().numpy().tobytes())

		# Get the size of the compressed data
		s_size = len(compressed_s)
		s_hat_size = len(compressed_s_hat)
		# print(s_hat_size, s_size)
		print(f"Compression ratio: {(s_hat_size / s_size)*100:.4f}")

		return _rank, pad_len
	
	def patch2tokens(self, patches):
		sequences = [
			f'{self.pixel_sep}'.join([
					f'{self.color_sep}'.join([
						str(num)
						for num in patch[i*3:(i+1)*3]
					])
				for i in range(len(patches)//3)
			])
			for patch in patches
		]
		return self.tokenizer(sequences, return_tensors="pt")

llm_zip = LLMCompression(
	# llm_name="meta-llama/Llama-3.2-1B",
	llm_name="unsloth/Llama-3.2-1B-bnb-4bit",
	context_size=256,
	patch_size=16,
)

img = cv2.imread("./ILSVRC2012_val_00003014.JPEG", cv2.IMREAD_COLOR)
img = cv2.resize(img, (224, 224))
# img = transform(img)

img.shape

(224, 224, 3)

In [3]:
p_size = 16
patches = np.array([
    img[i*p_size:(i+1)*p_size, j*p_size:(j+1)*p_size, :].flatten()
    for i in range(img.shape[0]//p_size)
    for j in range(img.shape[1]//p_size)
])
patches.shape

(196, 768)

In [4]:
sequences = [
    f'{llm_zip.pixel_sep}'.join([
            f'{llm_zip.color_sep}'.join([
                str(num)
                for num in patch[i*3:(i+1)*3]
            ])
        for i in range(len(patches)//3)
    ])
    for patch in patches
]
sequences[0]

'7-53-17|10-70-32|8-33-6|8-116-43|9-96-32|4-29-7|4-12-3|4-36-10|4-37-8|5-48-10|6-11-6|80-104-107|16-34-17|7-19-7|7-19-7|4-8-3|6-58-22|5-46-10|14-38-18|6-74-26|11-69-15|14-39-15|13-18-9|0-37-8|3-37-7|8-32-8|13-21-14|4-27-24|14-108-51|11-37-22|5-24-3|4-15-4|15-64-23|15-52-25|7-39-11|5-25-6|20-85-42|10-77-32|0-17-4|4-37-11|2-43-10|9-28-7|9-24-2|7-6-6|30-83-45|40-67-40|7-19-9|18-33-8|4-57-21|6-29-9|6-23-7|5-17-4|7-5-3|28-137-70|5-7-3|2-46-9|8-43-14|9-25-6|10-23-6|10-29-9|41-60-53|106-175-121|103-154-109|16-35-17|15-56-21'

In [9]:
abc = llm_zip.tokenizer(sequences, return_tensors="pt")
abc["input_ids"].shape

torch.Size([196, 390])

In [None]:
for sequence in sequences:
    tokens = llm_zip.tokenizer(sequence)
    break

In [None]:
for patch in patches:
    for i in range(len(patches)//3):
        s = "-".join([str(num) for num in patch[i*3:(i+1)*3]])
        break
    break

In [None]:
plt.imshow(img)
plt.show()

In [None]:
img.shape[0]//p_size, img.shape[1]//p_size

In [None]:
patches[0, :20]

In [None]:
_img = cv2.imread("./ILSVRC2012_val_00003014.JPEG", cv2.IMREAD_COLOR)
_img = cv2.resize(_img, (224, 224))
_img[0, 0], _img[0, 1], _img[0, 2], _img[0, 3], _img[1, 0], _img[1, 1]

In [None]:
%timeit llm_zip.encode(s)

In [None]:
%timeit llm_zip.decode(rank, pad_len)

In [None]:
words = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', "|"]
tokens = llm_zip.tokenizer(words, add_special_tokens=False)
tokens = [x for xs in tokens["input_ids"] for x in xs]

word2token = {w: idx for w, idx in zip(words, tokens)}
token2word = {idx: w for idx, w in zip(tokens, words)}
word2token, token2word