In [1]:
from typing import Generator
from tree_sitter import Language, Parser, Tree, Node

LANGUAGE = Language('parser_folder/my-languages.so','go')



In [2]:
parser = Parser()
parser.set_language(LANGUAGE)

In [3]:
def traverse_tree(tree: Tree) -> Generator[Node, None, None]:
    cursor = tree.walk()

    visited_children = False
    while True:
        if not visited_children:
            yield cursor.node
            if not cursor.goto_first_child():
                visited_children = True
        elif cursor.goto_next_sibling():
            visited_children = False
        elif not cursor.goto_parent():
            break

In [4]:
def c(node:Node,var):
        var_name = ""
        f = False
        right = None
        if node.type.count("declaration") > 0:
            child = node.child_by_field_name("declarator")
            var_name = child.child_by_field_name("declarator").text.decode("utf-8")
            right = child.child_by_field_name("value")
            f = True
        elif node.type == "assignment_expression":
            child = node.child_by_field_name("left")
            var_name = child.text.decode("utf-8")
            right = node.child_by_field_name("right")
            f = True
        if var_name == var:
                return True,var_name,f,right
        return False,var_name,f,right

In [5]:
code = \
"""
package main

import "fmt"

func main() {
    var flag bool = false
    var h bool = flag
    fmt.Println("Go Language.")
    
    if flag {
        fmt.Println("This will never be executed due to flag.")
    }
    
    flag = h
    
    for h {
        fmt.Println("This loop will never run due to h.")
        break
    }
}
"""

In [12]:
tree = parser.parse(bytes(code, 'utf8'))
tree.root_node
for node in traverse_tree(tree):
    
    print(node)
    print(node.text.decode("utf-8"))

<Node type=source_file, start_point=(1, 0), end_point=(21, 0)>
package main

import "fmt"

func main() {
    var flag bool = false
    var h bool = flag
    fmt.Println("Go Language.")
    
    if flag {
        fmt.Println("This will never be executed due to flag.")
    }
    
    flag = h
    
    for h {
        fmt.Println("This loop will never run due to h.")
        break
    }
}

<Node type=package_clause, start_point=(1, 0), end_point=(1, 12)>
package main
<Node type="package", start_point=(1, 0), end_point=(1, 7)>
package
<Node type=package_identifier, start_point=(1, 8), end_point=(1, 12)>
main
<Node type="
", start_point=(1, 12), end_point=(3, 0)>



<Node type=import_declaration, start_point=(3, 0), end_point=(3, 12)>
import "fmt"
<Node type="import", start_point=(3, 0), end_point=(3, 6)>
import
<Node type=import_spec, start_point=(3, 7), end_point=(3, 12)>
"fmt"
<Node type=interpreted_string_literal, start_point=(3, 7), end_point=(3, 12)>
"fmt"
<Node type=""", start_point=

In [102]:
def find_var_assignments(code,var,lang):
    #根据code找到某个变量的所有定义和赋值语句
    code = code.replace("\\n", "\n")
    tree = parser.parse(bytes(code, 'utf8'))
    ret = []
    for node in traverse_tree(tree):
        if (node.type.count("assignment") > 0 or node.type.count("declaration"))and node.child_by_field_name("left").text.decode("utf-8") == var:
            ret.append(node)
            print(node.child_by_field_name("right").child_count)
            print(node.child_by_field_name("right"))
            print(node.child_by_field_name("right").children)
    return ret
        

In [105]:
code_python = \
"""
tree = [1,2,3]
for node in tree:
    pass
    if False:
        print(node)
flag = False
a = [1,2,b] + a
while flag:
    print(1)
