In [None]:
import json
import os

import sys
from datetime import datetime
import random

import numpy as np
import torch

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoConfig
from transformers import set_seed as hf_set_seed
from nltk.tokenize import TextTilingTokenizer
from wtpsplit import SaT
from tqdm.notebook import tqdm


In [None]:
datasets = ["gov_report", "summ_screen_fd",
            "qmsum", "qasper", "narrative_qa", "quality"]


chunking = os.getenv("CHUNKING", "False").lower() == "true"

if chunking:
    print("Using chunking")

CHUNK_SIZE = 4
BATCH_SIZE = 32
max_examples_per_task = -1

In [None]:

model_to_max_input_tokens = {
    "Qwen/Qwen2.5-1.5B-Instruct": 8192,
    "MBZUAI/LaMini-GPT-1.5B" : 512,
    "/assets/models/meta-llama-2-chat-7b" : 8192,
    "instruction-pretrain/InstructLM-1.3B":2048,
    "nvidia/AceInstruct-1.5B": 8192,
    "/assets/models/meta-llama-3.2-instruct-3b": 8192 * 32   
}

def llama3_prompt(user_message):
    BEGIN = "<|begin_of_text|>"
    START = "<|start_header_id|>"
    END = "<|end_header_id|>"
    EOT = "<|eot_id|>"

    system_prompt = (
        "You are a helpful assistant. Always follow the task instruction carefully. "
        "The first paragraph before the first double line break contains the task instruction. "
        "Generate text as a natural continuation of the user message. Do not include any meta-commentary or explanations."
    )

    prompt = (
        f"{BEGIN}"
        f"{START}system{END}\n\n{system_prompt}{EOT}\n"
        f"{START}user{END}\n\n{user_message}{EOT}\n"
        f"{START}assistant{END}\n\n"
    )
    return prompt


model_to_chat_template = {
    "/assets/models/meta-llama-3.2-instruct-3b": llama3_prompt 
}

In [None]:
def trim_doc_keeping_suffix(tokenizer, tokenized_input_full, example, suffix_index, max_tokens, device):
    seperator_and_suffix = f"{example['truncation_seperator'].strip()}\n\n{example['input'][suffix_index:].strip()}\n"
    tokenized_seperator_and_suffix = tokenizer(seperator_and_suffix, return_tensors="pt").input_ids.to(device)
    tokenized_input_trimmed = tokenized_input_full[:, :max_tokens - tokenized_seperator_and_suffix.shape[1]]
    tokenized_input = torch.cat([tokenized_input_trimmed, tokenized_seperator_and_suffix], dim=1)
    return tokenized_input

In [None]:
sat = SaT("sat-3l")
sat.half().to("cuda")

