In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

from incremental_parsing.generation.constrained_generation import prefix_suffix_constrained_generation
from incremental_parsing.lex_earley.lark_grammar import get_python_context
from ansi.colour import bg, fg
import datetime

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
BEAM_SIZE = 1
MAX_GENERATION_LENGTH = 200

In [None]:
# NOTE: If you get a warning in this cell about weights not being fine-tuned, use an older version of the transformers library
MODEL_NAME = "bigcode/santacoder"
DEVICE = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(DEVICE)
context = get_python_context()

In [None]:
import random
import datasets
from incremental_parsing.evaluation.text_cuts import cut_text_random
import transformers

idx = 0
cut = 0

dataset = datasets.load_dataset("bigcode/the-stack-smol-xl", data_dir="data/python")["train"]
data = dataset[idx]["content"]
random.seed(hash((idx, cut)) % (2 ** 32))
prefix_text, middle, suffix_text = cut_text_random(data, 0, .9, .2, None)
suffix_text += "\n"


begin = datetime.datetime.now()
transformers.set_seed(hash((idx, cut, 0)) % (2 ** 32))

middle_text, *_ = prefix_suffix_constrained_generation(
    tokenizer=tokenizer, model=model, context=context, prefix_text=prefix_text,
    suffix_text=suffix_text, beam_size=BEAM_SIZE, max_generation_length=MAX_GENERATION_LENGTH, device=DEVICE,
    debug=True
)
end = datetime.datetime.now()
td = end - begin

if middle_text is None:
    print("Generation failed")
else:
    print(prefix_text + bg.boldgreen(fg.black(middle_text)) + suffix_text)

print(f"{int(td.total_seconds())}.{td.microseconds // 1000:03} seconds elapsed")

In [None]:
import ast
ast.parse(prefix_text + middle_text + suffix_text)

In [None]:
# Sometimes it won't display properly if the source file used the wrong line endings. Use these next two cells to display result in that case
print(prefix_text.replace("\r", "\n") + bg.boldgreen(fg.black(middle_text.replace("\r", "\n")) + suffix_text.replace("\r", "\n")))

In [None]:
print(data.replace("\r", "\n"))