"""

In [106]:
find_var_assignments(code_python,"a","flag")

3
<Node type=binary_operator, start_point=(7, 4), end_point=(7, 15)>
[<Node type=list, start_point=(7, 4), end_point=(7, 11)>, <Node type="+", start_point=(7, 12), end_point=(7, 13)>, <Node type=identifier, start_point=(7, 14), end_point=(7, 15)>]


[<Node type=assignment, start_point=(7, 0), end_point=(7, 15)>]

In [107]:
import copy
def dead_code_judge(code,lang):
    #判断一段代码中所有存在的死代码，并将死代码的位置返回
    code = code.replace("\\n", "\n")
    tree = parser.parse(bytes(code, 'utf8'))
    ret = []
    for node in traverse_tree(tree):
        if node.type.count("if_statement")> 0 or  node.type.count("while_statement") > 0:
            if node.child_by_field_name("condition").text.decode("utf-8") in ["False","false"]:
                ret.append([node.start_byte,node.end_byte])

    return ret
                
def delete_dead_code(code,lang):
    #删除一段代码中所有的死代码
    ret = dead_code_judge(code,lang)
    acc = 0
    for v in ret:
        code = code[:v[0] - acc] + code[v[1] - acc:]
        acc += v[1] - v[0]
    return code


In [6]:
print(code_python)


tree = [1,2,3]
for node in tree:
    if False:
        print(node)
flag = False
while flag:
    print(1)



In [7]:
print(delete_dead_code(code_python,"python"))


tree = [1,2,3]
for node in tree:
    
flag = False
while flag:
    print(1)



In [7]:
for node in traverse_tree(tree):
    print(node)
    print(node.text.decode("utf-8"))

<Node type=module, start_point=(1, 0), end_point=(8, 0)>
tree = [1,2,3]
for node in tree:
    if node == 1:
        print(node)
flag = False
while flag:
    print(1)

<Node type=expression_statement, start_point=(1, 0), end_point=(1, 14)>
tree = [1,2,3]
<Node type=assignment, start_point=(1, 0), end_point=(1, 14)>
tree = [1,2,3]
<Node type=identifier, start_point=(1, 0), end_point=(1, 4)>
tree
<Node type="=", start_point=(1, 5), end_point=(1, 6)>
=
<Node type=list, start_point=(1, 7), end_point=(1, 14)>
[1,2,3]
<Node type="[", start_point=(1, 7), end_point=(1, 8)>
[
<Node type=integer, start_point=(1, 8), end_point=(1, 9)>
1
<Node type=",", start_point=(1, 9), end_point=(1, 10)>
,
<Node type=integer, start_point=(1, 10), end_point=(1, 11)>
2
<Node type=",", start_point=(1, 11), end_point=(1, 12)>
,
<Node type=integer, start_point=(1, 12), end_point=(1, 13)>
3
<Node type="]", start_point=(1, 13), end_point=(1, 14)>
]
<Node type=for_statement, start_point=(2, 0), end_point=(4, 19)>
for n

In [26]:
tree.root_node

<Node type=source_file, start_point=(0, 0), end_point=(2, 1)>

In [27]:
node.start_byte

122

In [28]:
code[122:]

'return autoConvert_policy_Eviction_To_v1beta1_Eviction(in, out, s)\n}'

In [21]:
node.end_point

(1, 67)

In [53]:
root_node = tree.root_node

In [65]:
root_node.children[0]

<Node type=source_file, start_point=(1, 0), end_point=(11, 0)>

In [40]:
root_node.children

[<Node type=expression_statement, start_point=(0, 1), end_point=(0, 16)>,
 <Node type=expression_statement, start_point=(1, 2), end_point=(1, 25)>,
 <Node type=for_statement, start_point=(2, 2), end_point=(20, 43)>,
 <Node type=ERROR, start_point=(22, 0), end_point=(22, 1)>]

In [3]:
from run_parser import extract_dataflow

In [4]:
tree = parser.parse(bytes(code_python, 'utf-8'))

In [5]:
root_node = tree.root_node

In [89]:
def get_program_stmt_nums(code,lang):
    #得到源代码中statement的位置和数量
    code = code.replace("\\n", "\n")
    global tree
    if not tree:
        tree = parser.parse(bytes(code, 'utf8'))
    ret = []
    for node in traverse_tree(tree):
        if node.type.count("statement") > 0: #是不是每一种编程语言的语句type都叫statement?如果不是，这里就有问题
            ret.append((node.text.decode("utf-8"),node.end_byte,node.start_byte,node.start_point,node.end_point))
    #根据end_type进行筛选，只保留最长的statement
    realret = {}
    for v in ret:
        realret[v[1]] = (v[0],v[2],v[3],v[4]) if len(realret.get(v[1],("",0,0,0))[0]) < len(v[0]) else realret.get(v[1],("",0,0,0))
    ret = [(v,k) for k,v in realret.items()]
    ret.sort(key=lambda x:x[1])
    strs = [x[0][0] for x in ret]
    end_types = [x[1] for x in ret]
    start_bytes = [x[0][1] for x in ret]
    start_points = [x[0][2] for x in ret]
    end_points = [x[0][3] for x in ret]
    return strs,end_types,start_bytes,start_points,end_points

In [90]:
lst,end_bytes,start_bytes,start_points,end_points = get_program_stmt_nums(code,"python")

In [91]:
index = 1
lst = lst[:index] + ["print(123)"] + lst[index:]
code_new = "\n".join(lst)
tree.edit(start_byte=start_bytes[index],old_end_byte=end_bytes[index],new_end_byte=end_bytes[index]+len("print(123)"),\
    start_point=start_points[index],old_end_point=end_points[index],new_end_point=(end_points[index][0],end_points[index][1] + len("print(123)")))
#This will run much faster than if you were parsing from scratch.

In [100]:
print(code_python)


tree = [1,2,3]
for node in tree:
    if node == 1:
        print(node)
flag = False
while flag:
    print(1)



In [101]:
print(code_new)

tree = [1,2,3]
print(123)
for node in tree:
    if node == 1:
        print(node)
flag = False
while flag:
    print(1)


In [104]:
tree = parser.parse(bytes(code_new,"utf-8"))

In [102]:
new_tree = parser.parse(bytes(code_new,"utf-8"), tree)

In [105]:
for node in traverse_tree(tree):
    print(node)
    print(node.text.decode("utf-8"))
    #对于python语言来说，每一个statement模块后面的都是可以插入的位置


<Node type=module, start_point=(0, 0), end_point=(7, 12)>
tree = [1,2,3]
print(123)
for node in tree:
    if node == 1:
        print(node)
flag = False
while flag:
    print(1)
<Node type=expression_statement, start_point=(0, 0), end_point=(0, 14)>
tree = [1,2,3]
<Node type=assignment, start_point=(0, 0), end_point=(0, 14)>
tree = [1,2,3]
<Node type=identifier, start_point=(0, 0), end_point=(0, 4)>
tree
<Node type="=", start_point=(0, 5), end_point=(0, 6)>
=
<Node type=list, start_point=(0, 7), end_point=(0, 14)>
[1,2,3]
<Node type="[", start_point=(0, 7), end_point=(0, 8)>
[
<Node type=integer, start_point=(0, 8), end_point=(0, 9)>
1
<Node type=",", start_point=(0, 9), end_point=(0, 10)>
,
<Node type=integer, start_point=(0, 10), end_point=(0, 11)>
2
<Node type=",", start_point=(0, 11), end_point=(0, 12)>
,
<Node type=integer, start_point=(0, 12), end_point=(0, 13)>
3
<Node type="]", start_point=(0, 13), end_point=(0, 14)>
]
<Node type=expression_statement, start_point=(1, 0), end_po

In [82]:
node.is_named

True

In [5]:
h = extract_dataflow(code_python,'python')

In [6]:
h[0]

[('T', 0, 'computedFrom', ['int', 'raw_input'], [2, 4]),
 ('t', 9, 'computedFrom', ['T', 'xrange'], [11, 13]),
 ('T', 13, 'comesFrom', ['T'], [0]),
 ('line',
  16,
  'computedFrom',
  ['split', 'raw_input', 'map', 'int'],
  [18, 20, 22, 26]),
 ('N', 30, 'computedFrom', ['3', '0', 'line'], [36, 38, 40]),
 ('S', 32, 'computedFrom', ['3', '0', 'line'], [36, 38, 40]),
 ('p', 34, 'computedFrom', ['3', '0', 'line'], [36, 38, 40]),
 ('line', 36, 'comesFrom', ['line'], [16]),
 ('a', 42, 'computedFrom', ['3', 'line'], [44, 46]),
 ('line', 44, 'comesFrom', ['line'], [16]),
 ('cnt', 49, 'computedFrom', ['0'], [51]),
 ('a', 52, 'comesFrom', ['a'], [42]),
 ('x', 61, 'computedFrom', ['a'], [63]),
 ('a', 63, 'comesFrom', ['a'], [42]),
 ('x', 67, 'comesFrom', ['x'], [61]),
 ('p', 74, 'comesFrom', ['p'], [34]),
 ('cnt', 76, 'computedFrom', ['1'], [78]),
 ('S', 80, 'comesFrom', ['S'], [32, 103]),
 ('x', 86, 'comesFrom', ['x'], [61]),
 ('x', 91, 'comesFrom', ['x'], [61]),
 ('p', 98, 'comesFrom', ['p'], [

In [11]:
code = """
def build_table():\n  \t\n    key = []\ntable = key\n  \t\n    model = 0\n\n    chain = 31\nfor i in range( model, chain ):\n  \t\t\n    all = ( get_max_score( i ), get_max_surprise_score( i ) )\ntable.append( all )\n  \n  \treturn table\n  \n  \n  def get_max_score( i ):\n  \t\n    td = 0\n\n    buffer = 10\n\n    file = 2\n\n    list = 3\nreturn max( td, min( buffer, ( i + file ) / list ) )\n  \n  def get_max_surprise_score( i ):\n  \t\n    tree = 0\n\n    c = 10\n\n    root = 4\n\n    note = 3\nreturn min( i, max( tree, min( c, ( i + root ) / note ) ) )\n  \n  def get_max( case, scores, score_needed, num_surprises ):\n  \tscores = sorted( scores, reverse=True )\n  \t\n    section = 0\nnumPass = section\n  \t\n    bl = 0\ni = bl;\n  \twhile ( i < len(scores) ):\n  \t\t\n    py = 0\nif ( case[scores[i]][py] >= score_needed ):\n  \t\t\t\n    case = 1\nnumPass += case\n  \t\telse:\n  \t\t\tbreak\n  \t\t\ncore = 1\ni += core\n  \n  \t\nche = 0\nwhile ( i < len(scores) and num_surprises > che ):\n  \t\t\narray = 1\nif ( case[scores[i]][array] >= score_needed ):\n  \t\t\t\norder = 1\nnumPass += order\n  \t\t\t\nl = 1\nnum_surprises -= l\n  \t\t\t\n  \t\t\n    key = 1\ni += key\n  \n  \treturn numPass\n  \n  case = build_table()\n  \n  num_cases = input()\n  \n\n    ti = 1\n\n    t = 1\n  for i in range( ti, num_cases + t ):\n  \tline = raw_input().split()\n  \t\n    fi = 1\nnum_surprises = int(line[fi])\n  \t\n    b = 2\nscore_needed = int(line[b])\n  \t\n    ih = 3\nscores_raw = line[ih:]\n  \n  \tscores = [ int(y) for y in scores_raw ]\n  \n  \t\n        qi = 'Case #'\n\n        phi = ': '\nprint qi + str( i ) + phi + str( get_max( case, scores, score_needed, num_surprises ) )\n
"""

In [12]:
print(code)


def build_table():
  	
    key = []
table = key
  	
    model = 0

    chain = 31
for i in range( model, chain ):
  		
    all = ( get_max_score( i ), get_max_surprise_score( i ) )
table.append( all )
  
  	return table
  
  
  def get_max_score( i ):
  	
    td = 0

    buffer = 10

    file = 2

    list = 3
return max( td, min( buffer, ( i + file ) / list ) )
  
  def get_max_surprise_score( i ):
  	
    tree = 0

    c = 10

    root = 4

    note = 3
return min( i, max( tree, min( c, ( i + root ) / note ) ) )
  
  def get_max( case, scores, score_needed, num_surprises ):
  	scores = sorted( scores, reverse=True )
  	
    section = 0
numPass = section
  	
    bl = 0
i = bl;
  	while ( i < len(scores) ):
  		
    py = 0
if ( case[scores[i]][py] >= score_needed ):
  			
    case = 1
numPass += case
  		else:
  			break
  		
core = 1
i += core
  
  	
che = 0
while ( i < len(scores) and num_surprises > che ):
  		
array = 1
if ( case[scores[i]][array] >= score_needed ):
  			
order = 

In [3]:
code = """
def build_table():
  	table = []
  	for i in range( 0, 31 ):
  		table.append( ( get_max_score( i ), get_max_surprise_score( i ) ) )
  
  	return table
  
  
  def get_max_score( i ):
  	return max( 0, min( 10, ( i + 2 ) / 3 ) )
  
  def get_max_surprise_score( i ):
  	return min( i, max( 0, min( 10, ( i + 4 ) / 3 ) ) )
  
  def get_max( x, scores, score_needed, num_surprises ):
  	scores = sorted( scores, reverse=True )
  	numPass = 0
  	i = 0;
  	while ( i < len(scores) ):
  		if ( x[scores[i]][0] >= score_needed ):
  			numPass += 1
  		else:
  			break
  		i += 1
  
  	while ( i < len(scores) and num_surprises > 0 ):
  		if ( x[scores[i]][1] >= score_needed ):
  			numPass += 1
  			num_surprises -= 1
  			
  		i += 1
  
  	return numPass
  
  x = build_table()
  
  num_cases = input()
  
  for i in range( 1, num_cases + 1 ):
  	line = raw_input().split()
  	num_surprises = int(line[1])
  	score_needed = int(line[2])
  	scores_raw = line[3:]
  
  	scores = [ int(y) for y in scores_raw ]
  
  	print 'Case #' + str( i ) + ': ' + str( get_max( x, scores, score_needed, num_surprises ) )

