In [35]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
import torch
from tqdm import tqdm

In [36]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.38s/it]


In [37]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.unk_token
tokenizer.model_max_length = 1024

In [38]:
input = "hello how do you do?"
tok_inputs = tokenizer(input, return_tensors="pt", padding=True, truncation=True, )
tok_inputs = {k: v.to(model.device) for k, v in tok_inputs.items() if k != 'token_type_ids'}

In [39]:
gen_config = {
    'max_new_tokens': 10,
    'do_sample': False,
    # 'top_p': 0.95,
}

Greedy 

In [32]:
greedy_output = model.generate(**tok_inputs, **gen_config)



In [40]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0]))

Output:
----------------------------------------------------------------------------------------------------
<s> hello how do you do?
I'm a 20 year old


Beam Search

In [41]:
beam_output = model.generate(**tok_inputs, max_new_tokens=10, num_beams=5, early_stopping=True)

In [42]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
hello how do you do?
I'm new here and I'm


Sampling

In [43]:
sample_output = model.generate(**tok_inputs, 
                               max_new_tokens=10,
                               do_sample=True,
                               top_k=0)

In [44]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
hello how do you do?
I am a nice guy from the US


Top-k Sampling

In [45]:
top_k_output = model.generate(
    **tok_inputs,
    max_new_tokens=10,
    do_sample=True,
    top_k=50
)

In [46]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(top_k_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
hello how do you do?
Hi, I'm new here and I


Nucleus Sampling

In [47]:
nucleus_output = model.generate(
    **tok_inputs,
    max_new_tokens=10,
    do_sample=True,
    top_p=0.92,
    top_k=0
)

In [48]:
print("Output:\n" + 100 * '-')
print(tokenizer.decode(nucleus_output[0], skip_special_tokens=True))

Output:
----------------------------------------------------------------------------------------------------
hello how do you do?
I'm a new member, and I


Custom decoding function - Currently implements the greedy decoding algorithm. Output matches the default greedy output shown above

In [49]:
def mygenerate(model, tf_inputs, **gen_config):
    MAX_NEW_TOKENS = gen_config.pop("max_new_tokens", 100)
    BLOCK_SIZE = model.config.max_position_embeddings
    TAU = gen_config.pop("temperature", 1.0)
    DO_SAMPLE = gen_config.pop("do_sample", False)
    context = output = tf_inputs['input_ids']
    past_key_values = None
    model.eval()

    with torch.no_grad():
        for _ in tqdm(range(MAX_NEW_TOKENS)):
            block_context = context[:, -BLOCK_SIZE:]
            model_out = model(block_context, past_key_values)
            logits = model_out.logits / TAU 
            probs = F.softmax(logits[:, -1, :], dim=-1)
            new_token = torch.multinomial(probs, 1) if DO_SAMPLE else torch.argmax(probs, dim=-1, keepdim=True)
            context = torch.cat([context, new_token], dim=-1)
    return context

In [50]:
gen_config = {
    'max_new_tokens': 10,
    'do_sample': False,
    # 'top_p': 0.95,
}

out = mygenerate(model, tok_inputs, **gen_config)
tokenizer.decode(out[0])

100%|██████████| 10/10 [00:39<00:00,  3.91s/it]


"<s> hello how do you do?\nI'm a 20 year old"