# Iterative Prompt Compression

To summarize text $\mathbf x$, we want compressed $\mathbf x' : P_{LM}(\mathbf x
| \mathbf x')$ is extremely high. 

We can use a large LLM (e.g., GPT-4) to suggest shorter and shorter versions of
the text $\mathbf x'$ as we select the best one at each iteration as the prompt
maximizing $P_{LM}(\mathbf x | \mathbf x')$. 

Acknowledgements: Discussion with Dr. Alessandro Achille, Prof. Stefano Soatto 
at AWS research.

In [42]:
# import box
import os
import re
import datetime
import copy
from tqdm import tqdm 

import numpy as np
import torch 
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

import anthropic

In [43]:
# constants
OAI_KEY_PATH = "OAI_KEY.txt"
ANTHROPIC_KEY_PATH = "ANTHROPIC_KEY.txt"
# HF_MODEL = "meta-llama/Llama-2-7b-hf"
HF_MODEL = "gpt2"
MPS=True

In [44]:
# get openai key from OAI_KEY_PATH
with open(ANTHROPIC_KEY_PATH, 'r') as f:
    anthropic_key = f.read().strip()


client = anthropic.Anthropic(
    # defaults to os.environ.get("ANTHROPIC_API_KEY")
    api_key=anthropic_key,
)

In [45]:
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL)
model = AutoModelForCausalLM.from_pretrained(HF_MODEL)
model = model.eval() 
#move to cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# test for mps 
if MPS: 
    device = 'mps'

model = model.to(device)

In [46]:
model = model.to('mps')

In [47]:
text_to_compress = """Mathematics
Mathematics is an area of knowledge that includes the topics of numbers,
formulas and related structures, shapes and the spaces in which they are
contained, and quantities and their changes. These topics are represented in
modern mathematics with the major subdisciplines of number theory,[1]
algebra,[2] geometry,[1] and analysis,[3] respectively. There is no general
consensus among mathematicians about a common definition for their academic
discipline.
"""

In [48]:
system_prompt = """Hello Claude this is Aman. I'm building this system to
compress text with small ~7b param language models. For text x, you're gonna
produce a compressed version such that P(x | x') is maximized while
minimizing the length of x'. You can use any prompting strategies you want. The
user will give you the text x and you will respond with the compressed version.
Note that the compressed version must be smaller than the input.


REMEMBER TO DELIMIT YOUR COMPRESSED RESPONSE WITH <c> AND </c>!!

"""

In [49]:
def compress_text(text_to_compress, client, system_prompt_, n=1, **kwargs): 
    """ Compresses text using anthropic's claude-3-haiku-20240307 model. 
    text_to_compress: str
    client: anthropic.Anthropic() 
    system_prompt: str, guide for compression (includes <c> </c> spec), expecting 
        kwargs dict map for adding details.
    *args: list[str], used with system_prompt.format() to produce final message.
    """
    system_prompt = system_prompt_.format(**kwargs)
    # print("Sending API request at ", datetime.datetime.now().strftime("%I:%M:%S%p on %B %d, %Y"))
    _message = client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=1000,
        temperature=1,
        system=system_prompt,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": text_to_compress
                    }
                ]
            }
        ]
    )
    # print("Received API response at ", datetime.datetime.now().strftime("%I:%M:%S%p on %B %d, %Y"))
    # print(_message)
    content = _message.content
    # print(content)
    text = content[0].text

    pattern = r"<c>(.*?)<\/c>"
    matches = re.findall(pattern, text, re.DOTALL)
    retval = matches[0]

    # remove any {, } characters
    return retval.replace("{", "").replace("}", "")

compressed = compress_text(text_to_compress, client, system_prompt) 
compressed = f"{compressed}"
compressed

'Mathematics is the study of numbers, formulas, shapes, and quantities. The main areas are number theory, algebra, geometry, and analysis.'