"""

In [4]:
print(code)


def build_table():
  	table = []
  	for i in range( 0, 31 ):
  		table.append( ( get_max_score( i ), get_max_surprise_score( i ) ) )
  
  	return table
  
  
  def get_max_score( i ):
  	return max( 0, min( 10, ( i + 2 ) / 3 ) )
  
  def get_max_surprise_score( i ):
  	return min( i, max( 0, min( 10, ( i + 4 ) / 3 ) ) )
  
  def get_max( x, scores, score_needed, num_surprises ):
  	scores = sorted( scores, reverse=True )
  	numPass = 0
  	i = 0;
  	while ( i < len(scores) ):
  		if ( x[scores[i]][0] >= score_needed ):
  			numPass += 1
  		else:
  			break
  		i += 1
  
  	while ( i < len(scores) and num_surprises > 0 ):
  		if ( x[scores[i]][1] >= score_needed ):
  			numPass += 1
  			num_surprises -= 1
  			
  		i += 1
  
  	return numPass
  
  x = build_table()
  
  num_cases = input()
  
  for i in range( 1, num_cases + 1 ):
  	line = raw_input().split()
  	num_surprises = int(line[1])
  	score_needed = int(line[2])
  	scores_raw = line[3:]
  
  	scores = [ int(y) for y in scores