In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
# from models import BigCodeModel, add_model_args, add_infilling_args
# import argparse
# parser = argparse.ArgumentParser()
# add_model_args(parser)
# add_infilling_args(parser)

# args = parser.parse_args([])
# model = BigCodeModel(args, "bigcode/large-model")

In [3]:
# model.infill(["def read_file(filename):\n   \"\"\"", "\"\"\"\n    with open(filename, 'r') as f:\n        return f.read()"], stop_words=None, truncation_parameters=None, temperature=0.8)

In [4]:
END_OF_TEXT = "<|endoftext|>"

In [5]:
MODEL_NAME = "bigcode/large-model"
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_PAD = "<fim-pad>"

In [6]:
MODEL_NAME = "bigcode/santacoder"
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"

In [7]:
SPEC_TOKS = [END_OF_TEXT, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]

In [8]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, trust_remote_code=True).cuda()

Explicitly passing a `revision` is encouraged when loading a configuration with custom code to ensure no malicious code has been contributed in a newer revision.
Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.


In [9]:
tokenizer.add_special_tokens({
    'additional_special_tokens': SPEC_TOKS,
    'pad_token': END_OF_TEXT,
})

0

In [10]:
def complete(text, **kwargs):
    encoded = tokenizer.batch_encode_plus([text], return_tensors="pt")
    encoded = encoded.to(torch.device("cuda"))
    with torch.inference_mode():
        generated = model.generate(**encoded, **kwargs, pad_token_id=tokenizer.pad_token_id)
#     print(generated)
    for ix in generated.flatten():
        print((ix.item(), tokenizer.decode(ix)))
    return tokenizer.batch_decode(generated)[0]

In [11]:
def infill(prefix, suffix, **kwargs):
    prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
    text = complete(prompt, **kwargs)
    if END_OF_TEXT in text:
        text = text.split(END_OF_TEXT)[0]
    _, middle = text.split(FIM_MIDDLE)
    return middle

In [12]:
prefix = '''def test_invite_too_man_users(self) -> '''
suffix = '''
    self.login(self.example_email("iago"))'''

In [13]:
print(infill(prefix, suffix, max_new_tokens=28, do_sample=False))

(49153, '<fim-prefix>')
(563, 'def')
(703, ' test')
(62, '_')
(18970, 'invite')
(62, '_')
(15267, 'too')
(62, '_')
(3165, 'man')
(62, '_')
(3685, 'users')
(7, '(')
(314, 'self')
(8, ')')
(1208, ' ->')
(207, ' ')
(49155, '<fim-suffix>')
(258, '\n   ')
(366, ' self')
(13, '.')
(3157, 'login')
(7, '(')
(314, 'self')
(13, '.')
(1887, 'example')
(62, '_')
(2490, 'email')
(372, '("')
(3990, 'iag')
(78, 'o')
(1611, '"))')
(49154, '<fim-middle>')
(1138, 'None')
(25, ':')
(258, '\n   ')
(635, ' """')
(804, 'Test')
(954, ' that')
(4852, ' inv')
(268, 'it')
(291, 'ing')
(373, ' a')
(931, ' user')
(669, ' with')
(7833, ' too')
(7053, ' many')
(8696, ' members')
(9502, ' fails')
(2364, '."""')
(258, '\n   ')
(366, ' self')
(13, '.')
(3157, 'login')
(7, '(')
(314, 'self')
(0, '!')
(13, '.')
(1887, 'example')
(62, '_')
(2490, 'email')
None:
    """Test that inviting a user with too many members fails."""
    self.login(self!.example_email


In [15]:
tokenizer.special_tokens_map

{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>',
 'pad_token': '<|endoftext|>',
 'additional_special_tokens': ['<|endoftext|>',
  '<fim-prefix>',
  '<fim-middle>',
  '<fim-suffix>',
  '<fim-pad>']}