In [61]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'

In [37]:
from llama_cpp import Llama

# Quantized open-source model provided at https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF, located in adjacent models/ directory.
llm = Llama(
    model_path='./models/mistral-7b-instruct-v0.2.Q5_K_M.gguf', n_ctx=2048, n_gpu_layers=16, n_batch=512, logits_all=True)


def mistral_prompt(text):
    return f'<s>[INST] {text} [/INST]'

llama_model_loader: loaded meta data with 24 key-value pairs and 291 tensors from ./models/mistral-7b-instruct-v0.2.Q5_K_M.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = mistralai_mistral-7b-instruct-v0.2
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:         

In [38]:
input_ids = llm.tokenize(b'Tell me a joke.')
input_ids

[1, 15259, 528, 264, 13015, 28723]

In [39]:
llm.eval(input_ids)

In [40]:
llm.eval_tokens

deque([1, 15259, 528, 264, 13015, 28723], maxlen=2048)

In [41]:
llm.eval_logits

deque([[-5.509405136108398,
        -5.642429351806641,
        -0.25985416769981384,
        -4.011919021606445,
        -5.511569976806641,
        -5.508670330047607,
        -5.5106916427612305,
        -5.510843753814697,
        -5.511539459228516,
        -5.508193492889404,
        -5.509894371032715,
        -5.50925874710083,
        -0.9049124717712402,
        2.6067731380462646,
        -5.512192726135254,
        -5.511593818664551,
        -5.51029109954834,
        -5.512083053588867,
        -5.508544445037842,
        -5.510097980499268,
        -5.5108866691589355,
        -5.511232376098633,
        -5.512964725494385,
        -5.511204242706299,
        -5.51012659072876,
        -5.509476184844971,
        -5.5112481117248535,
        -5.508009910583496,
        -5.512011528015137,
        -5.511092185974121,
        -5.510834217071533,
        -5.513062477111816,
        -5.511260986328125,
        -5.513177871704102,
        -5.512188911437988,
        -5.508994

In [63]:
import torch
logits = torch.tensor(llm.eval_logits)

tensor([[ -5.5094,  -5.6424,  -0.2599,  ...,  -4.2934,  -3.3498,  -3.8549],
        [ -6.6826,  -6.8964,  -8.0603,  ...,  -6.5963,  -4.2556,  -3.6978],
        [ -8.9616,  -9.2665,  -3.4477,  ...,  -6.9543,  -7.3498,  -4.5341],
        [ -7.1710,  -7.1976,  -3.3864,  ...,  -5.5624,  -4.2146,  -3.4109],
        [ -9.7692, -10.6111,  -2.7398,  ...,  -6.9817,  -6.0972,  -7.3837],
        [ -8.5317,  -9.1659,  -1.6073,  ...,  -7.2843,  -4.7371,  -6.6491]])

In [43]:
logits.shape

torch.Size([6, 32000])

In [64]:
logits_view = logits.view(1, logits.shape[0], logits.shape[1])

tensor([[[ -5.5094,  -5.6424,  -0.2599,  ...,  -4.2934,  -3.3498,  -3.8549],
         [ -6.6826,  -6.8964,  -8.0603,  ...,  -6.5963,  -4.2556,  -3.6978],
         [ -8.9616,  -9.2665,  -3.4477,  ...,  -6.9543,  -7.3498,  -4.5341],
         [ -7.1710,  -7.1976,  -3.3864,  ...,  -5.5624,  -4.2146,  -3.4109],
         [ -9.7692, -10.6111,  -2.7398,  ...,  -6.9817,  -6.0972,  -7.3837],
         [ -8.5317,  -9.1659,  -1.6073,  ...,  -7.2843,  -4.7371,  -6.6491]]])

In [48]:
logits_view.shape

torch.Size([1, 6, 32000])

In [52]:
shift_logits = logits_view[..., :-1, :].contiguous()
print(shift_logits.shape)
shift_logits

torch.Size([1, 5, 32000])


tensor([[[ -5.5094,  -5.6424,  -0.2599,  ...,  -4.2934,  -3.3498,  -3.8549],
         [ -6.6826,  -6.8964,  -8.0603,  ...,  -6.5963,  -4.2556,  -3.6978],
         [ -8.9616,  -9.2665,  -3.4477,  ...,  -6.9543,  -7.3498,  -4.5341],
         [ -7.1710,  -7.1976,  -3.3864,  ...,  -5.5624,  -4.2146,  -3.4109],
         [ -9.7692, -10.6111,  -2.7398,  ...,  -6.9817,  -6.0972,  -7.3837]]])

In [65]:
labels = torch.tensor(input_ids)

tensor([    1, 15259,   528,   264, 13015, 28723])

In [66]:
labels_view = labels.view(1, labels.shape[0])

tensor([[    1, 15259,   528,   264, 13015, 28723]])

In [70]:
shift_labels = labels_view[..., 1:].contiguous()

tensor([[15259,   528,   264, 13015, 28723]])

In [71]:
from torch.nn import CrossEntropyLoss
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, logits.shape[-1])
shift_labels = shift_labels.view(-1)
loss = loss_fct(shift_logits, shift_labels)

tensor(4.7228)

In [93]:
import torch


def ppl(llm, prompt_formatter, input):
    llm.reset()
    prompt = prompt_formatter(input)
    print('prompt', prompt)
    input_ids = llm.tokenize(bytes(prompt, 'utf-8'))
    print('input_ids', len(input_ids), input_ids)
    llm.eval(input_ids)
    logits = torch.tensor(llm.eval_logits)
    print('logits', logits.shape)
    labels = torch.tensor(input_ids)
    print('labels', labels.shape)

    loss = torch.nn.functional.cross_entropy(logits, labels)
    print('loss', loss)
    return torch.exp(loss)


print(ppl(llm, mistral_prompt, 'the dog walked to the park'))
print('====')
print(ppl(llm, mistral_prompt, 'where fund cat salad before less'))

prompt <s>[INST] the dog walked to the park [/INST]
input_ids 17 [1, 523, 28713, 28767, 28792, 16289, 28793, 272, 3914, 5610, 298, 272, 4890, 733, 28748, 16289, 28793]
logits torch.Size([17, 32000])
labels torch.Size([17])
loss tensor(12.1269)
tensor(184770.7500)
====
prompt <s>[INST] where fund cat salad before less [/INST]
input_ids 17 [1, 523, 28713, 28767, 28792, 16289, 28793, 970, 3360, 5255, 25256, 1159, 2108, 733, 28748, 16289, 28793]
logits torch.Size([17, 32000])
labels torch.Size([17])
loss tensor(11.8530)
tensor(140502.1094)


In [115]:
prompt = 'the dog walked to the because'
tokens = llm.tokenize(bytes(prompt, 'utf-8'))

for i in range(len(tokens)):
    llm.reset()
    tokens_sub = tokens[:i + 1]
    print(tokens_sub, llm.detokenize(tokens_sub))
    llm.eval(tokens_sub)
    logits = torch.tensor(llm.eval_logits)
    display(logits)

    inputs = torch.tensor(tokens_sub)
    loss = torch.nn.functional.cross_entropy(logits, inputs)
    display(loss)
    print('=====')

[1] b''


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549]])

