In [1]:
!pip install flash-attn transformers accelerate

import time
from datetime import timedelta

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from transformers.utils import is_flash_attn_2_available

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
    # attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer, skip_prompt=True)

print("flash_attn_2 available:", is_flash_attn_2_available())



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


flash_attn_2 available: True


In [2]:
def gen(text, preview=True):
    duration_start = time.perf_counter()
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        tokens,
        max_new_tokens=1024,
        return_dict_in_generate=True,
        streamer=streamer if preview else None,
    )
    output_tokens = outputs.sequences[0]
    output_gen_tokens = output_tokens[
        len(tokens[0]) : -1
    ]  # From just after prompt to just before <|end|> token
    output_string = tokenizer.decode(output_gen_tokens)
    duration_seconds = time.perf_counter() - duration_start
    if preview:
        print(
            "== took {} ({} toks: {}/tok; {} tps) ==".format(
                timedelta(seconds=duration_seconds),
                len(output_gen_tokens),
                timedelta(seconds=duration_seconds / len(output_gen_tokens)),
                len(output_gen_tokens) / duration_seconds,
            )
        )
        print()
    del tokens, outputs, output_tokens, output_gen_tokens
    return output_string


gen("What is the closest star to the Sun?")
gen("What is the difference between hue, saturation, and value in exactly 30 words?")
gen("Where is Waldo?")

You are not running the flash-attention implementation, expect numerical differences.


The closest star to the Sun is Proxima Centauri. It is part of the Alpha Centauri star system, which also includes Alpha Centauri A and Alpha Centauri B. Proxima Centauri is approximately 4.24 light-years away from the Sun. It is a red dwarf star and is the closest known exoplanet host, with at least two confirmed planets, Proxima Centauri b and Proxima Centauri c, orbiting it.<|end|>
== took 0:00:05.284810 (113 toks: 0:00:00.046768/tok; 21.382037633841176 tps) ==

Hue refers to color's dominant wavelength, saturation measures intensity, and value indicates brightness or darkness. Together, they define a color's appearance.<|end|>
== took 0:00:01.412995 (37 toks: 0:00:00.038189/tok; 26.185519531137246 tps) ==

I'm unable to assist with that. However, Waldo is a fictional character from a series of children's books, and his location in those stories is always hidden among various illustrations. If you're referring to a different context, please provide more details.<|end|>
== took 0:00:

"I'm unable to assist with that. However, Waldo is a fictional character from a series of children's books, and his location in those stories is always hidden among various illustrations. If you're referring to a different context, please provide more details."

In [3]:
conv = gen(
    "Please generate an example conversation between participant1 and participant2 about frogs. Make it last 20 utterances and have a follow-up. Please label utterance numbers and separate utterances with newlines."
)

1. Participant1: Hey, have you ever been fascinated by frogs?
2. Participant2: Actually, yes! I find them quite interesting. Why do you ask?
3. Participant1: I was just reading about their life cycle and it's quite fascinating.
4. Participant2: Oh, I'd love to hear more about it. What did you learn?
5. Participant1: Well, did you know that frogs start their life as eggs?
6. Participant2: Yes, I've heard about that. The eggs hatch into tadpoles, right?
7. Participant1: Exactly! Tadpoles live in water and breathe through gills.
8. Participant2: That's so cool! And then they undergo metamorphosis to become adult frogs.
9. Participant1: Yes, during metamorphosis, they develop lungs and legs for life on land.
10. Participant2: I've always wondered how they manage to jump so high and far.
11. Participant1: It's all about their powerful hind legs and the elastic energy stored in their tendons.
12. Participant2: That's amazing! I've also heard that frogs have a unique way of communicating.
13.

In [4]:
prompt = """
    Summarize the following conversation between two participants in no more than 200 words. Include the topic of conversation, 2-3 bullet points of discussion, and any follow-up action items. Output each of these pieces of information as its own section with a markdown heading. When referencing particular text in the conversation, specify the utterance with the utterance number in brackets, [123].
    
    {}
"""
gen(prompt.format(conv))

# Topic of Conversation
- Participants discussing the fascinating life cycle and characteristics of frogs.

# Discussion Points
- Frogs start their life as eggs, which hatch into tadpoles that live in water and breathe through gills.
- During metamorphosis, tadpoles develop lungs and legs to live on land.
- Frogs have powerful hind legs and elastic tendons that enable them to jump high and far.
- Frogs communicate using vocalizations, body language, and chemical signals.
- Some frogs can change their skin color.

# Follow-up Actions
- Participant1 will research nature reserves and zoos with a good frog exhibit.
- Participant2 will look for local frog enthusiasts and experts to join the trip.
- Both participants will reconvene next week to share their findings and plan the trip.<|end|>
== took 0:00:08.768498 (209 toks: 0:00:00.041955/tok; 23.83532570165608 tps) ==