In [50]:
def score_compressed(original, compressed, model, tokenizer): 
    """ - log P(x_p | x)
    original: str 
    compressed: str 
    model: HF transformer 
    tokenizer: HF tokenizer
    """
    x = original 
    x_p = compressed
    x_ids = tokenizer(x).input_ids # list[int]
    x_p_ids = tokenizer(x_p).input_ids # list[int]

    label_ids = [-100 for _ in range(len(x_p_ids))] + x_ids # list[int]
    input_ids = x_p_ids + x_ids # list[int]

    input_dict = {
        'input_ids': torch.tensor(input_ids).unsqueeze(0).to(model.device),
        'labels': torch.tensor(label_ids).unsqueeze(0).to(model.device),
    }
    # run thru model 
    with torch.no_grad():
        outputs = model(**input_dict)
        loss = outputs.loss
    return loss.item()


score_compressed(original=f"\n\nUNCOMPRESSED VERSION: \n\n"+text_to_compress, 
                 compressed="\n\nCOMPRESSED VERSION: \n\n"+compressed, 
                 model=model, 
                 tokenizer=tokenizer)
                 

3.6254160404205322

In [51]:
def evolve_compressed_prompt(original, 
                             client, 
                             model, 
                             tokenizer,
                             prompt, 
                             prompt_kwargs,
                             max_compressed_len=0.3, 
                             num_evolutions=10, 
                             pool_size=10, 
                             kill_frac=0.3, 
                             out_dir = 'results/exp1/'): 
    """
    args: 
        original: str, original text to compress. 
        client: anthropic.Anthropic()
        model: HF transformer 
        tokenizer: HF tokenizer
        prompt: str, system prompt for compression to be formatted with prompt_kwargs
        prompt_kwargs: dict{str: str}, kwargs with which to format prompt
        max_compressed_len: int for max compressed length, or float for fraction of original. Default=0.3
        num_evolutions: number of rounds of Claude calls to optimize pool 
        pool_size: pool of prompts to keep 
        kill_frac: fraction of pool to kill off each round. Default=0.3
    """
    print("Length of original: ", len(original))
    if type(max_compressed_len) == float: 
        assert max_compressed_len <= 1 and max_compressed_len >= 0
        max_compressed_len = round(max_compressed_len* len(original))
    assert max_compressed_len <= len(original) and max_compressed_len > 0

    if 'max_compressed_len' in prompt_kwargs.keys(): 
        prompt_kwargs['max_compressed_len'] = max_compressed_len

    # initialize pool by calling compress_text(text_to_compress, client, system_prompt_, **kwargs) 
    pool = [] # list[{'compressed': str, 'loss': float})] of compressed prompts
    for i in range(pool_size): 
        compressed_i = ""
        loss = -1
        cnt = 0
        compressed_i = compress_text(original, client, prompt, **prompt_kwargs)
        print("Received response from Claude!")

        if len(compressed_i) > max_compressed_len: 
            print(f"[WARN] Trimming compressed text from Claude from len = {len(compressed_i)} to max_compressed_len = {max_compressed_len}...")
            compressed_i = compressed_i[:max_compressed_len]

        loss = score_compressed(original=original, 
                        compressed=compressed_i,
                        model=model, 
                        tokenizer=tokenizer)

        pool.append({'compressed': compressed_i, 'loss': loss})
        print(f"Compressed {i} length: {len(compressed_i)}, loss {loss}")

    # Let's just do greedy evolution for now. We will pick a random prompt from
    # the pool, and make a new compressed text by passing it as the
    # `prompt_kwargs['best_prompt']`. 
    # We will get rid of the worst performing prompt in the pool and replace it
    # with the new compressed text.
    # 
    # Hint: Start by checking that prompt_kwargs['best_prompt'] exists. If it 
    # doesn't, that indicates that the user didn't specify a prompt capable of 
    # supporting evolution. Beginner's mistake!

    # Evolution process
    if 'best_prompt' not in prompt_kwargs:
        raise ValueError("The 'best_prompt' key is missing in prompt_kwargs. Please provide a prompt that supports evolution.")
    pool_hist = {}
    with tqdm(range(num_evolutions), desc='[first loop]') as pbar:
        for i in pbar:
            # Select a random prompt from the pool
            # idx = np.random.randint(0, pool_size)
            losses = np.array([p['loss'] for p in pool])
            probs = np.exp(-losses) / np.sum(np.exp(-losses))
            # sample the index of the prompt to replace based on the losses
            # idx = np.random.choice(range(len(pool)), p=probs)
            idx = np.random.choice(range(len(pool)))
            parent_loss = pool[idx]['loss']
            # print("Losses: ", losses)
            # print("Best idx: ", idx)
            prompt_kwargs['best_prompt'] = pool[idx]['compressed']


            # Generate a new compressed text using the selected prompt
            new_compressed = compress_text(original, client, prompt, **prompt_kwargs)
            # print(f"Evolution {i+1}: Received response from Claude!")

            if len(new_compressed) > max_compressed_len:
                # print(f"[WARN] Trimming compressed text from Claude from len = {len(new_compressed)} to max_compressed_len = {max_compressed_len}...")
                new_compressed = new_compressed[:max_compressed_len]

            new_loss = score_compressed(original=original,
                                        compressed=new_compressed,
                                        model=model,
                                        tokenizer=tokenizer)

            # Replace the worst-performing prompt in the pool with the new compressed text
            worst_idx = np.argmax([p['loss'] for p in pool])
            pool[worst_idx] = {'compressed': new_compressed, 'loss': new_loss}

            # print(f"Evolution {i+1}: Replaced prompt {worst_idx} with new compressed text. Length: {len(new_compressed)}, Loss: {new_loss}")
            # update pool_hist with a deepcopy of the pool 
            pool_hist[f'pool_{i}'] = copy.deepcopy(pool)

            # save pool_hist to disk 
            out_path = os.path.join(out_dir, f'pool_hist_iter_{i}.pt')
            # make dirs if not exist 
            os.makedirs(os.path.dirname(out_path), exist_ok=True)
            torch.save(pool_hist, out_path)

            pbar_text = f"Evolution {i+1}/{num_evolutions}: Child Loss: {new_loss:.2f}, Best Loss: {np.min([p['loss'] for p in pool]):.2f}, Parent Loss: {parent_loss:.2f}"
            # print(pbar_text)
            pbar.set_description(pbar_text)

    # Return the best-performing prompt from the final pool
    best_idx = np.argmin([p['loss'] for p in pool])
    return pool[best_idx]['compressed'], pool_hist

    




                 

