In [1]:
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import sys
from datetime import datetime
import random

import numpy as np
import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from transformers import set_seed as hf_set_seed
from nltk.tokenize import TextTilingTokenizer
from wtpsplit import SaT
from tqdm.notebook import tqdm
from llmlingua import PromptCompressor


In [2]:
from llmlingua import PromptCompressor

# compressor = PromptCompressor(
#     model_name="microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank",
#     use_llmlingua2=True
# )
# use_llmlingua = True  # Toggle compressor on/off

compressor = PromptCompressor("microsoft/phi-2")
use_llmlingua = True

# llm_lingua_model_name = "microsoft/phi-2"

# compressor = PromptCompressor(
#     model_name="/assets/models/meta-llama-3.2-instruct-3b",
#     use_llmlingua2=False,        # disable llmlingua-2 to use coarse-to-fine pipeline
#     device_map="auto",          # distribute across available devices
# )
# use_llmlingua = True

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
datasets = ["gov_report", "summ_screen_fd",
            "qmsum", "qasper", "narrative_qa", "quality"]


chunking = True
system_prompt = True

if chunking:
    print("Using chunking with LLMLingua default compressor")

CHUNK_SIZE = 4
BATCH_SIZE = 32
max_examples_per_task = -1

Using chunking with LLMLingua default compressor


In [4]:

model_to_max_input_tokens = {
    "Qwen/Qwen2.5-1.5B-Instruct": 8192,
    "MBZUAI/LaMini-GPT-1.5B" : 512,
    "/assets/models/meta-llama-2-chat-7b" : 8192,
    "instruction-pretrain/InstructLM-1.3B":2048,
    "nvidia/AceInstruct-1.5B": 8192,
    "/assets/models/meta-llama-3.2-instruct-3b": 8192 * 32   
}

def llama3_prompt(user_message):
    BEGIN = "<|begin_of_text|>"
    START = "<|start_header_id|>"
    END = "<|end_header_id|>"
    EOT = "<|eot_id|>"

    system_prompt = (
        "You are a helpful assistant. Always follow the task instruction carefully. "
        "The first paragraph before the first double line break contains the task instruction. "
        "Generate text as a natural continuation of the user message. Do not include any meta-commentary or explanations."
    )

    if system_prompt:
        prompt = (
        f"{BEGIN}"
        f"{START}system{END}\n\n{system_prompt}{EOT}\n"
        f"{START}user{END}\n\n{user_message}{EOT}\n"
        f"{START}assistant{END}\n\n"
    )
    else:
        prompt = (
        f"{BEGIN}"
        f"{START}user{END}\n\n{user_message}{EOT}\n"
        f"{START}assistant{END}\n\n"
    )
    

    return prompt


model_to_chat_template = {
    "/assets/models/meta-llama-3.2-instruct-3b": llama3_prompt 
}

In [5]:
def trim_doc_keeping_suffix(tokenizer, tokenized_input_full, example, suffix_index, max_tokens, device):
    seperator_and_suffix = f"{example['truncation_seperator'].strip()}\n\n{example['input'][suffix_index:].strip()}\n"
    tokenized_seperator_and_suffix = tokenizer(seperator_and_suffix, return_tensors="pt").input_ids.to(device)
    tokenized_input_trimmed = tokenized_input_full[:, :max_tokens - tokenized_seperator_and_suffix.shape[1]]
    tokenized_input = torch.cat([tokenized_input_trimmed, tokenized_seperator_and_suffix], dim=1)
    return tokenized_input

In [6]:
sat = SaT("sat-3l")
sat.half().to("cuda")

model_name = "/assets/models/meta-llama-3.2-instruct-3b"

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
    model_name, trust_remote_code=True, padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Loading model: {model_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"

max_input_length = model_to_max_input_tokens[model_name]

model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype=torch.bfloat16)


# summary_model_name = "google/flan-t5-xl"
# summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_name)
# summary_model = AutoModelForSeq2SeqLM.from_pretrained(
#     summary_model_name, torch_dtype=torch.bfloat16).to(device)

