In [190]:
import random
import tree_sitter_python as tspython
from tree_sitter import Language, Parser

In [4]:
# Build parser with python language
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)

In [None]:
# Sample, need to be bytes not string, if string is sent byte(string) could be used
code = b"""def hello(a,b):
    return a + b

def world(c):
    if true:
        return 0
    else:
        return c"""

# Example of output after parsing 
#(module 
# (function_definition name: 
# (identifier) parameters: (parameters (identifier) (identifier)) 
# body: (block (return_statement (binary_operator left: (identifier) right: (identifier)))))
# 
# (function_definition name: 
# (identifier) parameters: (parameters (identifier))
# body: (block (if_statement condition: (identifier)  consequence: (block (return_statement (integer))) 
# alternative: (else_clause body: (block (return_statement (identifier))))))))

identifiers are basically variables and function names

In [28]:
tree = parser.parse(code,encoding="utf8")
help(tree)

Help on Tree in module tree_sitter object:

class Tree(builtins.object)
 |  A tree that represents the syntactic structure of a source code file.
 |
 |  Methods defined here:
 |
 |  __copy__(self, /)
 |      Use :func:`copy.copy` to create a copy of the tree.
 |
 |  changed_ranges(self, /, new_tree)
 |      Compare this old edited syntax tree to a new syntax tree representing the same document, returning a sequence of ranges whose syntactic structure has changed.
 |
 |      Returns
 |      -------
 |      Ranges where the hierarchical structure of syntax nodes (from root to leaf) has changed between the old and new trees. Characters outside these ranges have identical ancestor nodes in both trees.
 |
 |      Note
 |      ----
 |      The returned ranges may be slightly larger than the exact changed areas, but Tree-sitter attempts to make them as small as possible.
 |
 |      Tip
 |      ---
 |      For this to work correctly, this syntax tree must have been edited such that its ranges 

In [29]:
print(tree.language)
print(tree.included_ranges)
print(tree.root_node)


<Language id=140737184724032, version=15, name="python">
[<Range start_point=(0, 0), end_point=(4294967295, 4294967295), start_byte=0, end_byte=4294967295>]
(module (function_definition name: (identifier) parameters: (parameters (identifier) (identifier)) body: (block (return_statement (binary_operator left: (identifier) right: (identifier))))) (function_definition name: (identifier) parameters: (parameters (identifier)) body: (block (if_statement condition: (identifier) consequence: (block (return_statement (integer))) alternative: (else_clause body: (block (return_statement (identifier))))))))


In [30]:
node = tree.root_node

In [41]:
help(node)

Help on Node in module tree_sitter object:

class Node(builtins.object)
 |  A single node within a syntax ``Tree``.
 |
 |  Methods defined here:
 |
 |  __eq__(self, value, /)
 |      Return self==value.
 |
 |  __ge__(self, value, /)
 |      Return self>=value.
 |
 |  __gt__(self, value, /)
 |      Return self>value.
 |
 |  __hash__(self, /)
 |      Return hash(self).
 |
 |  __le__(self, value, /)
 |      Return self<=value.
 |
 |  __lt__(self, value, /)
 |      Return self<value.
 |
 |  __ne__(self, value, /)
 |      Return self!=value.
 |
 |  __repr__(self, /)
 |      Return repr(self).
 |
 |  __str__(self, /)
 |      Return str(self).
 |
 |  child(self, index, /)
 |      Get this node's child at the given index, where ``0`` represents the first child.
 |
 |      Caution
 |      -------
 |      This method is fairly fast, but its cost is technically ``log(i)``, so if you might be iterating over a long list of children, you should use :attr:`children` or :meth:`walk` instead.
 |
 |  ch

In [39]:
# text is in utf-8 so we need to decode it for string
node.type, node.text.decode('utf8'), node.children# type: ignore

('module',
 'def hello(a,b):\n    return a + b\n\ndef world(c):\n    if true:\n        return 0\n    else:\n        return c',
 [<Node type=function_definition, start_point=(0, 0), end_point=(1, 16)>,
  <Node type=function_definition, start_point=(3, 0), end_point=(7, 16)>])

In [230]:
identifiers = []
def rec_tree(node):   
    for children_node in node.children:

        #print(f"type: {children_node.type}, text: {children_node.text}")
        if children_node.type == 'identifier': #and random.choice([True, False]):
            # point takes the row column position, byte is just the position of the byte
            # print(children_node.start_point, children_node.start_byte)
            # print(children_node.end_point, children_node.end_byte)
            identifiers.append((children_node.start_byte, children_node.end_byte))
            

        rec_tree(children_node)
        #if result is not None:
        #    return result
    return identifiers

result = rec_tree(node)
print(len(result))

9


In [231]:
from torchtext.data import get_tokenizer

tokenizer = get_tokenizer("basic_english")


In [234]:
code.decode()

'def hello(a,b):\n    return a + b\n\ndef world(c):\n    if true:\n        return 0\n    else:\n        return c'

In [None]:
# mask 15% of the tokens
tokens_sample = tokenizer(code.decode())
num_tokens_to_mask = (len(tokens_sample)) * 15 // 100

In [None]:
to_mask = random.sample(result, num_tokens_to_mask)
# need to sort the values as to modify the string from left to right
to_mask = sorted(to_mask, key= lambda x: x[0])
to_mask

In [None]:
mod = node.text
difference = 0
offset = 0
for mask in to_mask:
    start, end = mask[0], mask[1]

    label = mod[start + offset  :end + offset ]
    mod = mod[: start + offset] + b'<mask>' + mod[end + offset: ]
    
    # because we modify the string, we need an offest for the modifications
    difference = end - start # offset
    offset += 6 - difference
    

b'b'
b'def hello(a,<mask>):\n    return a + b\n\ndef world(c):\n    if true:\n        return 0\n    else:\n        return c'
1 13 12
12 19
offset is  5
b'world'
b'def hello(a,<mask>):\n    return a + b\n\ndef <mask>(c):\n    if true:\n        return 0\n    else:\n        return c'
5 43 38
38 49
offset is  6
b'c'
b'def hello(a,<mask>):\n    return a + b\n\ndef <mask>(<mask>):\n    if true:\n        return 0\n    else:\n        return c'
1 45 44
44 51
offset is  11


In [263]:
mod.decode()

'def <mask>(a,b):\n    return a + b\n\ndef world(c):\n    if<mask>:\n        return 0\n    else:\n        retur<mask>'

True