'# Topic of Conversation\n- Participants discussing the fascinating life cycle and characteristics of frogs.\n\n# Discussion Points\n- Frogs start their life as eggs, which hatch into tadpoles that live in water and breathe through gills.\n- During metamorphosis, tadpoles develop lungs and legs to live on land.\n- Frogs have powerful hind legs and elastic tendons that enable them to jump high and far.\n- Frogs communicate using vocalizations, body language, and chemical signals.\n- Some frogs can change their skin color.\n\n# Follow-up Actions\n- Participant1 will research nature reserves and zoos with a good frog exhibit.\n- Participant2 will look for local frog enthusiasts and experts to join the trip.\n- Both participants will reconvene next week to share their findings and plan the trip.'

# Strategy A: Ask LLM to generate example data of different types

In [5]:
nouns = gen("Please give a list of 30 nouns, comma-separated.")

apple, book, car, dog, elephant, flower, guitar, house, ice cream, jacket, kite, laptop, mountain, necklace, ocean, piano, rainbow, skateboard, tiger, umbrella, violin, waterfall, xylophone, yacht, zebra, apple tree, bookshelf, carrot, doghouse, elephant tusk, flower vase, guitar case<|end|>
== took 0:00:03.602849 (96 toks: 0:00:00.037530/tok; 26.645581174504784 tps) ==



In [7]:
nouns_list = nouns.split(", ")
nouns_list

['apple',
 'book',
 'car',
 'dog',
 'elephant',
 'flower',
 'guitar',
 'house',
 'ice cream',
 'jacket',
 'kite',
 'laptop',
 'mountain',
 'necklace',
 'ocean',
 'piano',
 'rainbow',
 'skateboard',
 'tiger',
 'umbrella',
 'violin',
 'waterfall',
 'xylophone',
 'yacht',
 'zebra',
 'apple tree',
 'bookshelf',
 'carrot',
 'doghouse',
 'elephant tusk',
 'flower vase',
 'guitar case']

In [8]:
import random

strat_a_convs = []
for noun in nouns_list:
    conv = gen(
        "Please generate an example conversation between participant1 and participant2 about {}. Make it last {} utterances and have a follow-up. Please label utterance numbers and separate utterances with newlines.".format(
            noun, random.randrange(5, 20)
        )
    )
    strat_a_convs.append(conv)

Participant1: Hey, have you tried the new variety of apples at the farmer's market?

Participant2: No, I haven't. What's special about them?

Participant1: They're called Honeycrisp apples. They have a really crisp texture and a sweet-tart flavor.

Participant2: That sounds delicious! I usually prefer Fuji apples, but I'm open to trying new ones.

Participant1: You should definitely give Honeycrisp a try. They're great for snacking or even in salads.

Participant2: I'll make sure to pick some up next time I'm at the market. Thanks for the recommendation!

Follow-up:
Participant1: No problem! Let me know how you like them. Maybe we can share a recipe using Honeycrisp apples.<|end|>
== took 0:00:08.027627 (213 toks: 0:00:00.037688/tok; 26.533371561733905 tps) ==

Participant1: Hey, have you read the new book by John Green?

Participant2: No, I haven't. What's it about?

Utterance 2:

Participant1: It's called "Turtles All the Way Down." It's a story about a teenage girl named Aza who str

In [15]:
print(strat_a_convs[2])

Participant1: Hey, I'm thinking about buying a new car. Any recommendations? (1)

Participant2: Sure! What are you looking for in a car? (2)

Participant1: I want something reliable and fuel-efficient. (3)

Participant2: Have you considered a hybrid or electric vehicle? They're both reliable and fuel-efficient. (4)

Participant1: I've heard good things about electric cars, but I'm not sure if they're worth the investment. (5)

Participant2: Electric cars have lower maintenance costs and are better for the environment. Plus, there are more charging stations now. (6)

Participant1: That's true. What about the driving range? I don't want to be stranded. (7)

Participant2: Most electric cars have a range of 200-300 miles on a single charge. Some even have ranges of over 400 miles. (8)

Participant1: That's impressive. What about the price? Are they affordable? (9)

Participant2: Electric cars have become more affordable in recent years. You can find some great deals, especially with govern

In [32]:
# INTERESTING: Does proper chat format matter for embeddings?
# TODO: Evaluate taking just last layer vs. all layers for embeddings
def embed(text):
    duration_start = time.perf_counter()
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model(tokens, output_hidden_states=True)
    embedding = outputs.hidden_states[-1].squeeze(0).mean(dim=0)
    embedding_cpu = embedding.to("cpu").detach()
    del tokens, outputs, embedding
    return embedding_cpu


embed(strat_a_convs[0])

tensor([-0.2109,  0.0420,  0.6680,  ...,  0.6172, -0.4258,  0.4316],
       dtype=torch.bfloat16)

In [33]:
strat_a_embeds = [embed(conv).float().numpy() for conv in strat_a_convs]
strat_a_embeds[0]

array([-0.2109375 ,  0.04199219,  0.66796875, ...,  0.6171875 ,
       -0.42578125,  0.43164062], dtype=float32)

# Strategy B: Randomly perturb words in the input

# Strategy C: Use diversity reward during decode to find as different of inputs as possible