tensor(21.4871)

=====
[1, 272] b'the'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741]])

tensor(15.0426)

=====
[1, 272, 3914] b'the dog'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741],
        [-6.9367, -7.1049, -3.9957,  ..., -7.2401, -5.1559, -5.5102]])

tensor(13.3820)

=====
[1, 272, 3914, 5610] b'the dog walked'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741],
        [-6.9367, -7.1049, -3.9957,  ..., -7.2401, -5.1559, -5.5102],
        [-8.0252, -7.8293, -0.9691,  ..., -6.4636, -5.0039, -5.0019]])

tensor(12.9502)

=====
[1, 272, 3914, 5610, 298] b'the dog walked to'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741],
        [-6.9367, -7.1049, -3.9957,  ..., -7.2401, -5.1559, -5.5102],
        [-8.0252, -7.8293, -0.9691,  ..., -6.4636, -5.0039, -5.0019],
        [-7.7160, -8.0192, -2.9540,  ..., -6.5423, -4.7248, -4.7977]])

tensor(12.1390)

=====
[1, 272, 3914, 5610, 298, 272] b'the dog walked to the'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741],
        [-6.9367, -7.1049, -3.9957,  ..., -7.2401, -5.1559, -5.5102],
        [-8.0252, -7.8293, -0.9691,  ..., -6.4636, -5.0039, -5.0019],
        [-7.7160, -8.0192, -2.9540,  ..., -6.5423, -4.7248, -4.7977],
        [-7.2335, -7.0519, -1.7630,  ..., -7.1533, -4.6419, -5.2300]])

