In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline

tokenizer_GraphCodeBert = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")

model_GraphCodeBert = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import ast

def getLineAssignment(tree, lineno):
    return next((node for node in ast.walk(tree) if isinstance(node, ast.Name) and node.lineno == lineno), None)

def get_variables(expression):
    tree = ast.parse(expression)
    variables = []
    for node in ast.walk(tree):
        if not isinstance(node, ast.Assign):
            continue
        nodeVariable = getLineAssignment(tree, node.lineno).id
        if nodeVariable not in variables:
            variables.append(nodeVariable)
    return variables

In [3]:
class replaceName(ast.NodeTransformer):
    def __init__(self, varName, replaceName):
        self.varName = varName
        self.replaceName = replaceName

    def visit_Name(self, node):
        if(node.id == self.varName):
            node.id = self.replaceName
        return node


In [18]:
CODE = """
X = [[12,7,3],
    [4 ,5,6],
    [7 ,8,9]]
Y = [[5,8,1,2],
    [6,7,3,0],
    [4,5,9,1]]
result = [[0,0,0,0],
         [0,0,0,0],
         [0,0,0,0]]
for i in range(len(X)):
   for j in range(len(Y[0])):
       for k in range(len(Y)):
           result[i][j] += X[i][k] * Y[k][j]
for r in result:
   print(r)
"""

In [21]:
variables = get_variables(CODE)
print(variables)
tree = ast.parse(CODE)

for i in range(0, len(variables)):
    masked = ast.unparse(replaceName(variables[i], '<mask>').visit(ast.parse(CODE)))
    fill_mask_GCB = pipeline("fill-mask", model=model_GraphCodeBert, tokenizer=tokenizer_GraphCodeBert)
    candidates = fill_mask_GCB(masked)
    replacement = ''
    for j in range(0, len(candidates)):
        replacement = candidates[0][j]['token_str'].lstrip(' ')
        if replacement in variables:
            continue
        if len(replacement) > 0:
            break
    tree = replaceName(variables[i], replacement).visit(tree)
    variables[i] = replacement
print(variables)
new_code = ast.unparse(tree)
print(CODE)
print(new_code)


['X', 'Y', 'result']
['W', 'y', 'R']

X = [[12,7,3],
    [4 ,5,6],
    [7 ,8,9]]
Y = [[5,8,1,2],
    [6,7,3,0],
    [4,5,9,1]]
result = [[0,0,0,0],
         [0,0,0,0],
         [0,0,0,0]]
for i in range(len(X)):
   for j in range(len(Y[0])):
       for k in range(len(Y)):
           result[i][j] += X[i][k] * Y[k][j]
for r in result:
   print(r)

W = [[12, 7, 3], [4, 5, 6], [7, 8, 9]]
y = [[5, 8, 1, 2], [6, 7, 3, 0], [4, 5, 9, 1]]
R = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]
for i in range(len(W)):
    for j in range(len(y[0])):
        for k in range(len(y)):
            R[i][j] += W[i][k] * y[k][j]
for r in R:
    print(r)
