In [None]:
from __future__ import annotations
import os
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Tuple

# def compression_experiment(
#     model_name: str = "meta-llama/Meta-Llama-3.1-8B",
#     text: str = "",
#     seed_length: int = 10,
#     max_new_tokens: int = 1000
# ):
#     """See if we can use a language model to try and compress text reasonably well."""
#     # Load model and tokenizer
model_name = "meta-llama/Meta-Llama-3.1-8B"
seed_length = 40
print(f"Loading model and tokenizer: {model_name}")
device = "cuda:0"
# Generate text based on the seed
# print("Generating text from seed...")
# with torch.no_grad():
#     output = model.generate(
#         seed_tokens,
#         max_new_tokens=min(max_new_tokens, full_encoding.shape[1] - seed_length),
#         do_sample=False  # Use greedy decoding for deterministic output
#     )


# Example usage
with open("gutenberg_book.txt", "r") as file:
    text = file.read()
# compression_experiment(
#     text=example_text,
#     seed_length=20,
#     max_new_tokens=200
# )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto").to(device)

# Tokenize the entire text
full_encoding = tokenizer.encode(text, return_tensors="pt").to(model.device)

# Extract the seed (first L tokens)
if seed_length >= full_encoding.shape[1]:
    raise ValueError(f"Seed length ({seed_length}) must be less than the total number of tokens ({full_encoding.shape[1]})")

max_new_tokens = 999
seed_tokens = full_encoding[:, :seed_length]
target_tokens = full_encoding[:, seed_length:min(seed_length + max_new_tokens, full_encoding.shape[1])]
print("seed tokens: ", seed_tokens.shape, seed_tokens.device)
print("target tokens: ", target_tokens.shape, target_tokens.device)
print("model.device: ", model.device)

In [None]:
import tqdm
import einops
import io
import contextlib
from typing import Dict
devices = {pn: str(sx.device) for pn, sx in model.named_parameters()}
print(json.dumps(devices, indent=4))
assert set(devices.values()) == {"cuda:0"}
assert seed_tokens.ndim == target_tokens.ndim <= 2, f"seed_tokens.ndim: {seed_tokens.ndim}, target_tokens.ndim: {target_tokens.ndim}" # fmt: skip
assert (seed_tokens.ndim == 1) or (seed_tokens.shape[0] == target_tokens.shape[0] == 1), f"seed_tokens.shape[0]: {seed_tokens.shape[0]}, target_tokens.shape[0]: {target_tokens.shape[0]}" # fmt: skip
seed_tokens = seed_tokens.flatten()
target_tokens = target_tokens.flatten()
running_amt = len(seed_tokens)
zero_pad_start = torch.zeros_like(seed_tokens)
zero_pad_end = torch.zeros_like(target_tokens)
#....
seed_tokens_padded = einops.rearrange(torch.cat([seed_tokens, zero_pad_end]), "b -> 1 b") # fmt: skip
target_tokens_padded = einops.rearrange(torch.cat([zero_pad_start, target_tokens]), "b -> 1 b") # fmt: skip
print(seed_tokens_padded.shape, target_tokens_padded.shape, "from", seed_tokens.shape, target_tokens.shape) # fmt: skip
assert seed_tokens_padded.device == torch.device("cuda:0")
assert target_tokens_padded.device == torch.device("cuda:0")
assert seed_tokens_padded.shape == target_tokens_padded.shape
# model(seed_tokens).logits.shape # batch seq vocab
correction_on_running_amt: Dict[int, torch.Tensor] = {}
assert target_tokens_padded.shape[0] == 1, f"target_tokens_padded.shape[0]: {target_tokens_padded.shape[0]}" # fmt: skip
assert running_amt < target_tokens_padded.shape[1], f"running_amt: {running_amt}, target_tokens_padded.shape[1]: {target_tokens_padded.shape[1]}" # fmt: skip
for i in tqdm.trange(running_amt, target_tokens_padded.shape[1], desc=f"Running inference... run from [{running_amt}, {target_tokens_padded.shape[1]})"): # fmt: skip
    with contextlib.redirect_stdout(io.StringIO()):
        prediction = model(seed_tokens_padded[:, :i]).logits[:, -1, :].argmax(dim=-1)
        if (prediction == target_tokens_padded[0, i]).all():
            seed_tokens_padded[0, i] = prediction
        else:
            correction_on_running_amt[i] = target_tokens_padded[0, i]
            seed_tokens_padded[0, i] = target_tokens_padded[0, i]
        assert (seed_tokens_padded[0, i] == target_tokens_padded[0, i]).all()
