To use this notebook, you will need a recent version of the Transformers library. I don't know what the minimum version is, but version 4.20 does not work, whereas version 4.25 does work.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
DEVICE = "cuda"

The BigCode models require a little more configuration that off-the-shelf Transformers, which I've included below. I expect some of this won't be necessary in future.

In [4]:
MODEL_NAME = "bigcode/christmas-models"
MODEL_REVISION = "dedup-alt-comments"
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
ENDOFTEXT = "<|endoftext|>"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=MODEL_REVISION, trust_remote_code=True).cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")

tokenizer.add_special_tokens({
  "additional_special_tokens": [ FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, ENDOFTEXT ],
  "pad_token": ENDOFTEXT,
})

Downloading:   0%|          | 0.00/948 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.47k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.60G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.08M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.00 [00:00<?, ?B/s]

5

The BigCode models supports fill-in-the-middle (FIM), and the output may include special tokens used for the FIM task. When decoding output,
if we just strip them away (e.g., with `skip_special_tokens=True`), then we will get jumbled output. The code below is designed to clean up special
tokens in a manner that makes sense for a FIM model.

In [7]:
def truncate_at_first_special_token(output_str):
    """
    Instead, this function clips the output at the first special token that it finds.
    """
    truncate_index = len(output_str)
    for special_token in [ FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, ENDOFTEXT ]:
        ix = output_str.find(special_token)
        if ix != -1 and ix < truncate_index:
            truncate_index = ix
    return output_str[:truncate_index]


def strip_left_padding(output_tensor):
    """
    Since we are not using skip_special_tokens as described above, when batching results of varying length,
    the output will contain <|endoftext|> tokens on the left. This code strips those out.
    """
    start_index = 0
    while output_tensor[start_index].item() == tokenizer.pad_token_id:
        start_index += 1
    return output_tensor[start_index:]


def stop_at_stop_token(decoded_string, stop_tokens):
    """
    Produces the prefix of decoded_string that ends at the first occurrence of
    a stop_token.
    WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
    itself.
    """
    min_stop_index = len(decoded_string)
    for stop_token in stop_tokens:
        stop_index = decoded_string.find(stop_token)
        if stop_index != -1 and stop_index < min_stop_index:
            min_stop_index = stop_index
    return decoded_string[:min_stop_index]


def completions(prompts, stop_tokens = [ "\n\n" ], max_length: int = 128, temperature: float = 0.2, top_p: float = 0.95):
    """
    This function generates completions up to given maximum length or the first stop token. The default stop token
    is two newlines, which is usually the boundary between top-level functions in many programming languages.
    """
    
    if type(prompts) == str:
        prompts = [prompts]

    # `.rstrip` is essential. Trailing whitespae produces really bad completions.
    prompts = [ p.rstrip() for p in prompts ] 
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(DEVICE)
    max_length = max_length + inputs.input_ids[0].size(0)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            max_length=max_length,
            pad_token_id=tokenizer.pad_token_id
        )
    cleaned_output_strs = [ truncate_at_first_special_token(tokenizer.decode(strip_left_padding(output))) for output in outputs ]
    return [
        stop_at_stop_token(output_str[len(prompt):], stop_tokens) for (output_str, prompt) in zip(cleaned_output_strs, prompts)
    ]

Here is a simple example where we get one completion.

In [9]:
prompt = 'def foo(n):\n    """Check if n is prime"""'
     
print(prompt + completions(prompt)[0])

def foo(n):
    """Check if n is prime"""
    if n == 2:
        return True
    if n == 3:
        return True
    if n % 2 == 0 or n % 3 == 0:
        return False
    for i in range(5, int(n ** 0.5) + 1, 6):
        if n % i == 0 or n % (i + 2) == 0:
            return False
    return True


The example below is interesting because it has two inputs of different lengths in a batch, so one of them gets padded. Without the postprocessing above, we would get special tokens below.

In [76]:
prompts = ["def hello(", "if __name__"]
results = completions(prompts)
for (prompt, result) in zip(prompts, results):
    print("********")
    print(prompt + result)   

********
def hello(request):
    return HttpResponse("Hello, world. You're at the polls index.")
********
if __name__ == '__main__':
    main()

