In [2]:
!pip install flash-attn transformers accelerate termcolor altair

import time
from datetime import timedelta

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from transformers.utils import is_flash_attn_2_available

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
    # attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer, skip_prompt=True)

print("flash_attn_2 available:", is_flash_attn_2_available())



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


flash_attn_2 available: True


In [4]:
def gen(text, preview=True):
    duration_start = time.perf_counter()
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        tokens,
        max_new_tokens=1024,
        return_dict_in_generate=True,
        streamer=streamer if preview else None,
    )
    output_tokens = outputs.sequences[0]
    output_gen_tokens = output_tokens[
        len(tokens[0]) : -1
    ]  # From just after prompt to just before <|end|> token
    output_string = tokenizer.decode(output_gen_tokens)
    duration_seconds = time.perf_counter() - duration_start
    if preview:
        print(
            "== took {} ({} toks: {}/tok; {} tps) ==".format(
                timedelta(seconds=duration_seconds),
                len(output_gen_tokens),
                timedelta(seconds=duration_seconds / len(output_gen_tokens)),
                len(output_gen_tokens) / duration_seconds,
            )
        )
        print()
    del tokens, outputs, output_tokens, output_gen_tokens
    return output_string


def embed(text, mean_layers=False, mean_tokens=False, prompt_prefix=""):
    duration_start = time.perf_counter()
    if prompt_prefix:
        prompt = "<|user|>\n{}\n```\n{}\n``` <|end|>\n<|assistant|>".format(
            prompt_prefix, text
        )
    else:
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model(tokens, output_hidden_states=True)
    embedding = outputs.hidden_states
    # print(len(embedding), embedding[0].shape)
    if mean_layers:
        # print(torch.stack(embedding).shape)
        embedding = torch.stack(embedding).mean(dim=0)  # Mean layers
    else:
        embedding = embedding[-1]  # Take last layer

    if mean_tokens:
        embedding = embedding.mean(dim=1)  # Mean tokens
    else:
        embedding = embedding[:, -1, :]  # Take last token

    embedding = embedding[0]  # Take first and only element of batch

    embedding_cpu = embedding.to("cpu").detach()
    del tokens, outputs, embedding
    return embedding_cpu

In [6]:
def gen_diverse(text, preview=True):
    duration_start = time.perf_counter()
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        tokens,
        max_new_tokens=1024,
        return_dict_in_generate=True,
        streamer=streamer if preview else None,
        num_beams=5,
        num_beam_groups=5,
        diversity_penalty=1.0,
    )
    display(outputs.sequences)
    print(len(outputs.sequences))
    # output_tokens = outputs.sequences[0]
    # output_gen_tokens = output_tokens[
    #     len(tokens[0]) : -1
    # ]  # From just after prompt to just before <|end|> token
    # output_string = tokenizer.decode(output_gen_tokens)
    # duration_seconds = time.perf_counter() - duration_start
    # if preview:
    #     print(
    #         "== took {} ({} toks: {}/tok; {} tps) ==".format(
    #             timedelta(seconds=duration_seconds),
    #             len(output_gen_tokens),
    #             timedelta(seconds=duration_seconds / len(output_gen_tokens)),
    #             len(output_gen_tokens) / duration_seconds,
    #         )
    #     )
    #     print()
    # del tokens, outputs, output_tokens, output_gen_tokens
    # return output_string


gen_diverse("What is a good setting for a picnic in autumn?", preview=False)

You are not running the flash-attention implementation, expect numerical differences.


tensor([[    1, 32010,  1724,   338,   263,  1781,  4444,   363,   263, 11942,
          7823,   297,  1120,  1227, 29973, 29871, 32007, 32001,   319,  1781,
          4444,   363,   263, 11942,  7823,   297,  1120,  1227,   723,   367,
           263,  5763,   293, 14089,   411,   325,  4626,   424,  6416,   900,
         29875,   482, 29889, 13001,   284, 14354,  1033,  3160, 29901,    13,
            13,    13, 29896, 29889,  3579,  3388,   280,  6070,   345,  4815,
          1068, 29901,  8360,   776,   363,   967,   380, 27389,  2479,   310,
          2654, 29892, 24841, 29892,   322, 13328,  2910,   280, 11308, 29892,
           445, 14089, 16688,   263, 14956,   802,  1250,  8865,   363,   263,
          6416, 11942,  7823, 29889,   450, 14089,   756, 20947,   310,  1722,
          8162,   363, 11942, 19254,   292, 29892,   408,  1532,   408, 22049,
          1020,  2719,   363,   263,   454,   275,   545,   368,   380,  1245,
         29889,    13,    13,    13, 29906, 29889,  

1
