### Download Grammar

In [33]:
from CodeCheckList import loader

"""define language"""
python_language = "python"

languages = [python_language]

loader.download_grammars(languages)

/home/svelascodimate/miniconda3/envs/code-check-list/lib/python3.10/site-packages/CodeCheckList/grammars


### Load Model

In [34]:

"""define the model checkpoint"""
checkpoint = "huggingface/CodeBERTa-small-v1"

### Create Modules

In [35]:
from CodeCheckList.tokenizer import CodeTokenizer
from CodeCheckList.masker import Masker

#create code tokenizer 
bert_tokenizer = CodeTokenizer.from_pretrained(checkpoint, python_language)

#create code masker
code_masker = Masker(bert_tokenizer)

### Node Types

In [36]:
print(bert_tokenizer.node_types)

['yield', '%', ':=', 'lambda', 'expression', 'float', '^', '>>=', '^=', 'case', 'default_parameter', 'or', '~', '->', '&=', 'decorator', 'in', '_compound_statement', '+', 'while', 'global_statement', 'try_statement', ':', 'for', 'list_pattern', '@=', 'with', 'as', 'parameter', '.', 'finally_clause', '<', '-', '{', 'del', 'block', 'await', '//=', 'exec', 'set_comprehension', 'try', 'pair', 'subscript', 'conditional_expression', 'if_statement', '>=', 'class_definition', 'elif_clause', '(', 'global', 'pass', 'string', 'delete_statement', 'slice', 'raise_statement', 'comment', 'false', 'wildcard_import', 'none', 'tuple', 'aliased_import', ';', 'assert_statement', 'else', '}}', 'augmented_assignment', 'nonlocal_statement', 'keyword_argument', 'for_in_clause', 'format_specifier', 'ellipsis', 'type_conversion', '=', '**=', 'module', 'typed_default_parameter', '!=', 'match_statement', 'dictionary_comprehension', 'expression_statement', '|=', 'true', 'return', '>>', '}', 'attribute', 'assignmen

### Encodings

In [37]:
"""example source code"""

code = "def multiply_numbers(a,b):\n    return a*b"
target_node_type = "*"

#encoding 
source_code_encoding = bert_tokenizer(code)

#masking
masked_code_encoding = code_masker(code, bert_tokenizer(code), bert_tokenizer.node_types.index(target_node_type))

assert len(source_code_encoding['input_ids']) == len(masked_code_encoding['input_ids'])

#masked code
masked_code = bert_tokenizer.tokenizer.decode(list(filter(lambda token_id: False if token_id == bert_tokenizer.tokenizer.bos_token_id or 
            token_id == bert_tokenizer.tokenizer.eos_token_id else True, masked_code_encoding['input_ids'])))

print(masked_code)

def multiply_numbers(a,b):
    return a<mask>b


### Code Prediction

In [38]:
import torch
from transformers import AutoModelForMaskedLM


model = AutoModelForMaskedLM.from_pretrained(checkpoint)

masked_indexes = list(map(lambda entry: entry[0],
    list(filter(lambda entry: True if entry[1] == bert_tokenizer.tokenizer.mask_token_id else False, enumerate(masked_code_encoding['input_ids'])))))

code_encoding = bert_tokenizer.tokenizer(code, return_tensors='pt')
code_encoding['input_ids'][0] = torch.tensor([torch.tensor(input_id) for input_id in masked_code_encoding['input_ids']])

model_prediction = model(**code_encoding)

for masked_index in masked_indexes:
    values, predictions = model_prediction['logits'][0][masked_index].topk(2)
    print(values)
    print(bert_tokenizer.tokenizer.decode(predictions))
    code_encoding['input_ids'][0][masked_index] = predictions[0]

predicted_code = bert_tokenizer.tokenizer.decode(code_encoding['input_ids'][0])

tensor([14.7716, 14.7396], grad_fn=<TopkBackward0>)
,*


### Evaluation

In [39]:
print(code)
print(masked_code)
print(predicted_code)

def multiply_numbers(a,b):
    return a*b
def multiply_numbers(a,b):
    return a<mask>b
<s>def multiply_numbers(a,b):
    return a,b</s>
