In [None]:
!pip install -q -U bitsandbytes

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import re
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
device = 'cuda'

models = {'tiny_starcoder_py': {'quantization': False},
          'starcoder2-3b': {'quantization': True},
          'starcoder2-7b': {'quantization': True},
          'starcoder2-15b': {'quantization': True},
          }

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

In [None]:
# Set your path
python_dataset_path = '/content/drive/MyDrive/code_completion_jb/data/python_dataset.csv'
df = pd.read_csv(python_dataset_path)
df.head(3)

Unnamed: 0.1,Unnamed: 0,prefix,tag,content,suffix,file_name
0,0,import argparse\nfrom rules import create_rule...,code_by_description,"datetime.now().strftime(""%Y%m%d%H%M%S%f"")",\n\n\ndef parse(message):\n # Extracts the ...,multi_agent_simulation.txt
1,1,import argparse\nfrom rules import create_rule...,code_by_description,assert '<s>' in message and '</s>' in message\...,\n return message[start:end]\n\n\ndef parse...,multi_agent_simulation.txt
2,2,import argparse\nfrom rules import create_rule...,code_by_description,assert '<s>' in message and '</s>' in message\...,"\n return message[start:end], message[end +...",multi_agent_simulation.txt


In [4]:
def format_prompt(prefix, suffix):
    return f"""<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"""

def format_middle_output(text):
    prefix = re.search('<fim_prefix>(.*?)<fim_suffix>', text, re.DOTALL).group(1)
    suffix = re.search('<fim_suffix>(.*?)<fim_middle>', text, re.DOTALL).group(1)
    try:
        output = re.search('<fim_middle>(.*?)<file_sep>', text, re.DOTALL).group(1)
    except:
        output = re.search('<fim_middle>(.*)', text).group(1).replace('<|endoftext|>', '')
    return (prefix, output, suffix)

In [5]:
params = {
    'max_new_tokens': 128,
    'temperature': 0.2,
    'top_k': 50,
    'top_p': 0.1,
    'repetition_penalty': 1.17,
    'do_sample': True
}

# Generation
This function generates code completions for each row in the dataset using the provided model and tokenizer, then stores the generated output in the specified column of the dataframe.<br>
Final dataset consists of:
`prefix, tag, content, suffix, file_name, gen_tiny_starcoder_py, gen_starcoder2_3b, gen_starcoder2_7b, gen_starcoder2_15b`


In [6]:
color = lambda s: f"\033[96m{s}\033[00m"

def generate_code(model, tokenizer, dataset, column, verbose=True):
    for index, row in dataset.iterrows():
        prompt = format_prompt(row.prefix, row.suffix)
        inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)

        outputs = model.generate(inputs, pad_token_id=tokenizer.eos_token_id, **params)
        prefix, output, suffix = format_middle_output(tokenizer.decode(outputs[0]))
        if verbose:
            print(f'Index: {color(index + 1)}\n'
                  f'Code: {prefix[-250:].lstrip()}{color(output)}{suffix[:250].rstrip()}\n'
                  f'Tag: {color(row.tag)}\n')

        df.at[index, column] = output

This code iterates over a dictionary of model configurations, loads the tokenizer and model for each checkpoint, generates code completions for each model using the generate_code function, and then stores the generated output in the corresponding column of the dataframe. Then saves results after each model's generation.


In [None]:
for model_name, config in models.items():
    checkpoint = f'bigcode/{model_name}'
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    if config['quantization']:
      model = AutoModelForCausalLM.from_pretrained(checkpoint, quantization_config=quantization_config)
    else:
      model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

    gen_column = 'gen' + model_name.replace('-', '_')
    df[gen_column] = ''
    generate_code(model, tokenizer, df, gen_column)

    # Set your save path
    save_path = '/content/drive/MyDrive/code_completion_jb/data/python_dataset_gen.csv'
    df.to_csv(save_path, index=False)
    df.head(3)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/532 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.03k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/657M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Index: [96m1[00m
Code: logger
import time
import random
import problems_config as pcfg
from datetime import datetime


def generate_time_based_id():
    # Get the current time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return [96mdatetime.now().strftime("%Y%m%d_%H%M%S")[00m


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    assert '<s>' in message and '</s>' in message
    start = message.index('<s>') + len('<s>')
    end = message.index('</s>')
    return message[star
Tag: [96mcode_by_description[00m

Index: [96m2[00m
Code: ent time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return datetime.now().strftime("%Y%m%d%H%M%S%f")


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    [96mmessage = message.split("<s>")[00m
    return message[start:end]


def parse_action(message, choices):
    

tokenizer_config.json:   0%|          | 0.00/7.88k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/958 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors:   0%|          | 0.00/12.1G [00:00<?, ?B/s]

Index: [96m1[00m
Code: logger
import time
import random
import problems_config as pcfg
from datetime import datetime


def generate_time_based_id():
    # Get the current time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return [96mdatetime.now().strftime('%Y%m%d%H%M%S%f')[00m


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    assert '<s>' in message and '</s>' in message
    start = message.index('<s>') + len('<s>')
    end = message.index('</s>')
    return message[star
Tag: [96mcode_by_description[00m

Index: [96m2[00m
Code: ent time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return datetime.now().strftime("%Y%m%d%H%M%S%f")


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    [96massert '<s>' in message and '</s>' in message
    start = message.index('<s>') + len('<s>')
    end = messa

tokenizer_config.json:   0%|          | 0.00/7.88k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/958 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/893 [00:00<?, ?B/s]

`low_cpu_mem_usage` was None, now set to True since model is quantized.


model.safetensors.index.json:   0%|          | 0.00/41.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.89G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.51G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Index: [96m1[00m
Code: logger
import time
import random
import problems_config as pcfg
from datetime import datetime


def generate_time_based_id():
    # Get the current time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return [96mdatetime.now().strftime("%Y%m%d%H%M%S%f")[:-3][00m


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    assert '<s>' in message and '</s>' in message
    start = message.index('<s>') + len('<s>')
    end = message.index('</s>')
    return message[star
Tag: [96mcode_by_description[00m

Index: [96m2[00m
Code: ent time in the format YYYYMMDDHHMMSSFFF (year, month, day, hour, minute, second, millisecond)
    return datetime.now().strftime("%Y%m%d%H%M%S%f")


def parse(message):
    # Extracts the substring between <s> and </s> tags in the given message
    [96massert '<s>' in message and '</s>' in message
    start = message.index('<s>') + len('<s>')
    end = 