In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# args
checkpoint = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# checkpoint = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8"
assistant_checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

# model
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map='auto')
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint, device_map='auto')

# overwrite _assisted_decoding
import types
from generation_utils import assisted_decoding
model._assisted_decoding = types.MethodType(assisted_decoding, model)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


In [None]:

# input
messages = [
    {'role': 'user', 'content': 'Answer step by step: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?'},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# generate output
outputs = model.generate(**inputs, assistant_model=assistant_model,
                         return_dict_in_generate=True, max_length=200,
                         pad_token_id=tokenizer.eos_token_id)

output, token_dict = outputs.sequences, outputs.token_dict
# print(tokenizer.decode(output[0], skip_special_tokens=False))

# highlight tokens
def highlight_tokens(self, token_dict):
    highlight_accept = lambda text: f"<span style='color: orange;'>{text}</span>"
    highlight_reject = lambda text: f"<span style='color: gray; text-decoration: line-through;'>{text}</span>"

    # Improved decode function that handles newlines correctly
    def decode(token):
        if len(token) == 0:
            return ""
        text = self.tokenizer.decode(token[0], skip_special_tokens=False)
        # Replace newlines with HTML break tags
        text = text.replace('\n', '<br>')
        return text

    output_text = ""
    for accept_tokens, reject_tokens, next_token in zip(
        token_dict["accept_tokens"], token_dict["reject_tokens"], token_dict["next_token"]):
        output_text += highlight_accept(decode(accept_tokens))
        output_text += highlight_reject(decode(reject_tokens))
        output_text += decode(next_token)

    # Wrap in a div with white-space: pre-wrap to preserve other whitespace
    output_text = f"<div style='white-space: pre-wrap;'>{output_text}</div>"

    return output_text

text = highlight_tokens(token_dict)
from IPython.display import display, HTML
display(HTML(text))

In [6]:
print(tokenizer.decode(output[0], skip_special_tokens=False))

<|begin_of_text|><|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

Answer step by step: Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

To find out how much Weng earned, we need to first convert the time from minutes to hours, since her hourly rate is given in dollars per hour. 

1 hour = 60 minutes
50 minutes = 50 / 60 hours
50 minutes = 5/6 hour

Now, we multiply the time worked in hours by the hourly rate to find the total earnings.
Weng's hourly rate = $12
Time worked = 5/6 hour
Total earnings = Hourly rate * Time worked
Total earnings = $12 * 5/6
Total earnings = $10

So, Weng earned $10


In [5]:
print(text)

<span style='color: orange; font-size: 16px;'>To find out how much Weng earned, we need to</span><span style='color: gray; font-size: 16px; text-decoration: line-through;'> multiply the number of hours she worked (</span><span style='color: black; font-size: 16px;'> first</span><span style='color: orange; font-size: 16px;'> convert the</span><span style='color: gray; font-size: 16px; text-decoration: line-through;'> 50 minutes into hours. 

There are 60 minutes in an hour, so </span><span style='color: black; font-size: 16px;'> time</span><span style='color: orange; font-size: 16px;'> from minutes to hours</span><span style='color: gray; font-size: 16px; text-decoration: line-through;'>. 

There are 60 minutes in an hour, so 50 minutes is</span><span style='color: black; font-size: 16px;'>,</span><span style='color: orange; font-size: 16px;'> since her hourly</span><span style='color: gray; font-size: 16px; text-decoration: line-through;'> wage is given in dollars per hour. 

There are