summary_model = model
summary_tokenizer = tokenizer

Loading tokenizer
Loading model: /assets/models/meta-llama-3.2-instruct-3b


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
def extract_segments(text, chunksize=CHUNK_SIZE):
    sentences = sat.split(text)
    segments = []
    for i in range(0, len(sentences), chunksize):
        segment = " ".join(sentences[i:i + chunksize])
        segments.append(segment.strip())
    # Remove empty segments
    segments = [s for s in segments if s]

    return segments

def batch_concise_rewrite_chunks(chunks, query=None, min_length=50):
    # (keep your existing code for computing token_lens, to_rewrite, passthrough, etc.)
    token_lens = [
        len(tokenizer(chunk, return_tensors="pt").input_ids[0])
        for chunk in chunks
    ]

    to_rewrite = []
    indices = []
    passthrough = {}
    for i, (chunk, tok_len) in enumerate(zip(chunks, token_lens)):
        if tok_len < min_length:
            passthrough[i] = chunk
        else:
            indices.append(i)
            to_rewrite.append(chunk)

    rewritten_chunks = [""] * len(chunks)

    if to_rewrite:
        # ← HERE: call your PromptCompressor instead of HF generate()
        # I’m assuming its API looks like: compress_batch(texts, query=None) -> List[str]
        compressed_texts = llm_lingua.compress_batch(to_rewrite, query=query)

        for k, idx in enumerate(indices):
            rewritten_chunks[idx] = compressed_texts[k]

    # put back any chunks we skipped
    for i, original in passthrough.items():
        rewritten_chunks[i] = original

    return rewritten_chunks



def process_model_input_chunking_withoutlingua(tokenizer, example, max_tokens, device, dataset):

    instruction = example["input"][:example['document_start_index']]
    truncation_seperator = example['truncation_seperator']

    query = example["input"][example['query_start_index']:]
    if len(query) == 0:
        query = None
    doc = example["input"][example['document_start_index']
        :example['document_end_index']]

    # Apply semantic chunking
    chunks = extract_segments(doc)

    # Compress in batches of 5 in parallel
    compressed_chunks = []
    for i in range(0, len(chunks), BATCH_SIZE):
        batch = chunks[i:i + BATCH_SIZE]
        compressed_batch = batch_concise_rewrite_chunks(batch, query)
        compressed_chunks.extend(compressed_batch)


    compressed_doc = "\n".join(compressed_chunks)

    # Compute ratio of compressed doc to original doc
    ratio = len(compressed_doc) / len(doc)
    print(f"Compression ratio: {ratio:.2f}")
    if ratio >= 0.95:
        print("Compression ratio is too high, skipping compression")
        compressed_doc = doc

    input_text = f"{instruction}{compressed_doc}{truncation_seperator}{query or ''}"

    compressed_input = model_to_chat_template.get(model_name, lambda x: x)(input_text)

    # Write compressed input to file
    with open(f"compressed_input/{dataset}/{example['id']}.txt", "w") as f:
        f.write(compressed_input)

    # Write original input to file
    with open(f"original_input/{dataset}/{example['id']}.txt", "w") as f:
        f.write(example["input"])

    tokenized_input_full = tokenizer(
        compressed_input, return_tensors="pt").input_ids.to(device)

    return tokenized_input_full

def process_model_input(tokenizer, example, max_tokens, device):
    instr = example['input'][:example['document_start_index']]
    sep = example['truncation_seperator']
    query = example['input'][example['query_start_index']:] or ''
    doc = example['input'][example['document_start_index']:example['document_end_index']]
    prompt = f"{instr}{doc}{sep}{query}"
    chat = model_to_chat_template.get(model_name, lambda x:x)(prompt)
    return tokenizer(chat, return_tensors="pt").input_ids.to(device)

