In [1]:
import json
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
import sys
from datetime import datetime
import random
import gc

import numpy as np
import torch
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import set_seed as hf_set_seed
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
lm_infinite_root = os.path.abspath('..')

# Prepend it so Python finds 'models/llama.py' first
sys.path.insert(0, lm_infinite_root)

# Now you can import the converter
from models.llama import convert_llama_model

In [3]:
def llama3_prompt(user_message):
    BEGIN = "<|begin_of_text|>"
    START = "<|start_header_id|>"
    END = "<|end_header_id|>"
    EOT = "<|eot_id|>"

    system_prompt = (
        "Always follow the task instruction carefully."
        "The first paragraph before the first double line break contains the task instruction."
        "Generate text as a natural continuation of the user message."
        "Do not include any meta-commentary or explanations or your own thoughts."
    )

    prompt = (
        f"{BEGIN}"
        f"{START}system{END}\n\n{system_prompt}{EOT}\n"
        f"{START}user{END}\n\n{user_message}{EOT}\n"
        f"{START}assistant{END}\n\n"
    )
    return prompt


model_to_chat_template = {
    "/assets/models/meta-llama-3.2-instruct-3b": llama3_prompt 
}

In [4]:
datasets =["gov_report", "summ_screen_fd", "qmsum", "qasper","narrative_qa", "quality"]

In [5]:
model_to_max_input_tokens = 4096

In [6]:
def trim_doc_keeping_suffix(tokenizer, tokenized_input_full, example, suffix_index, max_tokens, device):
    seperator_and_suffix = f"{example['truncation_seperator'].strip()}\n\n{example['input'][suffix_index:].strip()}\n"
    tokenized_seperator_and_suffix = tokenizer(seperator_and_suffix, return_tensors="pt").input_ids.to(device)
    tokenized_input_trimmed = tokenized_input_full[:, :max_tokens - tokenized_seperator_and_suffix.shape[1]]
    tokenized_input = torch.cat([tokenized_input_trimmed, tokenized_seperator_and_suffix], dim=1)
    return tokenized_input

In [7]:
model_name = "/assets/models/meta-llama-3.2-instruct-3b"
model_print_name = "llama-basic_4096"
max_examples_per_task = -1


In [8]:
def process_model_input(tokenizer, example, max_tokens, device):
    instruction = example["input"][:example['document_start_index']]
    truncation_seperator = example['truncation_seperator']

    query = example["input"][example['query_start_index']:]
    if len(query) == 0:
        query = None
    doc = example["input"][example['document_start_index']
        :example['document_end_index']]
    
    input_text = f"{instruction}{doc}{truncation_seperator}{query or ''}"
    input = model_to_chat_template.get(model_name, lambda x: x)(input_text)
    # print(f"Input: {input}")
    tokenized_input = tokenizer(
        input, return_tensors="pt").input_ids.to(device)

    return tokenized_input

In [9]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"

generations_dir = "generations/ipynb"
seed = 43
random.seed(seed)
np.random.seed(seed)
hf_set_seed(seed)
print("Params:")
print(f"model: {model_name}")
generations_dir = os.path.join(generations_dir, model_print_name.replace("/", "_"))
print(f"generations_dir: {generations_dir}")
print(f"max_examples_per_task: {max_examples_per_task}")
print("=" * 50)
time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
print(f"time as start: {time}")

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Loading model: {model_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"

max_input_length = model_to_max_input_tokens

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    use_flash_attention_2=False,
    trust_remote_code=True,
)
torch.cuda.empty_cache()
gc.collect()
model.past_key_values = None

#model = convert_llama_model(model, local_branch=8192, global_branch=256, safe_mode=False)


Params:
model: /assets/models/meta-llama-3.2-instruct-3b
generations_dir: generations/ipynb/llama-basic_4096
max_examples_per_task: -1
time as start: 27_04_2025_21_40_14
Loading tokenizer
Loading model: /assets/models/meta-llama-3.2-instruct-3b


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


In [10]:
model = model.eval()

print(f"{model} model loaded!, device:{model.device}")

