In [1]:
import ast
import random

from ansi.colour import bg, fg
from transformers import AutoTokenizer, AutoModelForCausalLM

from incremental_parsing.evaluation.text_cuts import cut_text_random
from incremental_parsing.generation.constrained_generation import unconstrained_generation, \
    prefix_suffix_constrained_generation, do_prefix_suffix_constrained_generation
from incremental_parsing.generation.utils import tokenizer_int64, create_balanced_context
from incremental_parsing.lex_earley.lark_grammar import get_python_context
import datasets

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
BEAM_SIZE = 1
MAX_GENERATION_LENGTH = 250

In [4]:
MODEL_NAME = "bigcode/santacoder"
DEVICE = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True).to(DEVICE)
context = get_python_context()
dataset = datasets.load_dataset("bigcode/the-stack-smol-xl", data_dir="data/python")["train"]

Found cached dataset json (/home/ec2-user/.cache/huggingface/datasets/bigcode___json/bigcode--the-stack-smol-xl-1fe14832da3eae85/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
import transformers
transformers.set_seed(0)

prefix_text = """
class Counter():
  def __init__():
    self.count = 0

  def inc(self):
    self.count"""

suffix_text = """  1

"""


middle_text = unconstrained_generation(tokenizer=tokenizer,
                                       model=model,
                                       prefix_text=prefix_text,
                                       suffix_text=suffix_text,
                                       beam_size=BEAM_SIZE,
                                       max_new_tokens=MAX_GENERATION_LENGTH,
                                       device=DEVICE)

print(middle_text)



Setting `pad_token_id` to `eos_token_id`:49152 for open-end generation.


 += 1

  def dec(self):
    self.count -= 1

  def get(self):
    return self.count

c = Counter()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c


In [6]:
transformers.set_seed(0)
middle_text, _ = prefix_suffix_constrained_generation(
    tokenizer=tokenizer, model=model, context=context, prefix_text=prefix_text,
    suffix_text=suffix_text, beam_size=BEAM_SIZE, max_generation_length=MAX_GENERATION_LENGTH, device=DEVICE,
    debug=True#"filtered"
)

print(middle_text)

Setting `pad_token_id` to `eos_token_id`:49152 for open-end generation.


 [41m[30m+[0m[0m[42m[30m=[0m[0m[43m[30m [0m[0m1[44m[30m
[0m[0m[45m[30m
[0m[0m [46m[30m [0m[0mdef dec(self)[47m[30m:[0m[0m
[40m[37m [0m[0m[41m[30m [0m[0m[42m[30m [0m[0m[43m[30m [0m[0mself.count [44m[30m-[0m[0m[45m[30m=[0m[0m[46m[30m [0m[0m1[47m[30m
[0m[0m[40m[37m
[0m[0m [41m[30m [0m[0mdef get(self)[42m[30m:[0m[0m
[43m[30m [0m[0m[44m[30m [0m[0m[45m[30m [0m[0m[46m[30m [0m[0mretur[47m[30mn[0m[0m[40m[37m [0m[0mself.count[41m[30m
[0m[0m[42m[30m
[0m[0mc [43m[30m=[0m[0m[44m[30m [0m[0mCounter()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c.inc()
c
Scores:
[41m[30m +[0m[0m : -1.367
[42m[30m=[

In [7]:
transformers.set_seed(0)

prefix_text = """
class Counter():
  def __init__():
    self.count = 0

  def inc(self):
    '''"""

suffix_text = """''
    self.count += 1

"""


middle_text = unconstrained_generation(tokenizer=tokenizer,
                                       model=model,
                                       prefix_text=prefix_text,
                                       suffix_text=suffix_text,
                                       beam_size=BEAM_SIZE,
                                       max_new_tokens=MAX_GENERATION_LENGTH,
                                       device=DEVICE)

print(middle_text)

Setting `pad_token_id` to `eos_token_id`:49152 for open-end generation.



    Increments the counter
    '


In [8]:
def show_example(data_idx: int, cut_idx: int, also_unconstrained: bool = False):
    chosen_dataset_content = dataset[data_idx]["content"]
    random.seed(hash((data_idx, cut_idx)) % (2 ** 32))
    prefix, middle, suffix = cut_text_random(chosen_dataset_content, 0, .9, .2)
    suffix = suffix + "\n"
    
    pre_input_ids, pre_attention_mask = tokenizer_int64(tokenizer, prefix)
    post_input_ids, post_attention_mask = tokenizer_int64(tokenizer, suffix)

    input_ids, attention_mask = create_balanced_context(
        pre_input_ids=pre_input_ids, pre_attention_mask=pre_attention_mask,
        post_input_ids=post_input_ids, post_attention_mask=post_attention_mask,
        tokenizer=tokenizer, max_generation_length=500, device="cuda:0"
    )

    new_output_text_constrained, full_constrained = do_prefix_suffix_constrained_generation(
        tokenizer=tokenizer, model=model, context=context,
        input_ids=input_ids, attention_mask=attention_mask,
        prefix_text=prefix, suffix_text=suffix,
        pre_input_ids=pre_input_ids, post_input_ids=post_input_ids,
        beam_size=1, max_generation_length=500,
        debug=False
    )

    if new_output_text_constrained is not None:
        print(prefix + bg.boldgreen(fg.black(new_output_text_constrained)) + suffix)
    else:
        print(prefix + bg.boldcyan(fg.black(full_constrained)) + suffix)

    if also_unconstrained:
        print("\n" + bg.boldmagenta("----------------") + "\n")
        outputs_unconstrained = model.generate(input_ids=input_ids, attention_mask=attention_mask,
                                                max_new_tokens=500,
                                                num_beams=1, num_return_sequences=1,
                                                early_stopping=True)

        len_input_tokens = input_ids.shape[1]

        new_output_tokens_unconstrained = outputs_unconstrained[0][len_input_tokens:]
        new_output_text_unconstrained = tokenizer.decode(new_output_tokens_unconstrained, skip_special_tokens=True,
                                                     clean_up_tokenization_spaces=False)

        if new_output_text_unconstrained is not None:
            print(prefix + bg.boldred(fg.black(new_output_text_unconstrained)) + suffix)
            


In [None]:
show_example(1470, 5, True)

In [None]:
show_example(8440, 1)

In [None]:
show_example(1804, 4)

In [None]:
show_example(8551, 8)

In [None]:
show_example(7396, 1)

In [None]:
show_example(8, 5)