In [8]:
def compress_chunks_llmlingua(chunks, max_ratio=None):
    """
    Compresses a list of text chunks using LLMLingua, but only if
    the compressed length is within the specified max_ratio of the original.
    If the ratio is exceeded, returns the original chunk.

    Args:
        chunks (List[str]): List of text segments to compress.
        max_ratio (float or None): Maximum allowed compressed/original token ratio.
            e.g., 0.5 means compressed chunk must be <= 50% of original tokens.
            If None, no ratio enforcement.
    """
    compressed = []
    for chunk in chunks:
        # Count original tokens (simple whitespace split)
        orig_tokens = len(chunk.split())
        target = orig_tokens // 2
        # Perform LLMLingua compression
        res = compressor.compress_prompt(
            chunk,
            instruction="",
            question="",
            target_token=target
        )
        short = res.get('compressed_prompt', chunk)
        # Enforce max_ratio if provided
        if max_ratio is not None and orig_tokens > 0:
            comp_tokens = len(short.split())
            if comp_tokens / orig_tokens > max_ratio:
                # Skip compression if ratio too high
                compressed.append(chunk)
                continue
        compressed.append(short)
    return compressed

# Build model input when chunking
def process_model_input_chunking(tokenizer, example, max_tokens, device, dataset):
    instr = example['input'][:example['document_start_index']]
    sep = example['truncation_seperator']
    query = example['input'][example['query_start_index']:] or ''
    doc = example['input'][example['document_start_index']:example['document_end_index']]

    # Split into segments and compress each chunk, enforcing a max ratio
    chunks = extract_segments(doc)
    if use_llmlingua:
        compressed_chunks = compress_chunks_llmlingua(chunks, max_ratio=0.95)
    else:
        compressed_chunks = chunks

    # Join compressed chunks with newlines
    compressed_doc = "".join(compressed_chunks)

    # Build the prompt by combining instruction, compressed doc, separator, and query
    prompt = f"{instr}{compressed_doc}{sep}{query}"
    chat = model_to_chat_template.get(model_name, lambda x: x)(prompt)

    # Save compressed and original inputs for inspection
    os.makedirs(f"compressed_input/{dataset}", exist_ok=True)
    with open(f"compressed_input/{dataset}/{example['id']}.txt", "w") as f:
        f.write(chat)
    os.makedirs(f"original_input/{dataset}", exist_ok=True)
    with open(f"original_input/{dataset}/{example['id']}.txt", "w") as f:
        f.write(example['input'])

    # Tokenize and return the input ids for generation
    inputs = tokenizer(chat, return_tensors="pt").to(device)
    return inputs.input_ids

In [9]:
generations_dir = "generations/ipynb"


In [10]:
llm_name = "phi-.5"
seed = 43
random.seed(seed)
np.random.seed(seed)
hf_set_seed(seed)
print("Params:")
print(f"model: {model_name}")
model_suffix = model_name.split("/")[-1]
model_suffix = f"{model_suffix}-{llm_name}"
if system_prompt:
    model_suffix = f"{model_suffix}--system-prompt"
generations_dir = os.path.join(generations_dir, model_suffix)
print(f"generations_dir: {generations_dir}")
print(f"max_examples_per_task: {max_examples_per_task}")
print("=" * 50)
time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
print(f"time as start: {time}")




Params:
model: /assets/models/meta-llama-3.2-instruct-3b
generations_dir: generations/ipynb/meta-llama-3.2-instruct-3b-phi-.5--system-prompt
max_examples_per_task: -1
time as start: 26_04_2025_12_34_51


In [11]:
model = model.eval()

mid_layer_index = 16
print(f"model loaded!, device:{model.device}")

