In [40]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import subprocess

def read_method_code(path) -> str: 
    # for future, code will be modified and further its values
    # will be used for the form_prompt_for_method
    pass

def initialize_models():
    tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-7b-base-v1.5", trust_remote_code=True, device_map = "auto")
    model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-7b-base-v1.5", trust_remote_code=True, device_map = "auto")
    return tokenizer, model

def form_prompt_for_method(method_code: str) -> str:
    prefix = "import pytest"
    test_comment = "# test for the method above\n# those tests cover each possible branch just once, no excessive repeats"
    prompt = "\n\n".join([prefix, method_code, test_comment])
    return prompt
    
def generate_code_tests(method_code: str, tokenizer: AutoTokenizer, model: AutoModelForCausalLM) -> str:
    prompt = form_prompt_for_method(method_code)
    tokenized_prompt = tokenizer(prompt, return_tensors="pt")
    tokenized_output = model.generate(**tokenized_prompt, max_new_tokens=200, do_sample = False, use_cache = True)
    whole_code = tokenizer.decode(tokenized_output[0], skip_special_tokens=True)
    return whole_code

def fix_interrupted_gen(whole_code: str) -> str:
    function_split = whole_code.split("def")
    function_split_count = len(function_split)
    # if there are more than two functions the last one could be interrupted in the middle
    # we want to get rid of such a function so that code is interpretable
    # if there are just two functions last line could be interrupted
    if function_split_count > 3:
        valid_parts = function_split[:-1]
        working_code = 'def'.join(valid_parts)
    else:
        line_split = whole_code.split("\n")
        valid_lines = line_split[:-1]
        working_code = '\n'.join(valid_lines)
    
    return working_code

def write_code_tests(code: str, path: str = "code_test.py") -> None:
    with open(path, 'w') as file:
        file.write(code)

def get_total_file_coverage(path: str = "code_test.py") -> int:
    test_command = f"pytest {path} --cov={path.split('.')[0]}"
    test_result = subprocess.run(test_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True)
    out = test_result.stdout

    # identify total result line
    linesplit = out.split('\n')
    for line in linesplit:
        if line[:5] == 'TOTAL':
            total_line = line
            break
    
    # retrieve just the coverage info
    coverage_string = total_line.split()[-1]
    coverage_percent = int(coverage_string[:-1])
    return coverage_percent

In [37]:
tokenizer, model = initialize_models()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 3/3 [00:08<00:00,  2.70s/it]


In [39]:
whole_code = generate_code_tests("""def max_of_three(a, b, c):
    \"\"\"
    This function returns the maximum of three numbers.
    \"\"\"
    if a >= b and a >= c:
        return a
    elif b >= a and b >= c:
        return b
    else:
        return c

""", tokenizer, model)

working_code = fix_interrupted_gen(whole_code)
write_code_tests(working_code)
test_coverage = get_total_file_coverage()

print(f"Test coverage is {test_coverage}%")

Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Test coverage is 100%
