# Utility Methods

In [22]:
#| default_exp utils

In [50]:
#| export
import CodeCheckList
import ast
import random
import re
import time
from func_timeout import func_set_timeout, FunctionTimedOut
from multiprocessing import Process, Queue

In [24]:
#| export
# From: https://github.com/github/CodeSearchNet/tree/master/function_parser
def traverse(
    node,       # tree-sitter node
    results,    # list to append results to
) -> None:
    """Traverse in a recursive way, a tree-sitter node and append results to a list."""
    if node.type == 'string':
        results.append(node)
        return
    for n in node.children:
        traverse(n, results)
    if not node.children:
        results.append(node)

In [25]:
#| export 
def find_nodes(
    node,               #Tree sitter ast treee
    target_node_type,   #Target node type to search in the tree
    results,            #List to append the resutls to
) -> None: 
    """Traverses the tree and find the specified node type"""
    if node.type == target_node_type:
        results.append(node)
        return
    for n in node.children:
        find_nodes(n, target_node_type, results)    

In [26]:
#| export
def get_node_type_list(
    node
) -> None:
    """Traverses the tree and get all the node types"""
    node_types = []
    def traverse_and_get_types(node, node_type_set):
        node_type_set.append(node.type)
        for n in node.children:
            traverse_and_get_types(n, node_type_set)
    traverse_and_get_types(node, node_types)
    return node_types

In [27]:
#TODO
#def calculate_tree_edit_distance(predicted_code_tree, source_code_tree)

In [28]:
#| export
def unroll_node_types(
    nested_node_types: dict, # node_types from tree-sitter
) -> list: # list of node types
    """Unroll nested node types into a flat list of node types. This includes subtypes as well."""
    node_types = [node_type["type"] for node_type in nested_node_types]
    node_subtypes = [
        node_subtype["type"]
        for node_type in nested_node_types
        if "subtypes" in node_type
        for node_subtype in node_type["subtypes"]
    ]
    children_subtypes = [
        children_type["type"]
        for node_type in nested_node_types
        if "children" in node_type
        for children_type in node_type["children"]["types"]
    ]
    alias_subtypes = [
        children_type["type"]
        for node_type in nested_node_types
        if "fields" in node_type and "alias" in node_type["fields"] 
        for children_type in node_type["fields"]["alias"]["types"]
    ]
    return list(set(node_types + node_subtypes + children_subtypes + alias_subtypes + ['ERROR']))

In [29]:
#| export
def convert_to_offset(
    point,              #point to convert
    lines: list         #list of lines in the source code
    ):
        """Convert the point to an offset"""
        row, column = point
        chars_in_rows = sum(map(len, lines[:row])) + row
        chars_in_columns = len(lines[row][:column])
        offset = chars_in_rows + chars_in_columns
        return offset

In [30]:
#| export
def get_sub_set_test_set(test_set, test_size:int):
    sub_samples = []
    for sample in test_set:
        sub_samples.append(sample)
        if len(sub_samples)>=test_size:
            break
    return sub_samples

In [31]:
#| export
def get_random_sub_set_test_set(test_set, test_size:int):
    sub_samples = []
    while len(sub_samples)<test_size:
        random_index = random.randrange(0,len(test_set))
        sub_samples.append(test_set[random_index])
    return sub_samples

In [32]:
#| export 
def is_valid_code(code):
    try:
        ast.parse(code)
    except:
        return False
    return True
    

In [33]:
#| export
def is_balanced_snippet(snippet, threshold):
    """This method is used to prevent a kernel blocking when tree-sitter tries to parse a buggy snippet"""
    proportion = len(re.findall(r"([a-zA-Z0-9])", snippet))/len(re.findall(r"([^a-zA-Z0-9])", snippet))
    num_buggy_assigns = len(re.findall(r"[^a-zA-Z0-9]+[=]+[^a-zA-Z0-9=]+[=]+[^a-zA-Z0-9=]+", snippet))
    num_buggy_tabs = len(re.findall(r"\n+[\t]+[^a-zA-Z0-9\t]+[\t]+[^a-zA-Z0-9]+", snippet))
    num_buggy_spaces = len(re.findall(r"\n+[ ]+[^a-zA-Z0-9 ]+[ ]+[^a-zA-Z0-9]+", snippet))
    return proportion > threshold and num_buggy_assigns == 0 and num_buggy_tabs == 0

In [34]:
#| export
def get_test_sets(test_set, language, max_token_number, model_tokenizer, with_ranks=False, num_proc=1):
    subset = test_set.filter(lambda sample: True if sample['language']== language 
            and len(sample['func_code_tokens']) < max_token_number
            and len(model_tokenizer.tokenizer(sample['whole_func_string'])['input_ids']) < max_token_number
            and is_balanced_snippet(sample['whole_func_string'], 1)
            else False, num_proc=num_proc)
    return subset

In [35]:
#| export 
def get_elements_by_percentage(elements, percentage):
    indexes = set(random.sample(list(range(len(elements))), int(percentage*len(elements))))
    return [n for i,n in enumerate(elements) if i in indexes]

In [68]:
def run_parser_subprocess(code, parser):

    def parse_code(queue, parser, code):
        parser.parse(bytes(code, "utf8"))
        queue.put('hi')
    
    @func_set_timeout(5)
    def get_parser_result(queue):
        return queue.get()
    
    queue = Queue()
    parser_process = Process(target=parse_code, args=(queue, parser, code))
    parser_process.start()
    result = None

    try:
        result = get_parser_result(queue)
    except FunctionTimedOut as e:
        if parser_process.is_alive():
            print('-parser deadlock-')
            parser_process.kill()

    print(result)

## Testing

In [37]:
a = [0,1,2,3,4,5,6,7,8,9]
assert len(get_elements_by_percentage(a, 0)) == 0
assert len(get_elements_by_percentage(a, 0.1)) == 1
assert len(get_elements_by_percentage(a, 0.2)) == 2
assert len(get_elements_by_percentage(a, 0.3)) == 3
assert len(get_elements_by_percentage(a, 0.4)) == 4
assert len(get_elements_by_percentage(a, 0.5)) == 5
assert len(get_elements_by_percentage(a, 0.6)) == 6
assert len(get_elements_by_percentage(a, 0.7)) == 7
assert len(get_elements_by_percentage(a, 0.8)) == 8
assert len(get_elements_by_percentage(a, 0.9)) == 9
assert len(get_elements_by_percentage(a, 1)) == 10
assert a == get_elements_by_percentage(a, 1)

In [38]:
from CodeCheckList.tokenizer import CodeTokenizer
tokenizer = CodeTokenizer.from_pretrained("huggingface/CodeBERTa-small-v1", "python")
threshold = 1


In [39]:
code = "def scale(self, center=True, scale=True):\n        \"\"\"\nthe the\n\n\n                                                                                                                                                          _\n                     ____________=_=_===========________===______________________________==_____________________\n_______\n____\n\n___\n\n\n\n\n\n\n\n\n        return return)"
print(bytes(code, "utf8"))

assert False == is_balanced_snippet(code, threshold)



In [40]:
code = "def m(a,b):\n    r__urn a*b_________"
assert False == is_balanced_snippet(code, threshold)

In [70]:
#code = "def scale(self, center=True, scale=True):\n        \"\"\"\nthe the\n\n\n                                                                                                                                                          _\n                     ____________=_=_===========________===______________________________==_____________________\n_______\n____\n\n___\n\n\n\n\n\n\n\n\n        return return)"
code = "def m(a,b):\n    r__urn a*b_________"
run_parser_subprocess(code, tokenizer.parser)

hi