In [52]:
# call evo func
char_compress_prompt = """Hello Claude this is Aman. I'm building this system to
compress text with small ~7b param language models. For text x, you're gonna
produce a compressed version such that P(x | x') is maximized while
minimizing the length of x'. You can use any prompting strategies you want. The
user will give you the text x and you will respond with the compressed version.
Note that the compressed version must be smaller than the input.

Length (chars) of original (below): {original_len}
Length (chars) of compressed sequence: {max_compressed_len}

For context, here's one of the best ones to date: 

<c>
{best_prompt}
</c>

REMEMBER TO DELIMIT YOUR COMPRESSED RESPONSE WITH <c> AND </c>!!

I want you to be as creative as possible. Add some entropy to the system! 
Experiment with crazy ideas to make the prompt better :)

Anyway, here's the original text: 
"""

char_compress_kwargs = {
    "original_len": -1, 
    "max_compressed_len": -1,
    "best_prompt": "None yet, you're the first!"
}

# datetime_string = datetime.datetime.now().strftime("%I:%M:%S%p on %B %d, %Y")
datetime_string = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")

pool, pool_hist = evolve_compressed_prompt(original = text_to_compress, 
                            client = client, 
                            model = model, 
                            tokenizer = tokenizer,
                            prompt = char_compress_prompt, 
                            prompt_kwargs = char_compress_kwargs,
                            max_compressed_len=0.3, 
                            num_evolutions=1000, 
                            pool_size=30, 
                            out_dir = f'results/{datetime_string}')

