In [2]:
import time
import torch
from transformers import AutoTokenizer
from utils.utils import load_json

## Prepare text and embeddings

In [3]:
device="mps"
base_model_dir = "./Phi-3-mini-4k-instruct"
config = load_json(f"./{base_model_dir}/config.json")

In [4]:
tokenizer = AutoTokenizer.from_pretrained(base_model_dir, clean_up_tokenization_spaces=False)
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
def text_to_ids(text, tokenizer):
    if type(text) != list: text = [text]
    input_ids = tokenizer(
        text,
    )["input_ids"]
    return input_ids

## Forward passes

In [6]:
from ops.transformer_ops import Transformer
from ops.generation import generate_text, generate_text_stream

In [7]:
model_dir = "PHI3-MINI-4K-PKL-int8"
model = Transformer(model_dir, config, device=device)

In [8]:
texts = ["<|user|>I am going to Paris, what should I see?<|end|><|assistant|>"]
input_ids = text_to_ids(texts, tokenizer)

In [9]:
streaming=True

In [10]:
if streaming:
    output_text = []
    total_tokens_count = 0
    start_time = time.time()
    for word, n_tokens in generate_text_stream(model, tokenizer, input_ids, max_gen_len=128, stop_tokens_ids=terminators):
        print(f"{word}", end='', flush=True)
        output_text.append(word)
        total_tokens_count += n_tokens
    delta_time = time.time() - start_time
    output_text = "".join(output_text)
    print()
else:  
    start_time = time.time()
    outputs, total_tokens_count = generate_text(model, tokenizer, input_ids, max_gen_len=128, stop_tokens_ids=terminators)
    delta_time = time.time() - start_time
    print([text+output for text, output in zip(texts, outputs)])
print(f"Generation took {delta_time} seconds, {total_tokens_count/delta_time} tokens/s.")

torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96])
torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 15, 96]) torch.Size([1, 32, 1

KeyboardInterrupt: 