In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline

tokenizer = AutoTokenizer.from_pretrained("neulab/codebert-python")
model = AutoModelForMaskedLM.from_pretrained("neulab/codebert-python")
fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

In [2]:
outputs = fill_mask("<mask> numpy as np", top_k=10)
for output in outputs:
    print(output)

{'score': 0.9995896220207214, 'token': 41975, 'token_str': 'import', 'sequence': 'import numpy as np'}
{'score': 0.0001533473696326837, 'token': 7761, 'token_str': 'from', 'sequence': 'from numpy as np'}
{'score': 0.00015234711463563144, 'token': 6595, 'token_str': ' import', 'sequence': ' import numpy as np'}
{'score': 3.692310565384105e-05, 'token': 10431, 'token_str': '#', 'sequence': '# numpy as np'}
{'score': 1.0732301234384067e-05, 'token': 41929, 'token_str': 'Import', 'sequence': 'Import numpy as np'}
{'score': 8.840012924338225e-06, 'token': 20891, 'token_str': ' Import', 'sequence': ' Import numpy as np'}
{'score': 5.5008940762490965e-06, 'token': 6, 'token_str': ',', 'sequence': ', numpy as np'}
{'score': 4.772637112182565e-06, 'token': 46181, 'token_str': 'package', 'sequence': 'package numpy as np'}
{'score': 2.338359308851068e-06, 'token': 50118, 'token_str': '\n', 'sequence': '\n numpy as np'}
{'score': 2.3127922759158537e-06, 'token': 1437, 'token_str': ' ', 'sequence':

In [3]:
outputs = fill_mask("if (x is not None) <mask> (x > 0)")
for output in outputs:
    print(output)

{'score': 0.875249445438385, 'token': 8, 'token_str': ' and', 'sequence': 'if (x is not None) and (x > 0)'}
{'score': 0.017183667048811913, 'token': 50, 'token_str': ' or', 'sequence': 'if (x is not None) or (x > 0)'}
{'score': 0.013177888467907906, 'token': 463, 'token_str': 'and', 'sequence': 'if (x is not None)and (x > 0)'}
{'score': 0.012697670608758926, 'token': 671, 'token_str': ' return', 'sequence': 'if (x is not None) return (x > 0)'}
{'score': 0.010224265046417713, 'token': 48200, 'token_str': ' &&', 'sequence': 'if (x is not None) && (x > 0)'}


In [4]:
outputs = fill_mask("if var1 <mask> <mask> None:")
for output in outputs:
    for sub_output in output:
        print(sub_output)
    print()

{'score': 0.9669309854507446, 'token': 16, 'token_str': ' is', 'sequence': '<s>if var1 is<mask> None:</s>'}
{'score': 0.00770614156499505, 'token': 328, 'token_str': '!', 'sequence': '<s>if var1!<mask> None:</s>'}
{'score': 0.002733553759753704, 'token': 28696, 'token_str': ' <', 'sequence': '<s>if var1 <<mask> None:</s>'}
{'score': 0.002088442211970687, 'token': 50118, 'token_str': '\n', 'sequence': '<s>if var1\n<mask> None:</s>'}
{'score': 0.0018904786556959152, 'token': 35, 'token_str': ':', 'sequence': '<s>if var1:<mask> None:</s>'}

{'score': 0.9925784468650818, 'token': 45, 'token_str': ' not', 'sequence': '<s>if var1<mask> not None:</s>'}
{'score': 0.001338448142632842, 'token': 5214, 'token_str': '=', 'sequence': '<s>if var1<mask>= None:</s>'}
{'score': 0.0013240614207461476, 'token': 16, 'token_str': ' is', 'sequence': '<s>if var1<mask> is None:</s>'}
{'score': 0.0009433833765797317, 'token': 49333, 'token_str': '!=', 'sequence': '<s>if var1<mask>!= None:</s>'}
{'score': 0.000

In [5]:
outputs = fill_mask("<mask> ( x )")
for output in outputs:
    print(output)

{'score': 0.1390775889158249, 'token': 17265, 'token_str': 'print', 'sequence': 'print ( x )'}
{'score': 0.09198562055826187, 'token': 5780, 'token_str': ' print', 'sequence': ' print ( x )'}
{'score': 0.03168381378054619, 'token': 41975, 'token_str': 'import', 'sequence': 'import ( x )'}
{'score': 0.01466455589979887, 'token': 1423, 'token_str': ' y', 'sequence': ' y ( x )'}
{'score': 0.013982338830828667, 'token': 37131, 'token_str': ' eval', 'sequence': ' eval ( x )'}


In [6]:
from collections import defaultdict
from tokenize import generate_tokens

from Levenshtein import distance

def merge_outputs(outputs):
    probs = defaultdict(float)
    for output in outputs:
        probs[output["token_str"].strip()] += output["score"]
    return probs

def tokenize(lines, line):
    N = len(lines)
    tokens = list(generate_tokens(lambda L=iter(lines): next(L)))
    filtered_tokens = defaultdict(list)
    for token in tokens:
        filtered_tokens[token.start[0] - 1].append(token)
    line_lengths = [filtered_tokens[i][-1].end[1] for i in range(N)]
    cumulative_lengths = [sum(line_lengths[:i]) for i in range(N)]
    
    curr_tokens = filtered_tokens[line]
    curr_token_strings = [x.string for x in curr_tokens]
    line_offset = cumulative_lengths[line]
    
    print(" ".join(curr_token_strings).strip())
    
    return curr_tokens, curr_token_strings, line_offset


def get_best_output(prev, merged_outputs):
    merged_outputs = sorted(merged_outputs.items(), key=lambda x: -x[1])
    best_output, best_ratio = merged_outputs[0][0], merged_outputs[0][1] / merged_outputs[1][1]
    dist_outputs = {}
    for key, value in merged_outputs:
        dist = distance(prev, key)
        new_prob = value * 0.2 ** dist
        dist_outputs[key] = new_prob
    dist_outputs = sorted(dist_outputs.items(), key=lambda x: -x[1])
    new_output, new_ratio = dist_outputs[0][0], dist_outputs[0][1] / dist_outputs[1][1]
    if new_ratio > best_ratio:
        best_output = new_output
        best_ratio = new_ratio
    return best_output, best_ratio
    


def autocorrect(text, line):
    lines = [x + "\n" for x in text.split("\n")]
    curr_tokens, curr_token_strings, line_offset = tokenize(lines, line)

    best_suggestion = 0
    suggestions = []
    prev_lines = "".join(lines[max(0, line-2):line])
    next_lines = "".join(lines[line+1:line+3])
    for i in range(len(curr_tokens)):
        prev = curr_token_strings[i]
        if len(prev.strip()) == 0:
            continue
        
        curr_token_strings[i] = "<mask>"
        string = " ".join(curr_token_strings).strip()
        curr_token_strings[i] = prev

        outputs = fill_mask(prev_lines + string + next_lines)
        merged_outputs = merge_outputs(outputs)
        if len(merged_outputs) < 2:
            continue
            
        max_output = max(merged_outputs.values())
        if max_output < 0.4:
            continue

        best_output, best_ratio = get_best_output(prev, merged_outputs)
        # print(i, "\t", prev, "\t", best_output, "\t", best_ratio)
        if best_output.strip() != prev.strip() and len(best_output.strip()) > 0 and best_ratio > 5:
            print("CHANGE {", prev, "} to {", best_output, "} (", best_ratio, ")")
            start = line_offset + curr_tokens[i].start[1]
            end = line_offset + curr_tokens[i].end[1]
            suggestions.append((((prev, start, end), best_output), best_ratio))
            best_suggestion = max(best_suggestion, best_ratio)
        
    suggestions = [s[0] for s in suggestions if s[1] >= 0.5 * best_suggestion]
    return suggestions


text = """imprt numpy as np"""
line = 0  # center of attention
suggestions = autocorrect(text, line)

print()
print(suggestions)

imprt numpy as np
CHANGE { imprt } to { import } ( 465763.09558493894 )

[(('imprt', 0, 5), 'import')]
