In [1]:
!pip install flash-attn transformers accelerate termcolor altair

import time
from datetime import timedelta

import torch
import torch.nn.functional as F
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from transformers.utils import is_flash_attn_2_available

torch.random.manual_seed(0)

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    device_map="cuda",
    torch_dtype="auto",
    trust_remote_code=True,
    # attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
streamer = TextStreamer(tokenizer, skip_prompt=True)z

print("flash_attn_2 available:", is_flash_attn_2_available())



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


flash_attn_2 available: True


In [2]:
def gen(text, preview=True):
    duration_start = time.perf_counter()
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model.generate(
        tokens,
        max_new_tokens=1024,
        return_dict_in_generate=True,
        streamer=streamer if preview else None,
    )
    output_tokens = outputs.sequences[0]
    output_gen_tokens = output_tokens[
        len(tokens[0]) : -1
    ]  # From just after prompt to just before <|end|> token
    output_string = tokenizer.decode(output_gen_tokens)
    duration_seconds = time.perf_counter() - duration_start
    if preview:
        print(
            "== took {} ({} toks: {}/tok; {} tps) ==".format(
                timedelta(seconds=duration_seconds),
                len(output_gen_tokens),
                timedelta(seconds=duration_seconds / len(output_gen_tokens)),
                len(output_gen_tokens) / duration_seconds,
            )
        )
        print()
    del tokens, outputs, output_tokens, output_gen_tokens
    return output_string


def embed(text, mean_layers=False, mean_tokens=False, prompt_prefix=""):
    duration_start = time.perf_counter()
    if prompt_prefix:
        prompt = "<|user|>\n{}\n```\n{}\n``` <|end|>\n<|assistant|>".format(
            prompt_prefix, text
        )
    else:
        prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    tokens = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = model(tokens, output_hidden_states=True)
    embedding = outputs.hidden_states
    # print(len(embedding), embedding[0].shape)
    if mean_layers:
        # print(torch.stack(embedding).shape)
        embedding = torch.stack(embedding).mean(dim=0)  # Mean layers
    else:
        embedding = embedding[-1]  # Take last layer

    if mean_tokens:
        embedding = embedding.mean(dim=1)  # Mean tokens
    else:
        embedding = embedding[:, -1, :]  # Take last token

    embedding = embedding[0]  # Take first and only element of batch

    embedding_cpu = embedding.to("cpu").detach()
    del tokens, outputs, embedding
    return embedding_cpu

In [4]:
def D(obj):
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor):
        print(obj.shape)
        display(obj)

In [5]:
text = 'Write a haiku about symmetry.'
prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
inputs = tokenizer(prompt, return_tensors='pt').to('cuda')
outputs = model(**inputs)

You are not running the flash-attention implementation, expect numerical differences.


In [6]:
# (BATCH_SIZE, NUM_TOKENS, VOCAB_SIZE)
outputs.logits.shape

torch.Size([1, 12, 32064])

In [7]:
logits = outputs.logits[0, -1, :]
D(logits)

torch.Size([32064])


