In [1]:
import re
import os
import json

def normalize_empty_lines(code: str) -> str:
    """
    Normalize consecutive empty lines in a string to a maximum of two.

    Args:
        code (str): Code to normalize.
    
    Returns:
        str: Normalized code.
    """
    normalized_code = re.sub(r'\n{4,}', '\n\n', code)
    return normalized_code

def construct_prompt(
    data: dict, 
    version: str = "special",
    repo_token: str = "<repo_name>",
    file_token: str = "<file_sep>",
    fim: bool = False,
    language: str = "python",
    tokenizer= None,
    max_prompt_length = None
    ) -> str:
    """
    Constructs a prompt for the specified model version.

    Args:
        data: the data to construct the prompt from
        version: 'special', 'normal' or 'baseline'
        repo_token: the token to use for the repo name
        file_token: the token to use for the file path
        fim: whether to use FIM (Fill-In-the-Middle) or not
        tokenizer: the tokenizer to use for tokenizing the prompt if specified
        max_prompt_length: the maximum length of the prompt if specified
    
    Returns:
        prompt: the constructed prompt or a list of prompts if version is 'all'
    """
    
    assert version in ["special", "normal", "baseline"], "version must be one of ['special', 'normal', 'baseline']"
    assert language in ["python", "java"], "language must be one of ['python', 'java']"
    assert tokenizer is not None, "tokenizer must be specified"
    assert max_prompt_length is not None, "max_prompt_length must be specified"

    
    repo_name = data['repo_name']
    file_path = data['file_path']
    code = data['cropped_code']
    import_statement = data['import_statement']

    # special token version
    if version == "special":
        repo_prompt = f"{repo_token}{repo_name}"
        for snippet in data['context']:
            repo_prompt += f"{file_token}{snippet['path']}\n{snippet['snippet']}"
            
        if fim:
            in_file_prompt = f"{file_token}<fim_prefix>{file_path}\n{import_statement}\n{code}<fim_suffix><fim_middle>"
        else:
            in_file_prompt = f"{file_token}{file_path}\n{import_statement}\n{code}"
        
        if tokenizer is not None and max_prompt_length is not None:
            repo_prompt_token_num = len(tokenizer.encode(repo_prompt))
            in_file_prompt_token_num = len(tokenizer.encode(in_file_prompt))
            
            extra_token_num = repo_prompt_token_num + in_file_prompt_token_num - max_prompt_length
            if extra_token_num > 0:
                # split the repo prompt by lines
                repo_prompt_lines = repo_prompt.split("\n")
                # drop lines from end until the extra token number is less than 0
                for i in range(len(repo_prompt_lines)-1, -1, -1):
                    extra_token_num -= len(tokenizer.encode(repo_prompt_lines[i]))
                    if extra_token_num < 0:
                        break
                
                # join the lines back
                repo_prompt = "\n".join(repo_prompt_lines[:i+1])
            
        prompt = repo_prompt + in_file_prompt
    
    # normal version
    elif version == "normal":
        comment_symbol = "#" if language == "python" else "//"
        repo_prompt = f"{comment_symbol} Repo Name: {data['repo_name']}\n"
        for snippet in data['context']:
            repo_prompt += f"{comment_symbol} Path: {snippet['path']}\n{snippet['snippet']}" + "\n"
        
        if fim:
            in_file_prompt = f"<fim_prefix>{comment_symbol} Path: {file_path}\n{import_statement}\n{code}<fim_suffix><fim_middle>"
        else:
            in_file_prompt = f"{comment_symbol} Path: {file_path}\n{import_statement}\n{code}"
        
        repo_prompt_token_num = len(tokenizer.encode(repo_prompt))
        in_file_prompt_token_num = len(tokenizer.encode(in_file_prompt))
        
        extra_token_num = repo_prompt_token_num + in_file_prompt_token_num - max_prompt_length
        if extra_token_num > 0:
            # split the repo prompt by lines
            repo_prompt_lines = repo_prompt.split("\n")
            # drop lines from end until the extra token number is less than 0
            for i in range(len(repo_prompt_lines)-1, -1, -1):
                extra_token_num -= len(tokenizer.encode(repo_prompt_lines[i]))
                if extra_token_num < 0:
                    break
            
            # join the lines back
            repo_prompt = "\n".join(repo_prompt_lines[:i+1])+ "\n"
            
        prompt = repo_prompt + in_file_prompt
    
    # baseline version
    elif version == "baseline":
        comment_symbol = "#" if language == "python" else "//"
        in_file_prompt = f"{comment_symbol} Path: {file_path}\n{import_statement}\n{code}"
        
        if fim:
            in_file_prompt = f"<fim_prefix>{in_file_prompt}<fim_suffix><fim_middle>"
        else:
            pass
        
        prompt = in_file_prompt
        
    return normalize_empty_lines(prompt)

