## Imports

In [12]:
import warnings
warnings.filterwarnings("ignore")
import sys
import os
sys.path.append(os.path.abspath(os.path.dirname(os.getcwd())))
import time
import json
import datetime
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import SelfExtend

In [2]:
def format_time(elapsed: float) -> str:
    """Takes a time in seconds and formats it to hh:mm:ss:ms.

    Args:
        elapsed (float): Time period elapsed in seconds.

    Returns:
        str: Elapsed period formatted as a string including milliseconds.
    """
    elapsed_timedelta = datetime.timedelta(seconds=elapsed)
    hrs, remain = divmod(elapsed_timedelta.total_seconds(), 3600)
    mins, secs = divmod(remain, 60)
    ms = int((secs - int(secs)) * 1000)
    time = "{:02}:{:02}:{:02}.{:03}".format(int(hrs), int(mins), int(secs), ms)
    return time

## Setup

### Load Model

In [3]:
device = "cuda" # or "cpu"
model_path = "ibm-granite/granite-8b-code-instruct"

In [5]:
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:51<00:00, 12.76s/it]


### Self-Extend

In [6]:
def get_hparams(L, N):
    """
    Determine all possible neighbor window size W and group size G based on empirical rule.

    Args:
        L (int): Pretraining context window.
        N (int): Target extension length.

    Returns:
        list[tuple[int, int]]: List of tuples, where each tuple contains a valid W and G.
    """
    alpha_range = [0.5, 2/3]
    valid_hparams = []

    for alpha in alpha_range:
        max_allowed = alpha * L

        for W in range(1, L+1):
            if W == max_allowed:
                continue
            G = (N - W) / (max_allowed - W)
            if G > 0 and G == int(G):
                valid_hparams.append((W, int(G)))

    if not valid_hparams:
        raise ValueError("No valid combination of W and G found")

    return valid_hparams

In [7]:
L = 4096
N = 32_768 * 2

2-64 are reasonable for `group_size`; 512-1536 are feasible for `neighbor_window`. But larger `group_size` and smaller `neighbor_window` are also good in many cases.

In [8]:
wg_pairs = get_hparams(L, N)
wg_pairs

[(64, 33),
 (1024, 63),
 (1056, 65),
 (1536, 125),
 (1552, 129),
 (1792, 249),
 (1800, 257),
 (1920, 497),
 (1924, 513),
 (1984, 993),
 (1986, 1025),
 (2016, 1985),
 (2017, 2049),
 (2032, 3969),
 (2040, 7937),
 (2044, 15873),
 (2046, 31745),
 (2047, 63489),
 (768, 33)]

In [9]:
window_size = 1024
group_size = 63
use_flash = True

In [10]:
SelfExtend.apply(model, group_size, window_size, enable_flash_attention=use_flash, flash_attention_impl="flash_attn")

Using flash_attn flash self_extend!!


### Passkey Setup for NIAH

In [11]:
def get_passkey_prompts(file_name):
    prompts = []
    with open(file_name, 'r') as file:
        for line in file:
            example = json.loads(line)
        
            chat = [
                {"role": "system", "content": example["input"]},
                {"role": "user", "content": "What is the pass key?"}
            ]
            

            prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
            
            prompts.append(prompt)
    
    return prompts

In [12]:
def get_passkey_targets(file_name):
    targets = []
    positions = []
    with open(file_name, 'r') as file:
        for line in file:
            example = json.loads(line)
            targets.append(example['target'])
            positions.append(example['passkey_position'])
    return targets, positions

In [13]:
def check_model_response(decoded_response, expected_ans): 
    try: 
        model_answer = decoded_response.split("Answer:\nThe pass key is")[1].strip()
        predicted = model_answer.split('.')[0]
    except IndexError: 
        print("Error: Couldn't find model response or format is incorrect")
        # print(f"Model Answer: {decoded_response}")
        return False
    return int(predicted) == int(expected_ans)

In [14]:
ps_json_path = "passkey_examples2.jsonl"
passkey_prompts = get_passkey_prompts(ps_json_path)
passkey_targets, passkey_positions = get_passkey_targets(ps_json_path)

## NIAH using Passkey Prompts

In [15]:
tokenized_prompts = []
for prompt, pos, target in zip(passkey_prompts, passkey_positions, passkey_targets): 
    input_tokens = tokenizer(prompt, return_tensors="pt")
    tokenized_prompts.append(input_tokens)
    print(f"Prompt Length: {input_tokens['input_ids'].shape[1]:,}; Target value (hidden in context @ position {pos:,}): {int(target):,}")

Prompt Length: 3,769; Target value (hidden in context @ position 3,457): 72,498
Prompt Length: 5,355; Target value (hidden in context @ position 3,457): 89,427
Prompt Length: 10,399; Target value (hidden in context @ position 5,907): 58,328
Prompt Length: 33,289; Target value (hidden in context @ position 17,707): 1,127,250,844
Prompt Length: 129,773; Target value (hidden in context @ position 17,707): 123,456,789


In [16]:
answers = []
for input_tokens, pos, target in zip(tokenized_prompts, passkey_positions, passkey_targets): 
    for i in input_tokens:
        input_tokens[i] = input_tokens[i].to(device)
    start = time.time()
    output = model.generate(**input_tokens, max_new_tokens=100)
    print(f"Generation took: {format_time(time.time() - start)} (hh:mm:ss:ms)")
    length = input_tokens['input_ids'].shape[1]
    del input_tokens
    response = tokenizer.batch_decode(output)
    found = check_model_response("".join(response), target)
    if found: 
        print(f"\tModel recalled passkey ({int(target):,}) from position {pos:,} in a prompt of {length:,} length!")
    answers.append(response)
    torch.cuda.empty_cache()
    

Generation took: 00:00:06.803 (hh:mm:ss:ms)
	Model recalled passkey (72,498) from position 3,457 in a prompt of 3,769 length!


This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


Generation took: 00:00:01.251 (hh:mm:ss:ms)
	Model recalled passkey (89,427) from position 3,457 in a prompt of 5,355 length!
Generation took: 00:00:01.945 (hh:mm:ss:ms)
	Model recalled passkey (58,328) from position 5,907 in a prompt of 10,399 length!
Generation took: 00:00:07.618 (hh:mm:ss:ms)
	Model recalled passkey (1,127,250,844) from position 17,707 in a prompt of 33,289 length!
Generation took: 00:00:43.580 (hh:mm:ss:ms)
Error: Couldn't find model response or format is incorrect


In [33]:
for i, ans in enumerate(answers): 
    print(f"Prompt {i+1} answer:")
    print(ans[0][-85:])
    print('-' * 50)

Prompt 1 answer:
 again.

Question:
What is the pass key?

Answer:
The pass key is 72498.<|endoftext|>
--------------------------------------------------
Prompt 2 answer:
 again.

Question:
What is the pass key?

Answer:
The pass key is 89427.<|endoftext|>
--------------------------------------------------
Prompt 3 answer:
 again.

Question:
What is the pass key?

Answer:
The pass key is 58328.<|endoftext|>
--------------------------------------------------
Prompt 4 answer:
n.

Question:
What is the pass key?

Answer:
The pass key is 1127250844.<|endoftext|>
--------------------------------------------------
Prompt 5 answer:
again.

Question:
What is the pass key?

Answer:
The answer is: 12345678<|endoftext|>
--------------------------------------------------
