In [1]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

<torch._C.Generator at 0x7f16b5139df0>

In [2]:
def D(obj):
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor):
        print(obj.shape)
        display(obj)
    else:
        display(obj)

In [3]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
    use_cache=True,
    # attn_implementation='flash_attention_2',
)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
max_candidates = 16
max_new_tokens = 100
batch_size = 8
p_falloff = 0.5 # UNIMPLEMENTED
prune_similar_sequences = True # UNIMPLEMENTED
prune_similar_branches = True # UNIMPLEMENTED
prune_similar_embeddings = True # UNIMPLEMENTED

In [8]:
def init_candidates(text: str):
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    inputs = tokenizer(prompt, return_tensors='pt')

    max_total_tokens = inputs.input_ids.shape[1] + max_new_tokens

    # (max_candidates, max_total_tokens)
    candidates = torch.zeros((max_candidates, max_total_tokens), dtype=torch.long, device=device)
    # (max_candidates, max_total_tokens)
    candidate_masks = torch.zeros((max_candidates, max_total_tokens), dtype=torch.bool, device=device)
    # (max_candidates)
    candidate_parents = torch.zeros((max_candidates), dtype=torch.long, device=device)
    # (max_candidates)
    candidate_logprobs = torch.zeros((max_candidates), dtype=torch.float32, device=device)

    candidates[0, :inputs.input_ids.shape[1]] = inputs.input_ids
    candidate_masks[0, :inputs.input_ids.shape[1]] = inputs.attention_mask
    candidate_parents[0] = 0
    candidate_logprobs[0] = 0.0

    return candidates, candidate_masks, candidate_parents, candidate_logprobs

candidates, candidate_masks, candidate_parents, candidate_logprobs = init_candidates('What is the cutest dog?')
D(candidates)
D(candidate_masks)
D(candidate_parents)
D(candidate_logprobs)

torch.Size([16, 112])


tensor([[    1, 32010,  1724,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]], device='cuda:0')

torch.Size([16, 112])


tensor([[ True,  True,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0')

torch.Size([16])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

torch.Size([16])


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')

In [12]:
def infer(candidates, candidate_masks, candidate_parents, candidate_logprobs):
    batches = (max_candidates + batch_size - 1) // batch_size  # Round up to nearest whole number of batches
    for i in range(0, batches, 1):
        batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
        D(batch_candidates)
        batch_candidate_masks = candidate_masks[i * batch_size:(i + 1) * batch_size]
        D(batch_candidate_masks)
    # outputs = model(input_ids=candidates, attention_mask=candidate_masks)

    # return outputs

outputs = infer(candidates, candidate_masks, candidate_parents, candidate_logprobs)
D(outputs)

2
torch.Size([8, 112])


tensor([[    1, 32010,  1724,   338,   278,  5700,   342, 11203, 29973, 29871,
         32007, 32001,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0

torch.Size([8, 112])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False

torch.Size([8, 112])


tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

torch.Size([8, 112])


tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False

None

In [6]:
# def candidates_generator(text: str):
#     print(text)
#     candidates, candidate_masks, candidate_parents, candidate_logprobs = _init_candidates(text)

#     return candidates, candidate_masks, candidate_parents, candidate_logprobs
        