print("Will write to:", generations_dir)
os.makedirs(generations_dir, exist_ok=True)
for dataset in datasets:
    generations = dict()
    input_task = dict()
    output_task = dict()
    print(f"Processing {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time as start {dataset}: {time}")
    print(f"Loading {dataset}")
    data = load_dataset("tau/zero_scrolls", dataset, cache_dir="/home/athul/datasets_cache")
    print(f"Loaded {dataset}")

    for i, example in tqdm(enumerate(data["validation"])):
        print("Processing example:", example["id"])

        if 0 < max_examples_per_task == i:
            print(f"Reached {max_examples_per_task} for {dataset}. Breaking")
            break

        model_input = process_model_input(tokenizer, example, max_input_length, device)

        prediction_token_ids = model.generate(model_input,
                                                  max_new_tokens=512,
                                                  do_sample=False,
                                                  top_p=0,
                                                  top_k=0,
                                                  temperature=1,
                                                  pad_token_id=tokenizer.eos_token_id, )
        model.past_key_values = None
        torch.cuda.empty_cache()
        gc.collect()

        predicted_text = tokenizer.decode(prediction_token_ids[0][model_input.shape[1]:], skip_special_tokens=True)
        generations[example["id"]] = predicted_text
        input_task[example["id"]] = example["input"]
        output_task[example["id"]] = example["output"]
        #break

    out_file_path_pred = os.path.join(generations_dir, f"{dataset}.json")
    with open(out_file_path_pred, 'w') as f_out:
        json.dump(generations, f_out, indent=4)
    


    print(f"Done generating {len(generations)} examples from {dataset}")
    time = datetime.now().strftime("%d_%m_%Y_%H_%M_%S")
    print(f"time at end: {time}")
    print(f"Look for predictions in {generations_dir}")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 3072)
    (layers): ModuleList(
      (0-27): 28 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((3072,), eps=1e-05)
    (rotary_emb

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Processing example: crs_R45461


1it [00:17, 17.52s/it]

Processing example: crs_R44668


2it [00:29, 14.37s/it]

Processing example: crs_R45546


3it [00:45, 15.21s/it]

Processing example: crs_R45732


4it [01:01, 15.36s/it]

Processing example: crs_R45237


5it [01:11, 13.49s/it]

Processing example: crs_RS21212


6it [01:24, 13.24s/it]

Processing example: crs_RL30478


7it [01:56, 19.26s/it]

Processing example: crs_RL32665


8it [02:23, 21.72s/it]

Processing example: crs_RS22373


9it [02:59, 26.22s/it]

Processing example: gao_GAO-19-207


10it [03:10, 21.59s/it]

Processing example: gao_GAO-18-486


11it [03:28, 20.53s/it]

Processing example: gao_GAO-19-220


12it [03:37, 16.88s/it]

Processing example: gao_GAO-19-148


13it [03:48, 15.17s/it]

Processing example: gao_GAO-18-271


14it [04:07, 16.38s/it]

Processing example: gao_GAO-18-78


15it [04:23, 16.38s/it]

Processing example: gao_GAO-18-420


16it [04:39, 16.09s/it]

Processing example: gao_GAO-17-781T


17it [04:49, 14.27s/it]

Processing example: gao_GAO-18-249


18it [05:03, 14.27s/it]

Processing example: gao_GAO-18-7


19it [05:23, 16.06s/it]

Processing example: gao_GAO-18-403


20it [05:34, 16.74s/it]


Done generating 20 examples from gov_report
time at end: 27_04_2025_21_45_54
Look for predictions in generations/ipynb/llama-basic_4096
Processing summ_screen_fd
time as start summ_screen_fd: 27_04_2025_21_45_54
Loading summ_screen_fd
Loaded summ_screen_fd


0it [00:00, ?it/s]

Processing example: fd_FRIENDS_07x04


1it [00:04,  4.78s/it]

Processing example: fd_Buffy_the_Vampire_Slayer_04x14


2it [00:23, 12.80s/it]

Processing example: fd_The_Office_05x24


3it [00:28,  9.22s/it]

Processing example: fd_Angel_03x07


4it [00:35,  8.34s/it]

Processing example: fd_The_Office_03x20


5it [00:39,  6.77s/it]

Processing example: fd_Doctor_Who_1963_14x21


6it [00:44,  6.37s/it]

Processing example: fd_Frasier_07x04


7it [00:50,  6.21s/it]

Processing example: fd_Frasier_10x16


8it [00:55,  5.67s/it]

Processing example: fd_Justified_06x12


9it [01:01,  5.86s/it]

Processing example: fd_Charmed_06x13


10it [01:07,  5.92s/it]

Processing example: fd_Pretty_Little_Liars_01x03


11it [01:13,  6.05s/it]

Processing example: fd_Buffy_the_Vampire_Slayer_06x14


12it [01:23,  7.05s/it]

Processing example: fd_Justified_04x01


13it [01:27,  6.31s/it]

Processing example: fd_Doctor_Who_01x02


14it [01:36,  7.04s/it]

Processing example: fd_The_Office_08x23


15it [01:41,  6.46s/it]

Processing example: fd_Justified_06x11


16it [01:52,  7.82s/it]

Processing example: fd_CSI__Crime_Scene_Investigation_06x02


17it [02:00,  7.96s/it]

Processing example: fd_Alias_01x04


18it [02:06,  7.25s/it]

Processing example: fd_The_Vampire_Diaries_01x11


19it [02:12,  6.76s/it]

Processing example: fd_The_O.C._03x01


20it [02:28,  7.44s/it]


Done generating 20 examples from summ_screen_fd
time at end: 27_04_2025_21_48_26
Look for predictions in generations/ipynb/llama-basic_4096
Processing qmsum
time as start qmsum: 27_04_2025_21_48_26
Loading qmsum
Loaded qmsum


0it [00:00, ?it/s]

Processing example: va-sq-11


1it [00:02,  2.30s/it]

Processing example: va-sq-50


2it [00:09,  5.09s/it]

Processing example: va-sq-57


3it [00:12,  4.35s/it]

Processing example: va-sq-58


4it [00:21,  6.21s/it]

Processing example: va-gq-66


5it [00:28,  6.24s/it]

Processing example: va-gq-93


6it [00:39,  7.91s/it]

Processing example: va-sq-108


7it [00:48,  8.48s/it]

Processing example: va-sq-117


8it [00:57,  8.53s/it]

Processing example: va-sq-125


9it [01:02,  7.24s/it]

Processing example: va-sq-145


10it [01:03,  5.56s/it]

Processing example: va-gq-166


11it [01:13,  6.72s/it]

Processing example: va-sq-179


12it [01:16,  5.66s/it]

Processing example: va-sq-184


13it [01:23,  6.19s/it]

Processing example: va-sq-186


14it [01:31,  6.75s/it]

Processing example: va-sq-216


15it [01:33,  5.27s/it]

Processing example: va-sq-219


16it [01:35,  4.08s/it]

Processing example: va-sq-225


17it [01:47,  6.63s/it]

Processing example: va-sq-250


18it [01:54,  6.79s/it]

Processing example: va-sq-258


19it [01:59,  6.14s/it]

Processing example: va-sq-264


20it [02:03,  6.17s/it]


Done generating 20 examples from qmsum
time at end: 27_04_2025_21_50_32
Look for predictions in generations/ipynb/llama-basic_4096
Processing qasper
time as start qasper: 27_04_2025_21_50_32
Loading qasper
Loaded qasper


0it [00:00, ?it/s]

Processing example: 3fad42be0fb2052bb404b989cc7d58b440cd23a0


1it [00:00,  2.11it/s]

Processing example: 3fad42be0fb2052bb404b989cc7d58b440cd23a0


2it [00:00,  2.15it/s]

Processing example: 8bf7f1f93d0a2816234d36395ab40c481be9a0e0


3it [00:01,  1.57it/s]

Processing example: 8bf7f1f93d0a2816234d36395ab40c481be9a0e0


4it [00:02,  1.38it/s]

Processing example: 0f12dc077fe8e5b95ca9163cea1dd17195c96929


5it [00:04,  1.04s/it]

Processing example: 0f12dc077fe8e5b95ca9163cea1dd17195c96929


6it [00:05,  1.22s/it]

Processing example: 518dae6f936882152c162058895db4eca815e649


7it [00:06,  1.14s/it]

Processing example: 58ef2442450c392bfc55c4dc35f216542f5f2dbb


8it [00:07,  1.03s/it]

Processing example: 58ef2442450c392bfc55c4dc35f216542f5f2dbb


9it [00:08,  1.05it/s]

Processing example: 290ee79b5e3872e0496a6a0fc9b103ab7d8f6c30


10it [00:09,  1.05it/s]

Processing example: ab9b0bde6113ffef8eb1c39919d21e5913a05081


11it [00:10,  1.06it/s]

Processing example: ff338921e34c15baf1eae0074938bf79ee65fdd2


12it [00:12,  1.41s/it]

Processing example: 1b1a30e9e68a9ae76af467e60cefb180d135e285


13it [00:13,  1.23s/it]

Processing example: dea9e7fe8e47da5e7f31d9b1a46ebe34e731a596


14it [00:14,  1.14s/it]

Processing example: dea9e7fe8e47da5e7f31d9b1a46ebe34e731a596


15it [00:15,  1.07s/it]

Processing example: 3355918bbdccac644afe441f085d0ffbbad565d7


16it [00:19,  1.90s/it]

Processing example: d9980676a83295dda37c20cfd5d58e574d0a4859


17it [00:22,  2.42s/it]

Processing example: d9980676a83295dda37c20cfd5d58e574d0a4859


18it [00:26,  2.80s/it]

Processing example: 79a44a68bb57b375d8a57a0a7f522d33476d9f33


19it [00:27,  2.20s/it]

Processing example: 79a44a68bb57b375d8a57a0a7f522d33476d9f33


20it [00:28,  1.78s/it]

Processing example: 76ed74788e3eb3321e646c48ae8bf6cdfe46dca1


21it [00:33,  2.76s/it]

Processing example: 8e52637026bee9061f9558178eaec08279bf7ac6


22it [00:33,  2.15s/it]

Processing example: 8e52637026bee9061f9558178eaec08279bf7ac6


23it [00:34,  1.73s/it]

Processing example: 3116453e35352a3a90ee5b12246dc7f2e60cfc59


24it [00:36,  1.76s/it]

Processing example: 3116453e35352a3a90ee5b12246dc7f2e60cfc59


25it [00:38,  1.78s/it]

Processing example: 4e748cb2b5e74d905d9b24b53be6cfdf326e8054


26it [00:38,  1.42s/it]

Processing example: b970f48d30775d3468952795bc72976baab3438e


27it [00:40,  1.58s/it]

Processing example: c70bafc35e27be9d1efae60596bc0dd390c124c0


28it [00:41,  1.48s/it]


Done generating 19 examples from qasper
time at end: 27_04_2025_21_51_15
Look for predictions in generations/ipynb/llama-basic_4096
Processing narrative_qa
time as start narrative_qa: 27_04_2025_21_51_15
Loading narrative_qa
Loaded narrative_qa


0it [00:00, ?it/s]

Processing example: 3858


1it [00:05,  5.26s/it]

Processing example: 3858


2it [00:10,  5.26s/it]

Processing example: 5947


3it [00:22,  8.27s/it]

Processing example: 5947


4it [00:34,  9.69s/it]

Processing example: 12164


5it [01:14, 20.83s/it]

Processing example: 12164


6it [01:55, 27.56s/it]

Processing example: 23148


7it [01:58, 19.64s/it]

Processing example: 23148


8it [02:02, 14.44s/it]

Processing example: 25278


9it [02:03, 10.33s/it]

Processing example: 25278


10it [02:04,  7.54s/it]

Processing example: 33068


11it [02:21, 10.40s/it]

Processing example: 33068


12it [02:38, 12.37s/it]

Processing example: 35134


13it [02:42,  9.72s/it]

Processing example: 35134


14it [02:45,  7.88s/it]

Processing example: 38322


15it [02:47,  6.17s/it]

Processing example: 38322


16it [02:50,  4.98s/it]

Processing example: 42586


17it [02:52,  4.15s/it]

Processing example: 42586


18it [02:54,  3.57s/it]

Processing example: 16886


19it [03:03,  5.05s/it]

Processing example: 16886


20it [03:11,  9.58s/it]


Done generating 10 examples from narrative_qa
time at end: 27_04_2025_21_54_29
Look for predictions in generations/ipynb/llama-basic_4096
Processing quality
time as start quality: 27_04_2025_21_54_29
Loading quality
Loaded quality


0it [00:00, ?it/s]

Processing example: 62139_J05FWZR6_1


1it [00:00,  1.36it/s]

Processing example: 52855_MV65I88C_9


2it [00:02,  1.30s/it]

Processing example: 62085_C1SL2YBE_3


3it [00:03,  1.01s/it]

Processing example: 63616_MQ1O9T2Q_6


4it [00:03,  1.10it/s]

Processing example: 63833_V187YO4H_2


5it [00:04,  1.20it/s]

Processing example: 63392_7YS4HHFI_6


6it [00:05,  1.26it/s]

Processing example: 63473_1VIHQ8TY_4


7it [00:06,  1.27it/s]

Processing example: 51650_B3KKWWD1_7


8it [00:06,  1.26it/s]

Processing example: 51274_8Q2YNHG5_6


9it [00:07,  1.25it/s]

Processing example: 20077_ZF5G55FD_1


10it [00:08,  1.48it/s]

Processing example: 22579_RQ3GB4A1_3


11it [00:08,  1.39it/s]

Processing example: 22867_TJ9SPIHC_9


12it [00:09,  1.30it/s]

Processing example: 22875_L821878U_6


13it [00:10,  1.49it/s]

Processing example: 22967_0XT2L7PI_7


14it [00:10,  1.48it/s]

Processing example: 22867_IZGAWLCJ_4


15it [00:11,  1.36it/s]

Processing example: 22462_BUA2LH2S_5


16it [00:12,  1.34it/s]

Processing example: 31736_TV0CUXDH_4


17it [00:13,  1.35it/s]

Processing example: 99927_EVLEI3Q2_6


18it [00:13,  1.48it/s]

Processing example: 31282_BQYW9TCH_4


19it [00:14,  1.37it/s]

Processing example: 99914_0Q5X8VEX_4


20it [00:15,  1.58it/s]

Processing example: 32665_VRYQXG3Y_9


21it [00:15,  1.35it/s]

Done generating 21 examples from quality
time at end: 27_04_2025_21_54_47
Look for predictions in generations/ipynb/llama-basic_4096