Length of original:  469
Received response from Claude!
Compressed 0 length: 86, loss 2.9625706672668457
Received response from Claude!
Compressed 1 length: 141, loss 2.903292417526245
Received response from Claude!
Compressed 2 length: 65, loss 2.934356212615967
Received response from Claude!
[WARN] Trimming compressed text from Claude from len = 154 to max_compressed_len = 141...
Compressed 3 length: 141, loss 2.9368834495544434
Received response from Claude!
Compressed 4 length: 131, loss 2.813264846801758
Received response from Claude!
[WARN] Trimming compressed text from Claude from len = 169 to max_compressed_len = 141...
Compressed 5 length: 141, loss 2.953397512435913
Received response from Claude!
[WARN] Trimming compressed text from Claude from len = 257 to max_compressed_len = 141...
Compressed 6 length: 141, loss 3.1293246746063232
Received response from Claude!
Compressed 7 length: 72, loss 3.1635634899139404
Received response from Claude!
[WARN] Trimming compressed text f

Evolution 1000/1000: Child Loss: 2.83, Best Loss: 2.67, Parent Loss: 2.70: 100%|██████████| 1000/1000 [46:41<00:00,  2.80s/it]


In [53]:
pool

'\nMathematics: numbers, formulas, shapes, changes.\nSubdisciplines: number theory, algebra, geometry, analysis.\nNo common definition.\n'

In [54]:
text_to_compress

'Mathematics\nMathematics is an area of knowledge that includes the topics of numbers,\nformulas and related structures, shapes and the spaces in which they are\ncontained, and quantities and their changes. These topics are represented in\nmodern mathematics with the major subdisciplines of number theory,[1]\nalgebra,[2] geometry,[1] and analysis,[3] respectively. There is no general\nconsensus among mathematicians about a common definition for their academic\ndiscipline.\n'

In [55]:
hist = [pool_hist[f'pool_{i}'] for i in range(len(pool_hist.keys()))]

In [56]:
hist[0]

[{'compressed': '\nMath: numbers, formulas, shapes, quantities & changes. No consensus on a definition.\n',
  'loss': 2.9625706672668457},
 {'compressed': '\nMath: numbers, formulas, shapes, quantities & changes. Major areas: number theory, algebra, geometry, analysis. No consensus on definition.\n',
  'loss': 2.903292417526245},
 {'compressed': '\nNum, forms, shapes, changes\nMaths, diverse fields, no consensus\n',
  'loss': 2.934356212615967},
 {'compressed': '\nMath: numbers, formulas, shapes, quantities & their changes. Main subdiscs:\nnumber theory, algebra, geometry, analysis. No consensus on a un',
  'loss': 2.9368834495544434},
 {'compressed': '\nMath: numbers, formulas, shapes, quantities. Subdisciplines: number theory, algebra, geometry, analysis. No universal definition.\n',
  'loss': 2.813264846801758},
 {'compressed': '\nMaths: numbers, formulas, shapes, quantities & changes. Major subfields: number theory, algebra, geometry, analysis. No common definition fo',
  'loss': 2

In [57]:
hist[-1]

[{'compressed': '\nMathematics: numbers, formulas, shapes, quantities, changes\nSubdisciplines: number theory, algebra, geometry, analysis\nAmbiguous definition\n',
  'loss': 2.6868135929107666},
 {'compressed': '\nMathematics: study of numbers, formulas, shapes, changes.\nSubdisciplines: number theory, algebra, geometry, analysis.\nNo common definition.\n',
  'loss': 2.6769490242004395},
 {'compressed': '\nMathematics: numbers, formulas, shapes, changes.\nSubdisciplines: number theory, algebra, geometry, analysis.\nNo common definition.\n',
  'loss': 2.673139810562134},
 {'compressed': '\nMath: numbers, formulas, shapes, quantities, changes \nSubdisciplines: number theory, algebra, geometry, analysis\nNo definitive definition\n',
  'loss': 2.6774086952209473},
 {'compressed': '\nMaths: numbers, formulas, shapes, changes.\nSubdisciplines: number theory, algebra, geometry, analysis.\nNo common definition.\n',
  'loss': 2.6999852657318115},
 {'compressed': '\nMaths: numbers, formulas, sh