In [None]:
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import sys
from datetime import datetime
import random
import gc

import numpy as np
import torch
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import set_seed as hf_set_seed
from tqdm import tqdm

In [None]:
# add path to LM-Infinite folder
lm_infinite_root = os.path.abspath('..')

# Prepend it so Python finds 'models/llama.py' first
sys.path.insert(0, lm_infinite_root)

# Now you can import the converter
from models.llama import convert_llama_model

In [None]:
def llama3_prompt(user_message):
    BEGIN = "<|begin_of_text|>"
    START = "<|start_header_id|>"
    END = "<|end_header_id|>"
    EOT = "<|eot_id|>"

    system_prompt = (
        "Always follow the task instruction carefully."
        "The first paragraph before the first double line break contains the task instruction."
        "Generate text as a natural continuation of the user message."
        "Do not include any meta-commentary or explanations or your own thoughts."
    )

    prompt = (
        f"{BEGIN}"
        f"{START}system{END}\n\n{system_prompt}{EOT}\n"
        f"{START}user{END}\n\n{user_message}{EOT}\n"
        f"{START}assistant{END}\n\n"
    )
    return prompt


model_to_chat_template = {
    "/assets/models/meta-llama-3.2-instruct-3b": llama3_prompt 
}

In [None]:
datasets =["gov_report", "summ_screen_fd", "qmsum", "qasper","narrative_qa", "quality"]

In [None]:
model_to_max_input_tokens = 4096

In [None]:
def trim_doc_keeping_suffix(tokenizer, tokenized_input_full, example, suffix_index, max_tokens, device):
    seperator_and_suffix = f"{example['truncation_seperator'].strip()}\n\n{example['input'][suffix_index:].strip()}\n"
    tokenized_seperator_and_suffix = tokenizer(seperator_and_suffix, return_tensors="pt").input_ids.to(device)
    tokenized_input_trimmed = tokenized_input_full[:, :max_tokens - tokenized_seperator_and_suffix.shape[1]]
    tokenized_input = torch.cat([tokenized_input_trimmed, tokenized_seperator_and_suffix], dim=1)
    return tokenized_input

In [None]:
model_name = "/assets/models/meta-llama-3.2-instruct-3b"
model_print_name = "llama-basic_4096"
max_examples_per_task = -1


In [None]:
def process_model_input(tokenizer, example, max_tokens, device):
    instruction = example["input"][:example['document_start_index']]
    truncation_seperator = example['truncation_seperator']

    query = example["input"][example['query_start_index']:]
    if len(query) == 0:
        query = None
    doc = example["input"][example['document_start_index']
        :example['document_end_index']]
    
    input_text = f"{instruction}{doc}{truncation_seperator}{query or ''}"
    input = model_to_chat_template.get(model_name, lambda x: x)(input_text)
    # print(f"Input: {input}")
    tokenized_input = tokenizer(
        input, return_tensors="pt").input_ids.to(device)

    return tokenized_input

In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"

generations_dir = "generations/ipynb"
seed = 43
random.seed(seed)
np.random.seed(seed)
hf_set_seed(seed)
print("Params:")
print(f"model: {model_name}")
generations_dir = os.path.join(generations_dir, model_print_name.replace("/", "_"))
print(f"generations_dir: {generations_dir}")
print(f"max_examples_per_task: {max_examples_per_task}")
print("=" * 50)
time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
print(f"time as start: {time}")

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Loading model: {model_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"

max_input_length = model_to_max_input_tokens

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=False,
    trust_remote_code=True,
)
torch.cuda.empty_cache()
gc.collect()
model.past_key_values = None

#model = convert_llama_model(model, local_branch=8192, global_branch=256, safe_mode=False)


In [None]:
model = model.eval()

print(f"{model} model loaded!, device:{model.device}")

print("Will write to:", generations_dir)
os.makedirs(generations_dir, exist_ok=True)
for dataset in datasets:
    generations = dict()
    input_task = dict()
    output_task = dict()
    print(f"Processing {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time as start {dataset}: {time}")
    print(f"Loading {dataset}")
    data = load_dataset("tau/zero_scrolls", dataset, cache_dir="/home/athul/datasets_cache")
    print(f"Loaded {dataset}")

    for i, example in tqdm(enumerate(data["validation"])):
        print("Processing example:", example["id"])

        if 0 < max_examples_per_task == i:
            print(f"Reached {max_examples_per_task} for {dataset}. Breaking")
            break

        model_input = process_model_input(tokenizer, example, max_input_length, device)

        prediction_token_ids = model.generate(model_input,
                                                  max_new_tokens=512,
                                                  do_sample=False,
                                                  top_p=0,
                                                  top_k=0,
                                                  temperature=1,
                                                  pad_token_id=tokenizer.eos_token_id, )
        model.past_key_values = None
        torch.cuda.empty_cache()
        gc.collect()

        predicted_text = tokenizer.decode(prediction_token_ids[0][model_input.shape[1]:], skip_special_tokens=True)
        generations[example["id"]] = predicted_text
        input_task[example["id"]] = example["input"]
        output_task[example["id"]] = example["output"]
        #break

    out_file_path_pred = os.path.join(generations_dir, f"{dataset}.json")
    with open(out_file_path_pred, 'w') as f_out:
        json.dump(generations, f_out, indent=4)
    


    print(f"Done generating {len(generations)} examples from {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time at end: {time}")
    print(f"Look for predictions in {generations_dir}")

In [None]:

    out_file_path_pred = os.path.join(generations_dir, f"{dataset}.json")
    with open(out_file_path_pred, 'w') as f_out:
        json.dump(generations, f_out, indent=4)
    


    print(f"Done generating {len(generations)} examples from {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time at end: {time}")
    print(f"Look for predictions in {generations_dir}")