# Using the BigCode Models

*Authored by Arjun Guha, with technical assistance from Raymond Li and Carolyn Jane Anderson.*

To use this notebook, you will need a recent version of the Transformers library. I don't know what the minimum version is, 
but version 4.20 does not work, whereas version 4.25 does work.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

Change the following to `"cuda"` if you have a GPU. This notebook will work on a GPU with 8GB VRAM, 
and will probably work with less.

In [2]:
DEVICE = "cpu"

The BigCode models require a little more configuration that off-the-shelf Transformers, which I've included below. I expect some of this won't be necessary in future.

In [3]:
MODEL_NAME = "bigcode/santacoder"
MODEL_REVISION = "dedup-alt-comments"
FIM_PREFIX = "<fim-prefix>"
FIM_MIDDLE = "<fim-middle>"
FIM_SUFFIX = "<fim-suffix>"
FIM_PAD = "<fim-pad>"
ENDOFTEXT = "<|endoftext|>"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, revision=MODEL_REVISION, trust_remote_code=True).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")

# Note that the special tokens must be listed in the order below.
tokenizer.add_special_tokens({
  "additional_special_tokens": [ ENDOFTEXT, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD ],
  "pad_token": ENDOFTEXT,
})

5

The BigCode models supports fill-in-the-middle (FIM), and the output may include special tokens used for the FIM task. When decoding output,
if we just strip them away (e.g., with `skip_special_tokens=True`), then we will get jumbled output. The code below is designed to clean up special
tokens in a manner that makes sense for a FIM model.

In [4]:
def strip_left_padding(output_tensor):
    """
    Since we are not using skip_special_tokens as described above, when
    batching results of varying length, the output will contain
    <|endoftext|> tokens on the left. This code strips those out, but
    does not remove special tokens after the text.
    """
    start_index = 0
    while output_tensor[start_index].item() == tokenizer.pad_token_id:
        start_index += 1
    return output_tensor[start_index:]


def truncate_at_first_special_token(output_str):
    """
    This function clips the output at the first special token that it finds.
    We use it after strip_left_padding.
    """
    truncate_index = len(output_str)
    for special_token in [ FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD, ENDOFTEXT ]:
        ix = output_str.find(special_token)
        if ix != -1 and ix < truncate_index:
            truncate_index = ix
    return output_str[:truncate_index]


def stop_at_stop_token(decoded_string, stop_tokens):
    """
    Produces the prefix of decoded_string that ends at the first occurrence of
    a stop_token. The decoded_string must not include the prompt, which may
    have stop token itself.
    """
    min_stop_index = len(decoded_string)
    for stop_token in stop_tokens:
        stop_index = decoded_string.find(stop_token)
        if stop_index != -1 and stop_index < min_stop_index:
            min_stop_index = stop_index
    return decoded_string[:min_stop_index]


def completions(prompts, stop_tokens = [ "\n\n" ], max_length: int = 128, temperature: float = 0.2, top_p: float = 0.95):
    """
    This function generates completions up to given maximum length or the first
    stop token. The default stop token is two newlines, which is usually the
    boundary between top-level functions in many programming languages.
    """
    
    output_list = True
    if type(prompts) == str:
        prompts = [prompts]
        output_list = False

    # `.rstrip` is essential. Trailing whitespae produces really bad completions.
    prompts = [ p.rstrip() for p in prompts ] 
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(DEVICE)
    max_length = max_length + inputs.input_ids[0].size(0)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            max_length=max_length,
            pad_token_id=tokenizer.pad_token_id
        )
    cleaned_output_strs = [
        truncate_at_first_special_token(tokenizer.decode(strip_left_padding(output))) 
        for output in outputs
    ]
    result = [
        stop_at_stop_token(output_str[len(prompt):], stop_tokens) 
        for (output_str, prompt) in zip(cleaned_output_strs, prompts)
    ]
    return result if output_list else result[0]

Here is a simple example where we get one completion.

In [5]:
prompt = 'def foo(n):\n    """Check if n is prime"""'
     
print(prompt + completions(prompt))

def foo(n):
    """Check if n is prime"""
    if n < 2:
        return False
    for i in range(2, n):
        if n % i == 0:
            return False
    return True


The example below is interesting because it has two inputs of different lengths in a batch, so one of them gets padded. Without the postprocessing above, we would get special tokens below.

In [6]:
prompts = ["def hello(", "if __name__"]
results = completions(prompts)
for (i, (prompt, result)) in enumerate(zip(prompts, results)):
    print(f"******* Result {i} *********")
    print(prompt + result)   

******* Result 0 *********
def hello(request):
    return HttpResponse("Hello, world. You're at the polls index.")
******* Result 1 *********
if __name__ == '__main__':
    main()



The next example uses the model to generate a docstring from a function name. It uses `"""` as a stop token to terminate at the end of the docstring.

In [7]:
print(completions('def fac(n):\n    """', stop_tokens= [ '"""' ]))


    Factorial of a number
    :param n: number
    :return: factorial of n
    


## Using Fill-in-the-Middle (FIM)

The BigCode models support fill-in-the-middle or infilling. The `infill` function below takes in a (list of) prefix-suffix tuples, and produces code that goes between them.

In [8]:
def extract_fim_part(s: str):
    """
    Find the index of <fim-middle>
    """
    start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
    stop = s.find(ENDOFTEXT, start) or len(s)
    return s[start:stop]

def infill(prefix_suffix_tuples, max_tokens: int = 50, temperature: float = 0.2, top_p : float = 0.95):
    output_list = True
    if type(prefix_suffix_tuples) == tuple:
        prefix_suffix_tuples = [prefix_suffix_tuples]
        output_list = False
        
    prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
    # `return_token_type_ids=False` is essential, or we get nonsense output.
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False).to(DEVICE)
    max_length = inputs.input_ids[0].size(0) + max_tokens
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            top_p=top_p,
            temperature=temperature,
            max_length=max_length,
            pad_token_id=tokenizer.pad_token_id
        )
    # WARNING: cannot use skip_special_tokens, because it blows away the FIM special tokens.
    result = [
        extract_fim_part(tokenizer.decode(tensor, skip_special_tokens=False)) for tensor in outputs
    ]
    return result if output_list else result[0]

In the example below, we use FIM to fill in the base case of the fibonacci function.

In [9]:
prefix = """def fib(n):"""

suffix = """    else:
        return fib(n - 2) + fib(n - 1)"""

middle = infill((prefix, suffix))
print(prefix + middle + suffix)

def fib(n):
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n - 2) + fib(n - 1)