print("Will write to:", generations_dir)
os.makedirs(generations_dir, exist_ok=True)
for dataset in datasets:
    generations = dict()
    print(f"Processing {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time as start {dataset}: {time}")
    print(f"Loading {dataset}")
    data = load_dataset("tau/zero_scrolls", dataset, trust_remote_code=True)
    print(f"Loaded {dataset}")
    # Create dir compressed_input if it doesn't exist
    compressed_dir = os.path.join("compressed_input", dataset)
    original_dir = os.path.join("original_input", dataset)
    if not os.path.exists(compressed_dir):
        os.makedirs(compressed_dir)
    if not os.path.exists(original_dir):
        os.makedirs(original_dir)

    for i, example in tqdm(enumerate(data["validation"])):
        print("Processing example:", example["id"])

        if 0 < max_examples_per_task == i:
            print(f"Reached {max_examples_per_task} for {dataset}. Breaking")
            break

        try:
            if chunking:
                model_input = process_model_input_chunking(
                    tokenizer, example, max_input_length, device, dataset)
            else:
                model_input = process_model_input(
                    tokenizer, example, max_input_length, device)
        except Exception as e:
            print(f"Error processing example {i} in {dataset}: {e}")
            continue

        # Get hidden states from the 16th layer
        with torch.no_grad():
            prediction_token_ids = model.generate(model_input,
                                                  max_new_tokens=512,
                                                  do_sample=True,
                                                  top_p=0.9,
                                                  top_k=0,
                                                  temperature=0.5,
                                                  pad_token_id=tokenizer.eos_token_id,
                                                  )

            predicted_text = tokenizer.decode(
                prediction_token_ids[0][model_input.shape[1]:], skip_special_tokens=True)
            del prediction_token_ids, model_input
            torch.cuda.empty_cache()

        generations[example["id"]] = predicted_text

    out_file_path_pred = os.path.join(generations_dir, f"{dataset}.json")
    with open(out_file_path_pred, 'w') as f_out:
        json.dump(generations, f_out, indent=4)

    print(f"Done generating {len(generations)} examples from {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time at end: {time}")
    print(f"Look for predictions in {generations_dir}")

model loaded!, device:cuda:0
Will write to: generations/ipynb/meta-llama-3.2-instruct-3b-phi-.5--system-prompt
Processing narrative_qa
time as start narrative_qa: 26_04_2025_12_34_52
Loading narrative_qa
Loaded narrative_qa


0it [00:00, ?it/s]

Processing example: 3858


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Processing example: 3858
Processing example: 5947
Processing example: 5947
Processing example: 12164
Processing example: 12164
Processing example: 23148
Processing example: 23148
Processing example: 25278
Processing example: 25278
Processing example: 33068
Processing example: 33068
Processing example: 35134
Processing example: 35134
Processing example: 38322
Processing example: 38322
Processing example: 42586
Processing example: 42586
Processing example: 16886
Processing example: 16886
Done generating 10 examples from narrative_qa
time at end: 26_04_2025_12_49_16
Look for predictions in generations/ipynb/meta-llama-3.2-instruct-3b-phi-.5--system-prompt
Processing quality
time as start quality: 26_04_2025_12_49_16
Loading quality
Loaded quality


0it [00:00, ?it/s]

Processing example: 62139_J05FWZR6_1
Processing example: 52855_MV65I88C_9
Processing example: 62085_C1SL2YBE_3
Processing example: 63616_MQ1O9T2Q_6
Processing example: 63833_V187YO4H_2
Processing example: 63392_7YS4HHFI_6
Processing example: 63473_1VIHQ8TY_4
Processing example: 51650_B3KKWWD1_7
Processing example: 51274_8Q2YNHG5_6
Processing example: 20077_ZF5G55FD_1
Processing example: 22579_RQ3GB4A1_3
Processing example: 22867_TJ9SPIHC_9
Processing example: 22875_L821878U_6
Processing example: 22967_0XT2L7PI_7
Processing example: 22867_IZGAWLCJ_4
Processing example: 22462_BUA2LH2S_5
Processing example: 31736_TV0CUXDH_4
Processing example: 99927_EVLEI3Q2_6
Processing example: 31282_BQYW9TCH_4
Processing example: 99914_0Q5X8VEX_4
Processing example: 32665_VRYQXG3Y_9
Done generating 21 examples from quality
time at end: 26_04_2025_12_51_50
Look for predictions in generations/ipynb/meta-llama-3.2-instruct-3b-phi-.5--system-prompt