model_name = "/assets/models/meta-llama-3.2-instruct-3b"

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
    model_name, trust_remote_code=True, padding_side="left")
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Loading model: {model_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"

max_input_length = model_to_max_input_tokens[model_name]

config = AutoConfig.from_pretrained(model_name)

config.rope_scaling = {
    "factor": 64.0,
    "high_freq_factor": 8.0,
    "low_freq_factor": 1.0,
    "original_max_position_embeddings": 8192,
    "rope_type": "llama3"
}


model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", config=config, torch_dtype=torch.bfloat16)


# summary_model_name = "google/flan-t5-xl"
# summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_name)
# summary_model = AutoModelForSeq2SeqLM.from_pretrained(
#     summary_model_name, torch_dtype=torch.bfloat16).to(device)

summary_model = model
summary_tokenizer = tokenizer

In [None]:
def extract_segments(text, chunksize=CHUNK_SIZE):
    sentences = sat.split(text)
    segments = []
    for i in range(0, len(sentences), chunksize):
        segment = " ".join(sentences[i:i + chunksize])
        segments.append(segment.strip())
    # Remove empty segments
    segments = [s for s in segments if s]

    return segments


def batch_concise_rewrite_chunks(chunks, query = None, min_length=50):
    # Tokenize each chunk individually to count tokens
    token_lens = [len(summary_tokenizer(
        chunk, return_tensors="pt").input_ids[0]) for chunk in chunks]

    to_rewrite = []
    indices = []
    passthrough = {}

    for i, (chunk, tok_len) in enumerate(zip(chunks, token_lens)):
        if tok_len < min_length:
            passthrough[i] = chunk  # Skip rewriting
        else:
            indices.append(i)
            if query:
                prompt = f"Make the following text shorter while preserving information that is relevant to the following query.\n\nQuery: {query}.\n\nText: {chunk}\n\nShort version:"
            else:
                prompt = f"Make the following text shorter while preserving its core meaning.\n\nText: {chunk}\n\nShort version:"
            prompt = model_to_chat_template.get(model_name, lambda x: x)(prompt)
            
            to_rewrite.append(prompt)

    rewritten_chunks = [""] * len(chunks)

    if to_rewrite:
        prompt_lens = [len(summary_tokenizer(
            p, return_tensors="pt").input_ids[0]) for p in to_rewrite]
        max_new_tokens = max(prompt_lens)

        with torch.no_grad():
            inputs = summary_tokenizer(
                to_rewrite, return_tensors="pt", padding=True).to(device)
            input_ids = inputs["input_ids"]
            outputs = summary_model.generate(
                input_ids=input_ids,
                attention_mask=inputs["attention_mask"],
                max_new_tokens=max_new_tokens,
                do_sample=True,
                repetition_penalty=1.1,
                early_stopping=True,
                num_beams=2,
                eos_token_id=summary_tokenizer.eos_token_id,
                pad_token_id=summary_tokenizer.pad_token_id,
            )
            compressed = summary_tokenizer.batch_decode(
                outputs[:, input_ids.shape[1]:], skip_special_tokens=True)
            compressed = [text for text in compressed]
            torch.cuda.empty_cache()

        for i, idx in enumerate(indices):
            rewritten_chunks[idx] = compressed[i]

    for i, original in passthrough.items():
        rewritten_chunks[i] = original

   
    return rewritten_chunks



def process_model_input_chunking(tokenizer, example, max_tokens, device, dataset):

    instruction = example["input"][:example['document_start_index']]
    truncation_seperator = example['truncation_seperator']

    query = example["input"][example['query_start_index']:]
    if len(query) == 0:
        query = None
    doc = example["input"][example['document_start_index']
        :example['document_end_index']]

    # Apply semantic chunking
    chunks = extract_segments(doc)

    # Compress in batches of 5 in parallel
    compressed_chunks = []
    for i in range(0, len(chunks), BATCH_SIZE):
        batch = chunks[i:i + BATCH_SIZE]
        compressed_batch = batch_concise_rewrite_chunks(batch, query)
        compressed_chunks.extend(compressed_batch)


    compressed_doc = "\n".join(compressed_chunks)

    # Compute ratio of compressed doc to original doc
    ratio = len(compressed_doc) / len(doc)
    print(f"Compression ratio: {ratio:.2f}")
    if ratio >= 0.95:
        print("Compression ratio is too high, skipping compression")
        compressed_doc = doc

    input_text = f"{instruction}{compressed_doc}{truncation_seperator}{query or ''}"

    compressed_input = model_to_chat_template.get(model_name, lambda x: x)(input_text)

    # Write compressed input to file
    with open(f"compressed_input/{dataset}/{example["id"]}.txt", "w") as f:
        f.write(compressed_input)

    # Write original input to file
    with open(f"original_input/{dataset}/{example["id"]}.txt", "w") as f:
        f.write(example["input"])

    tokenized_input_full = tokenizer(
        compressed_input, return_tensors="pt").input_ids.to(device)

    return tokenized_input_full


def process_model_input(tokenizer, example, max_tokens, device):
    instruction = example["input"][:example['document_start_index']]
    truncation_seperator = example['truncation_seperator']

    query = example["input"][example['query_start_index']:]
    if len(query) == 0:
        query = None
    doc = example["input"][example['document_start_index']
        :example['document_end_index']]
    
    input_text = f"{instruction}{doc}{truncation_seperator}{query or ''}"
    input = model_to_chat_template.get(model_name, lambda x: x)(input_text)

    tokenized_input = tokenizer(
        input, return_tensors="pt").input_ids.to(device)

    return tokenized_input

In [None]:
generations_dir = "generations/ipynb"


In [None]:
seed = 43
random.seed(seed)
np.random.seed(seed)
hf_set_seed(seed)
print("Params:")
print(f"model: {model_name}")
model_suffix = model_name.split("/")[-1]
if chunking:
    model_suffix = f"{model_suffix}-chunking"
generations_dir = os.path.join(generations_dir, model_suffix)
print(f"generations_dir: {generations_dir}")
print(f"max_examples_per_task: {max_examples_per_task}")
print("=" * 50)
time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
print(f"time as start: {time}")




In [None]:
model = model.eval()

mid_layer_index = 16
print(f"model loaded!, device:{model.device}")

print("Will write to:", generations_dir)
os.makedirs(generations_dir, exist_ok=True)
for dataset in datasets:
    generations = dict()
    print(f"Processing {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time as start {dataset}: {time}")
    print(f"Loading {dataset}")
    data = load_dataset("tau/zero_scrolls", dataset, trust_remote_code=True)
    print(f"Loaded {dataset}")
    # Create dir compressed_input if it doesn't exist
    compressed_dir = os.path.join("compressed_input", dataset)
    original_dir = os.path.join("original_input", dataset)
    if not os.path.exists(compressed_dir):
        os.makedirs(compressed_dir)
    if not os.path.exists(original_dir):
        os.makedirs(original_dir)

    for i, example in tqdm(enumerate(data["validation"])):
        print("Processing example:", example["id"])

        if 0 < max_examples_per_task == i:
            print(f"Reached {max_examples_per_task} for {dataset}. Breaking")
            break

        try:
            if chunking:
                model_input = process_model_input_chunking(
                    tokenizer, example, max_input_length, device, dataset)
            else:
                model_input = process_model_input(
                    tokenizer, example, max_input_length, device)
        except Exception as e:
            print(f"Error processing example {i} in {dataset}: {e}")
            continue


        # Get hidden states from the 16th layer
        with torch.no_grad():
            prediction_token_ids = model.generate(model_input,
                                                  max_new_tokens=512,
                                                  do_sample=True,
                                                  top_p=0.9,
                                                  top_k=0,
                                                  temperature=0.5,
                                                  pad_token_id=tokenizer.eos_token_id,
                                                  )

            predicted_text = tokenizer.decode(
                prediction_token_ids[0][model_input.shape[1]:], skip_special_tokens=True)
            del prediction_token_ids, model_input
            torch.cuda.empty_cache()

        generations[example["id"]] = predicted_text

    out_file_path_pred = os.path.join(generations_dir, f"{dataset}.json")
    with open(out_file_path_pred, 'w') as f_out:
        json.dump(generations, f_out, indent=4)

    print(f"Done generating {len(generations)} examples from {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time at end: {time}")
    print(f"Look for predictions in {generations_dir}")