tensor(11.5601)

=====
[1, 272, 3914, 5610, 298, 272, 1096] b'the dog walked to the because'


tensor([[-5.5094, -5.6424, -0.2599,  ..., -4.2934, -3.3498, -3.8549],
        [-7.7832, -8.0748, -7.9795,  ..., -4.3758, -5.5358, -6.0741],
        [-6.9367, -7.1049, -3.9957,  ..., -7.2401, -5.1559, -5.5102],
        ...,
        [-7.7160, -8.0192, -2.9540,  ..., -6.5423, -4.7248, -4.7977],
        [-7.2335, -7.0519, -1.7630,  ..., -7.1533, -4.6419, -5.2300],
        [-7.1256, -6.6898, -4.2902,  ..., -7.1522, -5.5642, -5.2738]])

tensor(11.3797)

=====


In [120]:
# Example of very very simple autoregressive generation using logits from model.

prompt = '<s>[INST] What is the first letter of "bee"? [/INST] b </s> [INST] What is the first letter of "cat"? Answer with a single letter. [/INST]'
tokens = llm.tokenize(bytes(prompt, 'utf-8'))

for i in range(15):
    llm.reset()
    llm.eval(tokens)
    logits = torch.tensor(llm.eval_logits)
    sorted_logits, sorted_indices = torch.sort(logits[-1, :], descending=True)
    print(tokens, llm.detokenize(tokens))
    for logit, idx in zip(sorted_logits[:5], sorted_indices[:5]):
        print(logit, idx.item(), llm.detokenize([idx.item()]))
    tokens.append(sorted_indices[0])
    print('==========')

[1, 523, 28713, 28767, 28792, 16289, 28793, 1824, 349, 272, 907, 5498, 302, 345, 28255, 27257, 733, 28748, 16289, 28793, 287, 1867, 28713, 28767, 733, 16289, 28793, 1824, 349, 272, 907, 5498, 302, 345, 6272, 27257, 26307, 395, 264, 2692, 5498, 28723, 733, 28748, 16289, 28793] b'<s>[INST] What is the first letter of "bee"? [/INST] b </s> [INST] What is the first letter of "cat"? Answer with a single letter. [/INST]'
tensor(26.4068) 277 b' c'
tensor(15.1816) 415 b' The'
tensor(15.0178) 28717 b'c'
tensor(11.0393) 19983 b' cens'
tensor(9.7702) 334 b' C'
[1, 523, 28713, 28767, 28792, 16289, 28793, 1824, 349, 272, 907, 5498, 302, 345, 28255, 27257, 733, 28748, 16289, 28793, 287, 1867, 28713, 28767, 733, 16289, 28793, 1824, 349, 272, 907, 5498, 302, 345, 6272, 27257, 26307, 395, 264, 2692, 5498, 28723, 733, 28748, 16289, 28793, tensor(277)] b'<s>[INST] What is the first letter of "bee"? [/INST] b </s> [INST] What is the first letter of "cat"? Answer with a single letter. [/INST] c'
tensor(21.

In [140]:
llm.reset()

# How likely is this sentence?
prompt = b'The dog walked to the door.'
tokens = llm.tokenize(prompt)
llm.eval(tokens)
logits = torch.tensor(llm.eval_logits)
probs = torch.softmax(logits, dim=-1)

prob_log = 0.0
for idx, token in enumerate(tokens):
    print(idx, token, llm.detokenize(
        [token]), probs[idx, token], torch.log(probs[idx, token]))
    prob_log += torch.log(probs[idx, token]).item()
print('Log:', prob_log)

0 1 b'' tensor(4.6587e-10) tensor(-21.4871)
1 415 b' The' tensor(7.4426e-05) tensor(-9.5057)
2 3914 b' dog' tensor(5.9457e-06) tensor(-12.0328)
3 5610 b' walked' tensor(9.6749e-06) tensor(-11.5460)
4 298 b' to' tensor(2.1755e-05) tensor(-10.7357)
5 272 b' the' tensor(6.4987e-05) tensor(-9.6413)
6 2251 b' door' tensor(9.3067e-06) tensor(-11.5848)
7 28723 b'.' tensor(2.3318e-07) tensor(-15.2714)
Log: -101.80482578277588
Mult: 6.119640507577637e-45 tensor([-101.8926])
