In [8]:
from tree_sitter import Language, Parser

Language.build_library(
    ".local/build/json_prefix.so",
    [
        "tree-sitter-json-prefix"
    ]
)
JSON_LANGUAGE = Language(".local/build/json_prefix.so", "json_prefix")
parser = Parser()
parser.set_language(JSON_LANGUAGE)

def depth_first_traversal(node):
    """Depth-first traversal of the tree"""
    yield node
    for child in node.children:
        yield from depth_first_traversal(child)

def json_error_or_prefix(s):
    tree = parser.parse(s)
    nodes = [(node.has_error, node.type.startswith("prefix_")) for node in depth_first_traversal(tree.root_node)]
    error = any([node[0] for node in nodes])
    prefix = any([node[1] for node in nodes])
    return error, prefix

In [9]:
import llama_cpp

MODEL_PATH = "../llms/models/ggml-alpaca.bin"

llama = llama_cpp.Llama(MODEL_PATH)

llama_model_load: loading model from '../llms/models/ggml-alpaca.bin' - please wait ...
llama_model_load: n_vocab = 32000
llama_model_load: n_ctx   = 512
llama_model_load: n_embd  = 4096
llama_model_load: n_mult  = 256
llama_model_load: n_head  = 32
llama_model_load: n_layer = 32
llama_model_load: n_rot   = 128
llama_model_load: f16     = 2
llama_model_load: n_ff    = 11008
llama_model_load: n_parts = 1
llama_model_load: type    = 1
llama_model_load: ggml map size = 4017.70 MB
llama_model_load: ggml ctx size =  81.25 KB
llama_model_load: mem required  = 5809.78 MB (+ 2052.00 MB per state)
llama_model_load: loading tensors from '../llms/models/ggml-alpaca.bin'
llama_model_load: model size =  4017.27 MB / num tensors = 291
llama_init_from_file: kv self size  =  512.00 MB
AVX = 1 | AVX2 = 1 | AVX512 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | VSX = 0 | 


In [10]:
import math

n_vocab = llama_cpp.llama_n_vocab(llama.ctx)
vocab = [llama_cpp.llama_token_to_str(llama.ctx, i) for i in range(n_vocab)]

text = b'["the", "quick", "brown", "fox", "'

tokens = llama.tokenize(text)
llama.reset()

completion = []

n = 256
for i in range(n):
    # Eval
    llama.eval(tokens)

    # Sample
    logits_raw = llama_cpp.llama_get_logits(llama.ctx)
    logits = logits_raw[:n_vocab]
    logprobs = [math.log(1.0 + math.exp(logit)) for logit in logits]
    top_logprobs = sorted(zip(vocab, range(n_vocab), logprobs), key=lambda x: x[2], reverse=True)
    tokens = None
    for token, index, logprob in top_logprobs:
        error, prefix = json_error_or_prefix(text + token)
        if error:
            continue

        if prefix:
            print(token, index, logprob)
            completion.append((token, index, logprob))
            tokens = [index]
            break

        print(token, index, logprob)
        completion.append((token, index, logprob))
        tokens = None
        break
    if tokens is None:
        break

b'j' 29926 22.213199615703903
b'umps' 17204 26.740674972536617
b'",' 613 25.598178863533025
b' "' 376 24.236846923857915
b'over' 957 22.420814514343288
b'",' 613 26.923004150392654
b' "' 376 23.873390197796752
b'the' 1552 24.115888595614674
b'",' 613 23.48542785650847
b' "' 376 24.954296112075085
b'la' 433 20.75459289647697
b'zy' 1537 22.62041854873398
b'",' 613 24.2032699585269
b' "' 376 24.050636291539792
b'dog' 26169 23.77336120610204
b'"]' 3108 24.41419982912651


In [11]:
import json

json.loads(text + b"".join([token for token, _, _ in completion]))

['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog']