In [1]:
import logging
import sys, json, os
import numpy as np
import argparse

In [2]:
def read_answers(filename):
    answers = {}
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            answers[line.split('<CODESPLIT>')[0]] = line.split('<CODESPLIT>')[1]
    return answers

In [3]:
def read_predictions(filename):
    predictions = {}
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            predictions[line.split('<CODESPLIT>')[0]] = line.split('<CODESPLIT>')[1]
    return predictions

In [4]:
def calculate_scores(answers, predictions):
    scores = []
    for key in answers:
        if key not in predictions:
            logging.error("Missing prediction for index {}.".format(key))
            sys.exit()
        a = answers[key]
        p = predictions[key]
        scores.append(a==p)
    result = sum(scores) / len(scores)
    return result

In [5]:
answer_file = 'evaluator/answers/java/answers.txt'
prediction_file = 'evaluator/predictions/java/predictions.txt'

In [6]:
answers = read_answers(answer_file)

In [7]:
predictions = read_predictions(prediction_file)

In [16]:
print(len(answers))
print(len(predictions))

40492
40492


In [11]:
acc = calculate_scores(answers, predictions)
acc

0.8062580262767954

## Run prediction 

In [1]:
import torch

from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
import argparse
import json
import os

In [2]:
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)}

In [3]:
config_class, model_class, tokenizer_class = MODEL_CLASSES['roberta']
config = config_class.from_pretrained('roberta-base')

In [4]:
tokenizer = tokenizer_class.from_pretrained('roberta-base')

In [5]:
model = RobertaForMaskedLM.from_pretrained('microsoft/codebert-base-mlm', 
                                           output_attentions=True, output_hidden_states=True)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

RobertaForMaskedLM(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNor

In [7]:
cloze_results = []

In [8]:
cloze_words_file = 'data/cloze-all/cloze_test_words.txt'

In [9]:
file_path = 'data/cloze-all/java/clozeTest.json'

In [10]:
def get_cloze_words(filename, tokenizer):
    with open(filename, 'r', encoding='utf-8') as fp:
        words = fp.read().split('\n')
    idx2word = {tokenizer.encoder[w]: w for w in words}
    return idx2word

In [11]:
idx2word = get_cloze_words(cloze_words_file, tokenizer)

In [12]:
len(idx2word)

930

In [13]:
lines = json.load(open(file_path))

In [14]:
def test_single(text, model, idx2word, tokenizer, device):
    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))[:510]
    inputs = tokenizer.build_inputs_with_special_tokens(tokenized_text)
    index = inputs.index(tokenizer.mask_token_id)

    inputs = torch.tensor([inputs])
    inputs = inputs.to(device)

    with torch.no_grad():
        scores = model(inputs)[0]
        score_list = scores[0][index]
        word_index = torch.LongTensor(list(idx2word.keys())).to(device)
        word_index = torch.zeros(score_list.shape[0]).to(device).scatter(0, word_index, 1)
        score_list = score_list + (1-word_index) * -1e6
        predict_word_id = torch.argmax(score_list).data.tolist()

    return predict_word_id


In [15]:
line = lines[0]
text = ' '.join(line['nl_tokens'] + line['pl_tokens'])

In [16]:
text

'/ * ( non - Javadoc ) @ Override public int peekBit ( ) throws AACException { int ret ; if ( bitsCached > 0 ) { ret = ( cache >> ( bitsCached - 1 ) ) & 1 ; } else { final int word = readCache ( true ) ; ret = ( <mask> >> WORD_BITS - 1 ) & 1 ; } return ret ; }'

In [17]:
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))[:510]
inputs = tokenizer.build_inputs_with_special_tokens(tokenized_text)
index = inputs.index(tokenizer.mask_token_id)

inputs = torch.tensor([inputs])
inputs = inputs.to(device)

with torch.no_grad():
    scores = model(inputs)[0]
    score_list = scores[0][index]
    word_index = torch.LongTensor(list(idx2word.keys())).to(device)
    word_index = torch.zeros(score_list.shape[0]).to(device).scatter(0, word_index, 1)
    score_list = score_list + (1-word_index) * -1e6
    predict_word_id = torch.argmax(score_list).data.tolist()

In [18]:
results = []
results.append({'idx': line['idx'],
                        'prediction': idx2word[predict_word_id]})
results

[{'idx': 'all-1', 'prediction': 'word'}]

## Attention analysis

In [21]:
from flair.data import Sentence
from flair.models import SequenceTagger
from transformers import BertModel, BertTokenizer, BertForPreTraining, BertConfig
import json
import torch
import pandas as pd
import numpy as np
import javalang
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import collections
import pickle
import matplotlib
import sklearn
from matplotlib import pyplot as plt
from matplotlib import cm
from sklearn import manifold
from tqdm import tqdm

In [20]:
avg_attns = {
    k: np.zeros((12, 12)) for k in [
      "self", "right", "left", "sep", "sep_sep", "rest_sep",
      "cls", "all_basictype", "all_identifier", "all_keyword", "all_modifier", "all_operator", "all_separator",
      "cls_idf", "basictype_idf", "identifier_idf", "keyword_idf", "modifier_idf", "operator_idf", "separator_idf", "sep_idf"
      ,"all_function", "all_object", "all_class", "all_return-type"
      ]
}

In [22]:
def tokenize(code_snippet):
  tree = list(javalang.tokenizer.tokenize(code_snippet))
  tokens = []
  types = []
  for i in tree:
    j = str(i)
    j = j.split()
    typee = j[0]
    token = j[1].strip('"')
    if typee=="DecimalInteger":
      token = "NUM"
    elif typee.lower()=="string":
      token="STR"
    tokens.append(token)

  return " ".join(tokens)

In [23]:
new_codes = [line['pl_tokens']]

In [26]:
print(new_codes[0])

['@', 'Override', 'public', 'int', 'peekBit', '(', ')', 'throws', 'AACException', '{', 'int', 'ret', ';', 'if', '(', 'bitsCached', '>', '0', ')', '{', 'ret', '=', '(', 'cache', '>>', '(', 'bitsCached', '-', '1', ')', ')', '&', '1', ';', '}', 'else', '{', 'final', 'int', 'word', '=', 'readCache', '(', 'true', ')', ';', 'ret', '=', '(', '<mask>', '>>', 'WORD_BITS', '-', '1', ')', '&', '1', ';', '}', 'return', 'ret', ';', '}']


In [27]:
code = " ".join(new_codes[0])

In [30]:
new_codes = code

In [33]:
allTokenizedCode = []
for index, code in enumerate(new_codes):
  splittedCode = code.split("\n")
  s = ""
  try:
    for eachLine in splittedCode:
      tokenizedCode = tokenize(eachLine)
      if tokenizedCode.strip(" ")=="":
        print(eachLine)
        continue
      else:
        s = s + tokenizedCode + " \n "
    allTokenizedCode.append(s)
    s = ""
  except:
    print("RMRD************ERROR-WAS-THROWN**************RMRD" + str(index))

In [34]:
allTokenizedCode

['@ Override public int peekBit ( ) throws AACException { int ret ; if ( bitsCached > NUM ) { ret = ( cache > > ( bitsCached - NUM ) ) & NUM ; } else { final int word = readCache ( true ) ; ret = ( < mask > > > WORD_BITS - NUM ) & NUM ; } return ret ; } \n ']

In [35]:
code

'@ Override public int peekBit ( ) throws AACException { int ret ; if ( bitsCached > 0 ) { ret = ( cache >> ( bitsCached - 1 ) ) & 1 ; } else { final int word = readCache ( true ) ; ret = ( <mask> >> WORD_BITS - 1 ) & 1 ; } return ret ; }'