def get_first_line_not_comment(code:str, language:str="python"):
    """
    This function gets the first line of code that is not a comment.

    Args:
    code: Str, the code

    Returns:
    Str, the first line of code that is not a comment or the first line of code if there is no line that is not a comment
    """

    # check if the language is valid
    assert language in ["python", "java"], "language must be one of [python, java]"


    # first remove the \n at the beginning of the code
    code = code.lstrip('\n')

    lines = code.split('\n')
    in_multiline_comment = False

    if language == "python":
        for line in lines:
            # if the line is empty, then skip
            if not line.strip():
                continue
            # if the line is a start of a multiline comment, then set the in_multiline_comment to True and skip
            if not in_multiline_comment and (line.strip().startswith('"""') or line.strip().startswith("'''")):
                in_multiline_comment = True
                continue
            # if the line is the end of a multiline comment, then set the in_multiline_comment to False and skip
            if in_multiline_comment and (line.strip().endswith('"""') or line.strip().endswith("'''")):
                in_multiline_comment = False
                continue
            # if the line is in a multiline comment, then skip
            if in_multiline_comment:
                continue
            # if the line is a single line comment, then skip
            if line.strip().startswith('#'):
                continue
            # if the line is not a comment, then return the line
            return line
        
    elif language == "java":
        for line in lines:
            # if the line is empty, then skip
            if not line.strip():
                continue
            # if the line is a start of a multiline comment, then set the in_multiline_comment to True and skip
            if not in_multiline_comment and line.strip().startswith('/*'):
                in_multiline_comment = True
                continue
            # if the line is the end of a multiline comment, then set the in_multiline_comment to False and skip
            if in_multiline_comment and line.strip().endswith('*/'):
                in_multiline_comment = False
                continue
            # if the line is in a multiline comment, then skip
            if in_multiline_comment:
                continue
            # if the line is a single line comment, then skip
            if line.strip().startswith('//'):
                continue
            # if the line is not a comment, then return the line
            return line


    # if we cannot find a line that is not a comment, then return the first line
    return lines[0]

In [2]:
# load dataset
from datasets import load_dataset

dataset = load_dataset("tianyang/repobench_python_v1.1")

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# load model
model_name = "bigcode/starcoderbase-3b"

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    device_map="auto",        
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [4]:
prompt = construct_prompt(
    dataset['cross_file_first'][31],
    version="normal",
    tokenizer=tokenizer,
    max_prompt_length=7800,
)

In [5]:
print(prompt)

# Repo Name: see2023/Bert-VITS2-ext
# Path: config.py
class Resample_config:
class Preprocess_text_config:
class Bert_gen_config:
class Emo_gen_config:
class Train_ms_config:
class Webui_config:
class Server_config:
class Translate_config:
class Config:
    def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
    def from_dict(cls, dataset_path: str, data: Dict[str, any]):
    def __init__(
        self,
        transcription_path: str,
        cleaned_path: str,
        train_path: str,
        val_path: str,
        config_path: str,
        val_per_lang: int = 5,
        max_val_total: int = 10000,
        clean: bool = True,
    ):
    def from_dict(cls, dataset_path: str, data: Dict[str, any]):
    def __init__(
        self,
        config_path: str,
        num_processes: int = 2,
        device: str = "cuda",
        use_multi_device: bool = False,
    ):
    def from_dict(cls, dataset_path: str, data: Dict[str, any]):
    def __init__(
        self,
     

In [9]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
model_outputs = model.generate(**model_inputs, max_new_tokens=128, do_sample=True, temperature=0.2, top_p=0.95)

generated_code = tokenizer.decode(model_outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)

print(generated_code)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


# Path: for_deploy/infer_utils.py
import sys
import torch
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DebertaV2Model,
    DebertaV2Tokenizer,
    ClapModel,
    ClapProcessor,
)
from config import config
from text.japanese import text2sep_kata

class BertFeature:
    def __init__(self, model_path, language="ZH"):
        self.model_path = model_path
        self.language = language
        self.tokenizer = None
        self.model


In [7]:
model_inputs = tokenizer(f"<fim_prefix>{prompt}<fim_suffix><fim_middle>", return_tensors="pt").to(model.device)
model_outputs = model.generate(**model_inputs, max_new_tokens=128, do_sample=True, temperature=0.2, top_p=0.95)

generated_code = tokenizer.decode(model_outputs[0][model_inputs["input_ids"].shape[-1]:], skip_special_tokens=True)

print(generated_code)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


            sep, ph, accent = text2sep_kata(text)
        else:
            sep, ph, accent = text.split(), text.split(), None
        sep = sep[:config.bert_gen_config.max_sep_len]
        sep = sep + ["[SEP]"]
        sep = sep[:config.bert_gen_config.max_sep_len]
        sep = sep + ["[CLS]"]
        sep = sep[:config.bert_gen_config.max_sep_len]
        sep = sep + ["[SEP]"]
        sep = sep[:config
