In [48]:
import time
import json
import datetime
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

In [2]:
def format_time(elapsed: float) -> str:
    """Takes a time in seconds and formats it to hh:mm:ss:ms.

    Args:
        elapsed (float): Time period elapsed in seconds.

    Returns:
        str: Elapsed period formatted as a string including milliseconds.
    """
    elapsed_timedelta = datetime.timedelta(seconds=elapsed)
    hrs, remain = divmod(elapsed_timedelta.total_seconds(), 3600)
    mins, secs = divmod(remain, 60)
    ms = int((secs - int(secs)) * 1000)
    time = "{:02}:{:02}:{:02}.{:03}".format(int(hrs), int(mins), int(secs), ms)
    return time

In [3]:
device = "cuda" # or "cpu"
model_path = "ibm-granite/granite-8b-code-instruct"

In [4]:
config = AutoConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.16s/it]


### Basic Example

In [5]:
chat = [
    { "role": "user", "content": "Write a code to find the maximum value in a list of numbers." },
]
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-35): 36 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=True)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=True)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=True)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=True)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=True)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=True)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=True)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (l

In [6]:
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_tokens = tokenizer(chat, return_tensors="pt")
for i in input_tokens:
    input_tokens[i] = input_tokens[i].to(device)

In [7]:
start = time.time()
output = model.generate(**input_tokens, max_new_tokens=100)
print(f"Generation took: {format_time(time.time() - start)} (hh:mm:ss:ms)")

Generation tooK: 00:00:08.434


In [8]:
output = tokenizer.batch_decode(output)
for i in output:
    print(i)

Question:
Write a code to find the maximum value in a list of numbers.

Answer:
```python
def find_max(numbers):
    max_value = numbers[0]
    for num in numbers:
        if num > max_value:
            max_value = num
    return max_value
```<|endoftext|>


### Passkey Example (Tests Recall from Ctx Window)

In [17]:
def get_passkey_prompts(file_name):
    prompts = []
    with open(file_name, 'r') as file:
        for line in file:
            example = json.loads(line)
        
            chat = [
                {"role": "system", "content": example["input"]},
                {"role": "user", "content": "What is the pass key?"}
            ]
            

            prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
            
            prompts.append(prompt)
    
    return prompts

In [51]:
def check_model_response(decoded_response, expected_ans): 
    try: 
        model_answer = decoded_response.split("Answer:\nThe pass key is")[1].strip()
        predicted = model_answer.split('.')[0]
    except IndexError: 
        print("Error: Couldn't find model response or format is incorrect")
        # print(f"Model Answer: {decoded_response}")
        return False
    return int(predicted) == expected_ans

In [19]:
ps_json_path = "passkey_examples.jsonl"
passkey_prompts = get_passkey_prompts(ps_json_path)

#### Short Prompt (fits in default ctx window, < 4K input tokens)

In [20]:
ps_prompt_ex = passkey_prompts[0]
input_tokens = tokenizer(ps_prompt_ex, return_tensors="pt")
input_tokens['input_ids'].shape

torch.Size([1, 3769])

In [24]:
for i in input_tokens:
    input_tokens[i] = input_tokens[i].to(device)

In [25]:
start = time.time()
output = model.generate(**input_tokens, max_new_tokens=100)
print(f"Generation took: {format_time(time.time() - start)} (hh:mm:ss:ms)")

Generation took: 00:00:04.238 (hh:mm:ss:ms)


In [26]:
output = tokenizer.batch_decode(output)
for i in output:
    print(i)

System:
There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information there.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
The grass is green. The sky is blue

In [48]:
check_model_response("".join(output), 72498)

True

### Config change to dynamic RoPE scaling

#### Long Prompt (longer than ctx window, 32K+ input tokens)

In [42]:
long_prompt = passkey_prompts[-1]
input_tokens = tokenizer(long_prompt, return_tensors="pt")
input_tokens['input_ids'].shape

torch.Size([1, 33289])

In [52]:
# values to test
theta_values = [50000, 80000, 160000, 320000]
scaling_factors = [1.0, 2.0, 4.0, 8.0]
results_df = pd.DataFrame(columns=["Theta", "Scaling Factor", "Correct", "Response Time", "Model Output"])

In [54]:
for theta in theta_values:
    for scaling_factor in scaling_factors:
        config = AutoConfig.from_pretrained(model_path)
        setattr(config, 'rope_scaling', {'type': 'dynamic', 'factor': scaling_factor})
        setattr(config, 'rope_theta', theta)
        model = AutoModelForCausalLM.from_pretrained(model_path, config=config, device_map=device)
        
        input_tokens = tokenizer(long_prompt, return_tensors="pt") # using longest prompt from passkey json (33K+ length)
        for i in input_tokens:
            input_tokens[i] = input_tokens[i].to(device)
        start = time.time()
        output = model.generate(**input_tokens, max_new_tokens=100)
        duration = time.time() - start
        response = tokenizer.batch_decode(output)

        expected_ans = 1127250844  # target passkey for long prompt
        is_correct = check_model_response("".join(response), expected_ans)
        new_row = pd.DataFrame({
                "Theta": [theta],
                "Scaling Factor": [scaling_factor],
                "Correct": [is_correct],
                "Response Time": [duration],
                "Model Output": [response]
            })
        results_df = pd.concat([results_df, new_row], ignore_index=True)

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]
  results_df = pd.concat([results_df, new_row], ignore_index=True)


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.24s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.23s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.22s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.23s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.21s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.21s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.21s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.25s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.23s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.25s/it]


Error: Couldn't find model response or format is incorrect


Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:08<00:00,  2.24s/it]


Error: Couldn't find model response or format is incorrect


In [55]:
results_df

Unnamed: 0,Theta,Scaling Factor,Correct,Response Time,Model Output
0,50000,1.0,False,42.560839,[System:\nThere is an important info hidden in...
1,50000,2.0,False,67.333625,[System:\nThere is an important info hidden in...
2,50000,4.0,False,67.353889,[System:\nThere is an important info hidden in...
3,50000,8.0,False,67.348533,[System:\nThere is an important info hidden in...
4,80000,1.0,False,67.36725,[System:\nThere is an important info hidden in...
5,80000,2.0,False,67.325441,[System:\nThere is an important info hidden in...
6,80000,4.0,False,67.347181,[System:\nThere is an important info hidden in...
7,80000,8.0,False,67.366786,[System:\nThere is an important info hidden in...
8,160000,1.0,False,67.360672,[System:\nThere is an important info hidden in...
9,160000,2.0,False,67.332454,[System:\nThere is an important info hidden in...


In [56]:
results_df.to_csv("RoPE_hparams_results.csv", index=False)