tensor([5.6875, 6.0312, 3.9375,  ..., 0.0000, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [8]:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cum_probs = torch.cumsum(probs, dim=-1)
D(cum_probs)

torch.Size([32064])


tensor([0.6870, 0.7924, 0.8978,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0',
       grad_fn=<CumsumBackward0>)

In [12]:
sorted_keep_indices = cum_probs < 0.9
sorted_keep_indices[1:] = sorted_keep_indices[:-1].clone()
sorted_keep_indices[0] = 1
D(sorted_keep_indices)
sorted_keep_indices.sum()

torch.Size([32064])


tensor([ True,  True,  True,  ..., False, False, False], device='cuda:0')

tensor(4, device='cuda:0')

In [13]:
keep_indices = sorted_indices[sorted_keep_indices]
D(keep_indices)

torch.Size([4])


tensor([11612,  2431,  7392,  9897], device='cuda:0')

In [14]:
tokenizer.decode(keep_indices)

'Mir Per Bal Ref'

In [15]:
keep_probs = probs[sorted_keep_indices]
D(keep_probs)
keep_probs.sum()

torch.Size([4])


tensor([0.6870, 0.1054, 0.1054, 0.0235], device='cuda:0',
       grad_fn=<IndexBackward0>)

tensor(0.9213, device='cuda:0', grad_fn=<SumBackward0>)

In [18]:
def top_p_tokens(logits, top_p=0.9):
    """Does not support batches yet. logits must be of shape (VOCAB_SIZE)."""
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    probs = F.softmax(sorted_logits, dim=-1)
    cum_probs = torch.cumsum(probs, dim=-1)
    sorted_keep_indices = cum_probs < 0.9  # Create tensor of bools indicating which indices are cumulatively less than top_p
    sorted_keep_indices[1:] = sorted_keep_indices[:-1].clone()  # Keep the last element that went over top_p
    sorted_keep_indices[0] = 1  # Always keep the first element
    keep_toks = sorted_indices[sorted_keep_indices]
    keep_probs = probs[sorted_keep_indices]
    return keep_toks, keep_probs

In [40]:
from collections import namedtuple
class Candidate(namedtuple('Candidate', ['sequence', 'prob'])):
    pass
    def __repr__(self):
        string = tokenizer.decode(self.sequence[0])
        return f'Candidate [{self.prob}]: {string} \n  ({self.sequence})'

max_candidates = 100
text = 'Write a haiku about symmetry.'
prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
inputs = tokenizer(prompt, return_tensors='pt').to('cuda')

candidates = [
    Candidate(inputs.input_ids, 1.0)
]

p = 0.99999999
for candidate in candidates:
    print(candidate)
print()
# OPTIM: Batch inference, which means rewrite top_p_tokens to use batch
# OPTIM: Keep previous_key values
# OPTIM: Use Tensors to keep track of candidates (with masked values)
# OPTIM: Log probs
for i in range(10):
    new_candidates = []
    for candidate in candidates:
        outputs = model(input_ids=candidate.sequence)
        new_toks, new_probs = top_p_tokens(outputs.logits[0, -1, :], p)
        for new_tok, new_prob in zip(new_toks, new_probs):
            new_candidate = Candidate(torch.cat([candidate.sequence, new_tok.unsqueeze(0).unsqueeze(0)], dim=1), candidate.prob * new_prob.item())
            new_candidates.append(new_candidate)
    candidates = new_candidates[:max_candidates]
    print(i, p)
    for candidate in candidates:
        print(candidate)
    print()
    p *= 0.5

Candidate [1.0]: <s><|user|> Write a haiku about symmetry. <|end|><|assistant|> 
  (tensor([[    1, 32010, 14350,   263,   447, 18282,  1048, 18446, 29889, 29871,
         32007, 32001]], device='cuda:0'))

0 0.99999999
Candidate [0.6870314478874207]: <s><|user|> Write a haiku about symmetry. <|end|><|assistant|> Mir 
  (tensor([[    1, 32010, 14350,   263,   447, 18282,  1048, 18446, 29889, 29871,
         32007, 32001, 11612]], device='cuda:0'))
Candidate [0.10535968840122223]: <s><|user|> Write a haiku about symmetry. <|end|><|assistant|> Per 
  (tensor([[    1, 32010, 14350,   263,   447, 18282,  1048, 18446, 29889, 29871,
         32007, 32001,  2431]], device='cuda:0'))
Candidate [0.10535968840122223]: <s><|user|> Write a haiku about symmetry. <|end|><|assistant|> Bal 
  (tensor([[    1, 32010, 14350,   263,   447, 18282,  1048, 18446, 29889, 29871,
         32007, 32001,  7392]], device='cuda:0'))
Candidate [0.02350892312824726]: <s><|user|> Write a haiku about symmetry. <|end|>