print("Total number wrong: ", len(correction_on_running_amt))
print("Hypthetical compression ratio: ", 1 / ((len(seed_tokens.flatten()) + 2 * len(correction_on_running_amt)) / len(seed_tokens_padded.flatten()))) # ignoring model since we assume that will go to zero

In [None]:
"""Try one other dataset"""
from datasets import load_dataset
fb_natural_reasoning_hf = load_dataset("facebook/natural_reasoning", split="train").select(range(1000)) # fmt: skip
questions = [x["question"] for x in fb_natural_reasoning_hf]
questions_mega_str = "\n\n".join(questions)
questions_tok = tokenizer.encode(questions_mega_str, return_tensors="pt").to(model.device)
questions_tok = questions_tok[:, :min(max_new_tokens, questions_tok.shape[1])]
assert isinstance(questions_tok, torch.Tensor), f"questions_tok: {type(questions_tok)}"
assert questions_tok.ndim == 2, f"questions_tok.ndim: {questions_tok.ndim}"
assert questions_tok.shape[0] == 1, f"questions_tok.shape[0]: {questions_tok.shape[0]}"
questions_tok_seed = torch.zeros_like(questions_tok)
questions_tok_tgt = torch.zeros_like(questions_tok)
# Now we will jsut use the strategy from above...
questions_tok_seed[:, :seed_length] = questions_tok[:, :seed_length]
questions_tok_tgt[:, seed_length:] = questions_tok[:, seed_length:]
print("questions_tok.shape: ", questions_tok.shape)
assert questions_tok.ndim == 2, f"questions_tok.ndim: {questions_tok.ndim}"
assert questions_tok.shape[0] == 1, f"questions_tok.shape[0]: {questions_tok.shape[0]}"
correction_on_running_amt: Dict[int, torch.Tensor] = {}
for i in tqdm.trange(seed_length, questions_tok_tgt.shape[1], desc=f"Running inference... run from [{seed_length}, {questions_tok_tgt.shape[1]})"): # fmt: skip
    with contextlib.redirect_stdout(io.StringIO()):
        prediction = model(questions_tok_seed[:, :i]).logits[:, -1, :].argmax(dim=-1)
        if (prediction == questions_tok_tgt[0, i]).all():
            questions_tok_seed[0, i] = prediction
        else:
            correction_on_running_amt[i] = questions_tok_tgt[0, i]
            questions_tok_seed[0, i] = questions_tok_tgt[0, i]
        assert (questions_tok_seed[0, i] == questions_tok_tgt[0, i]).all()
print("Total number wrong: ", len(correction_on_running_amt))
print("Hypthetical compression ratio: ", 1 / ((len(questions_tok_seed.flatten()) + 2 * len(correction_on_running_amt)) / len(questions_tok_seed.flatten()))) # ignoring model since we assume that will go to zero

In [None]:
"""
Try recursive compression LOL

Language modeling is compression: https://arxiv.org/abs/2309.10668

It sems like 10:1 might be possible with `https://bellard.org/nncp/` for language on
enwiki8

conclusion is this is not super good ngl (probably need a better algorithm/model if this
is to work at all)
"""
# wrong_str = "".join([f"{k}{v}" for k, v in correction_on_running_amt.items()]) # <---- can you predict all the data in a dumb way?
wrong_str = ",".join([f"{k}" for k, v in correction_on_running_amt.items()]) # <----- can you predict just the indices?
wrong_str_tok = tokenizer.encode(wrong_str, return_tensors="pt").to(model.device)
print(wrong_str_tok.shape)
wrong_str_seed = torch.zeros_like(wrong_str_tok)
wrong_str_seed[:, :seed_length] = wrong_str_tok[:, :seed_length]
wrong_str_tgt = torch.zeros_like(wrong_str_tok)
wrong_str_tgt[:, seed_length:] = wrong_str_tok[:, seed_length:]
n_wrong = 0
for i in tqdm.trange(seed_length, wrong_str_tgt.shape[1], desc=f"Running inference... run from [{seed_length}, {wrong_str_tgt.shape[1]})"): # fmt: skip
    with contextlib.redirect_stdout(io.StringIO()):
        prediction = model(wrong_str_seed[:, :i]).logits[:, -1, :].argmax(dim=-1)
        if (prediction == wrong_str_tgt[0, i]).all():
            wrong_str_seed[0, i] = prediction
        else:
            n_wrong += 1
print(f"Total number wrong: {n_wrong}")
print(f"Percent wrong: n_wrong/total = {n_wrong / wrong_str_tgt.shape[1]}")