In [1]:
!pip install transformers accelerate optimum nvidia-ml-py



In [2]:
from transformers.utils import is_flash_attn_2_available
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import numpy as np
import torch.nn.functional as F
import torch
from datetime import timedelta
import time
from collections import namedtuple
import json

torch.random.manual_seed(0)

<torch._C.Generator at 0x7f7c3356e1f0>

In [3]:
from pynvml import *

def check_gpu(step):
    nvmlInit()
    handle = nvmlDeviceGetHandleByIndex(0)
    info = nvmlDeviceGetMemoryInfo(handle)
    print(f"{step}: GPU memory used: {info.used // 1024**2} MB.")

In [4]:
def D(obj):
    if isinstance(obj, tuple):
        print(len(obj))
    elif isinstance(obj, torch.Tensor):
        print(obj.shape)
        display(obj)
    else:
        display(obj)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct",
    torch_dtype=torch.bfloat16,
    device_map='auto',
    trust_remote_code=True,
    use_cache=True,
    # attn_implementation='flash_attention_2',
)
tokenizer = AutoTokenizer.from_pretrained(
    "microsoft/Phi-3-mini-4k-instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device', device)

check_gpu('model init')

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


device cuda
model init: GPU memory used: 7813 MB.


In [6]:
max_candidates = 16
max_new_tokens = 3
batch_size = 8
p_falloff = 0.5 # UNIMPLEMENTED
prune_similar_sequences = True # UNIMPLEMENTED
prune_similar_branches = True # UNIMPLEMENTED
prune_similar_embeddings = True # UNIMPLEMENTED

In [7]:
def init_candidates(text: str):
    prompt = "<|user|>\n{} <|end|>\n<|assistant|>".format(text)
    inputs = tokenizer(prompt, return_tensors='pt')

    max_total_tokens = inputs.input_ids.shape[1] + max_new_tokens

    # (max_candidates, max_total_tokens)
    candidates = torch.zeros((max_candidates, max_total_tokens), dtype=torch.long, device=device)
    # (max_candidates, max_total_tokens)
    candidate_masks = torch.zeros((max_candidates, max_total_tokens), dtype=torch.bool, device=device)
    # (max_candidates)
    candidate_parents = torch.zeros((max_candidates), dtype=torch.long, device=device)
    # (max_candidates)
    candidate_logprobs = torch.zeros((max_candidates), dtype=torch.float32, device=device)

    candidates[0, :inputs.input_ids.shape[1]] = inputs.input_ids
    candidate_masks[0, :inputs.input_ids.shape[1]] = inputs.attention_mask
    candidate_parents[0] = 0
    candidate_logprobs[0] = 0.0

    return candidates, candidate_masks, candidate_parents, candidate_logprobs

candidates, candidate_masks, candidate_parents, candidate_logprobs = init_candidates('What is the most popular breed of dog?')
D(candidates)
D(candidate_masks)
D(candidate_parents)
D(candidate_logprobs)

check_gpu('candidates init')

torch.Size([16, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 

torch.Size([16, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 

torch.Size([16])


tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')

torch.Size([16])


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')

candidates init: GPU memory used: 7843 MB.


In [8]:
# For testing batch inputs
inputs2 = tokenizer("A dog", return_tensors='pt')
candidates[11, :inputs2.input_ids.shape[1]] = inputs2.input_ids
candidate_masks[11, :inputs2.input_ids.shape[1]] = inputs2.attention_mask
candidate_parents[11] = 0
candidate_logprobs[11] = -1.3

check_gpu('test addl inputs added')

test addl inputs added: GPU memory used: 7843 MB.


In [14]:
test_be = tokenizer(["String A", "String B which is longer", "String C which is even longer"], return_tensors="pt", padding=True)
test_be

{'input_ids': tensor([[32000, 32000, 32000, 32000,     1,  1714,   319],
        [32000,     1,  1714,   350,   607,   338,  5520],
        [    1,  1714,   315,   607,   338,  1584,  5520]]), 'attention_mask': tensor([[0, 0, 0, 0, 1, 1, 1],
        [0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]])}

In [15]:
tokenizer.batch_decode(test_be.input_ids)

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><s> String A',
 '<|endoftext|><s> String B which is longer',
 '<s> String C which is even longer']

In [9]:
def infer(candidates, candidate_masks, candidate_parents, candidate_logprobs):
    with torch.inference_mode():
        batches = (max_candidates + batch_size - 1) // batch_size  # Round up to nearest whole number of batches

        check_gpu('infer start')
        for i in range(0, batches, 1):
            batch_candidates = candidates[i * batch_size:(i + 1) * batch_size]
            D(batch_candidates)
            batch_candidate_masks = candidate_masks[i * batch_size:(i + 1) * batch_size]
            D(batch_candidate_masks)

            check_gpu('batch views made')

            batch_outputs = model(input_ids=batch_candidates, attention_mask=batch_candidate_masks)
            D(batch_outputs.logits)

            # Possibly turn off caching to save memory here?
            check_gpu('batch forward run')

            del batch_outputs

            check_gpu('batch outputs deleted')

outputs = infer(candidates, candidate_masks, candidate_parents, candidate_logprobs)
D(outputs)

check_gpu('all batches run')

infer start: GPU memory used: 7843 MB.
torch.Size([8, 18])


tensor([[    1, 32010,  1724,   338,   278,  1556,  5972,  2078,   287,   310,
         11203, 29973, 29871, 32007, 32001,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 

torch.Size([8, 18])


tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 

batch views made: GPU memory used: 7843 MB.


You are not running the flash-attention implementation, expect numerical differences.


torch.Size([8, 18, 32064])


tensor([[[ 1.8125,  1.3438, -0.4473,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.2812,  9.6875, 10.1250,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.0625,  2.9844,  3.8281,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 1.5859, -1.2031, -3.5625,  ...,  0.0000,  0.0000,  0.0000],
         [ 1.0703,  0.7109, -5.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 3.5938, -3.8750, -3.3281,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0

batch forward run: GPU memory used: 8035 MB.
batch outputs deleted: GPU memory used: 8035 MB.
torch.Size([8, 18])


tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    1,   319, 11203,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0, 

torch.Size([8, 18])


tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, 

batch views made: GPU memory used: 8035 MB.
torch.Size([8, 18, 32064])


tensor([[[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 2.2031, -0.1865, -1.1719,  ...,  0.0000,  0.0000,  0.0000],
         [ 2.2031, -0.1865, -1.1719,  ...,  0

batch forward run: GPU memory used: 8035 MB.
batch outputs deleted: GPU memory used: 8035 MB.


None

all batches run: GPU memory used: 8035 MB.


In [10]:
def top_p_tokens(logits, top_p=0.9):
    """logits of shape (batch_size, vocab_size)"""
    with torch.inference_mode():
    
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cum_probs = torch.cumsum(probs, dim=-1)
        # Create tensor of bools indicating which indices are cumulatively less than top_p
        sorted_keep_indices = cum_probs < 0.9
        # Keep the last element that went over top_p
        sorted_keep_indices[1:] = sorted_keep_indices[:-1].clone()
        sorted_keep_indices[0] = 1  # Always keep the first element
        keep_toks = sorted_indices[sorted_keep_indices]
        keep_probs = probs[sorted_keep_indices]
        return keep_toks, keep_probs


In [11]:
# def candidates_generator(text: str):
#     print(text)
#     candidates, candidate_masks, candidate_parents, candidate_logprobs = _init_candidates(text)

#     return candidates, candidate_masks, candidate_parents, candidate_logprobs
        
