In [1]:
import torch
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
import argparse
import json
import os
import pandas as pd
import numpy as np 
from tqdm import tqdm
from scipy import stats
import javalang

In [2]:
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)}

config_class, model_class, tokenizer_class = MODEL_CLASSES['roberta']
config = config_class.from_pretrained('roberta-base')
tokenizer = tokenizer_class.from_pretrained('roberta-base')

model = RobertaForMaskedLM.from_pretrained('microsoft/codebert-base-mlm', 
                                           output_attentions=True, output_hidden_states=True)

In [3]:
tokenizer.cls_token_id, tokenizer.sep_token_id, tokenizer.mask_token_id

(0, 2, 50264)

In [4]:
tokenizer.cls_token, tokenizer.sep_token, tokenizer.mask_token

('<s>', '</s>', '<mask>')

In [5]:
cloze_results = []
cloze_words_file = 'data/cloze-all/cloze_test_words.txt'
file_path = 'data/cloze-all/java/clozeTest.json'

In [6]:
def get_cloze_words(filename, tokenizer):
    with open(filename, 'r', encoding='utf-8') as fp:
        words = fp.read().split('\n')
    idx2word = {tokenizer.encoder[w]: w for w in words}
    return idx2word

In [7]:
idx2word = get_cloze_words(cloze_words_file, tokenizer)

In [8]:
lines = json.load(open(file_path))
len(lines)

40492

In [9]:
def read_answers(filename):
    answers = {}
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            answers[line.split('<CODESPLIT>')[0]] = line.split('<CODESPLIT>')[1]
    return answers

answer_file = 'evaluator/answers/java/answers.txt'
answers = read_answers(answer_file)
answer_list = list(answers.values())

In [11]:
# predict_word = idx2word[predict_word_id]
# print('Predicted word: {}'.format(predict_word))
# print('Ground truth word: {}'.format(answers['all-1']))

In [12]:
bestSampleWithMaxPairLength = []
bestSampleWithMaxPairLength_LEN =[]

number_of_samples = 100
for i in range(len(lines[:number_of_samples])):
    code = ' '.join(lines[i]['pl_tokens'])
    bestStr = "<s> " + code + " </s>"
    bestLen = len(bestStr.split(" "))
    bestSampleWithMaxPairLength.append(bestStr)
    bestSampleWithMaxPairLength_LEN.append(bestLen)

In [13]:
lengths=[]
codes=[]
selected_answers = []

for index, code in enumerate(bestSampleWithMaxPairLength):
  l = len(tokenizer.tokenize(code))
  if l<=256:
    lengths.append(l)
    codes.append(code)
    selected_answers.append(answer_list[index])
    

In [14]:
len(codes), len(selected_answers)

(85, 85)

In [15]:
code = codes[0]
with torch.no_grad():

    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code))
    input_ids = torch.tensor([tokenized_text])
    output_from_model = model(input_ids)

    _attention = output_from_model["attentions"]# attention shape is layers, batchsize, heads, tokenLen, tokenLen


In [16]:
len(_attention)

12

In [17]:
_attention[0].shape

torch.Size([1, 12, 78, 78])

### Javalang

In [14]:

tree = list(javalang.tokenizer.tokenize(codes[0]))

In [15]:
type(tree), tree[0]

(list, Operator "<" line 1, position 1)

In [16]:
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(codes[0]))

### Calculate Average attention on CLS 

In [18]:
cls_data = np.zeros((12,12))

with torch.no_grad():
  for eachCode in tqdm(codes):
    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(eachCode))[:512]
    input_ids = torch.tensor([tokenized_text])
    output_from_model = model(input_ids)
    
    _attention = output_from_model["attentions"]# attention shape is layers, batchsize, heads, tokenLen, tokenLen
    
    for layer in range(12):
      for head in range(12):
        cls_data[layer][head] += _attention[layer][0][head][:, 0:1].mean().cpu().detach().numpy() 

CLS_atten = cls_data/len(codes)

100%|██████████| 85/85 [00:03<00:00, 24.20it/s]


In [19]:
print(CLS_atten.shape)
CLS_atten_sum = np.sum(CLS_atten, axis=1)
print(CLS_atten_sum)

(12, 12)
[0.53873959 5.2101759  5.69606181 5.00920714 3.93256064 4.73496921
 5.23256336 5.21673604 7.89330531 8.38361245 8.12879069 5.91060665]


### Calculating Average attention put on SEP token

In [25]:
sep_data = np.zeros((12,12))

with torch.no_grad():
  for eachCode in tqdm(codes):
    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(eachCode))
    # index = inputs.index(tokenizer.mask_token_id)
    inputs_id = torch.tensor([tokenized_text])
    output_from_model = model(inputs_id)
    
    _attention = output_from_model["attentions"]# attention shape is layers, batchsize, heads, tokenLen, tokenLen
    
    for layer in range(12):
      for head in range(12):
        for each_sep_index in torch.where(inputs_id[0]==2)[0].cpu().detach().numpy():
          sep_data[layer][head] += _attention[layer][0][head][:, each_sep_index].mean().cpu().detach().numpy() 

SEP_atten = sep_data/len(codes)

100%|██████████| 85/85 [00:03<00:00, 24.83it/s]


In [26]:
print(SEP_atten.shape)
SEP_atten_sum = np.sum(SEP_atten, axis=1)
print(SEP_atten_sum)

(12, 12)
[0.1957119  0.13042349 0.0988445  0.14585112 0.21771341 0.21395908
 0.14760725 0.114327   0.11055515 0.06455206 0.06105491 0.04995303]


### Average attention on Syntactic Types

In [27]:
def get_syntax_types_for_code(code_snippet):
  types = ["[CLS]"]
  code = ["<s>"]
  tree = list(javalang.tokenizer.tokenize(code_snippet))
  
  for i in tree:
    j = str(i)
    j = j.split(" ")
    if j[1] == '"MASK"':
      types.append('[MASK]')
      code.append('<mask>')
    else:
      types.append(j[0].lower())
      code.append(j[1][1:-1])
    
  types.append("[SEP]")
  code.append("</s>")
  return np.array(types), ' '.join(code)

In [28]:
def get_start_end_of_token_when_tokenized(code, types, tokenizer):
  reindexed_types = []
  start = 0
  end = 0
  for index, each_token in enumerate(code.split(" ")):
    tokenized_list = tokenizer.tokenize(each_token)
    for i in range(len(tokenized_list)):
      end += 1
    reindexed_types.append((start, end-1))
    start = end
  return reindexed_types

In [29]:
def getSyntaxAttentionScore(codes, tokenizer, syntaxType):

  with torch.no_grad():
    identifier = np.zeros((12,12))
    number = 0 
    failed_calculate = 0
    for eachCode in tqdm(codes, desc=syntaxType):
      try: 
        cleancode = eachCode.replace("<s> ", "").replace(" </s>", "").replace('<mask>', 'MASK')
        types, rewrote_code = get_syntax_types_for_code(cleancode)
        # send input to model
        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(rewrote_code))
        input_ids = torch.tensor([tokenized_text])
        output_from_model = model(input_ids)
        # get attention from model
        _attention = output_from_model["attentions"]# attention shape is layers, batchsize, heads, tokenLen, tokenLen
        # get start and end index of each token
        start_end = get_start_end_of_token_when_tokenized(rewrote_code, types, tokenizer)
        if syntaxType in types:
          number += 1
        for layer in range(12):
          for head in range(12):
            for each_sep_index in np.where(types==syntaxType)[0]:
              start_index, end_index = start_end[each_sep_index]
              interim_value = _attention[layer][0][head][:, start_index:end_index+1].mean().cpu().detach().numpy()
              if np.isnan(interim_value):
                  pass
              else: 
                  identifier[layer][head] += interim_value
      except:
        failed_calculate += 1
    print("failed calculate: ", failed_calculate)
                
    identifier = identifier/number
  return identifier, number


In [4]:
syntax_list = ['annotation', 'basictype', 'boolean', 
          'decimalinteger', 'identifier', 'keyword',
          'modifier', 'operator', 'separator', 'null',
          'string', 'decimalfloatingpoint']

In [30]:
avg_attns = {}
avg_attens_sum = {}
syntax_frequenct = {}

for syntax in syntax_list:
    avg_attns[syntax] = np.zeros((12, 12))
    avg_attns[syntax], syntax_frequenct[syntax] = getSyntaxAttentionScore(codes, tokenizer, syntax)
    avg_attens_sum[syntax] = np.sum(avg_attns[syntax], axis=1)

annotation: 100%|██████████| 85/85 [00:03<00:00, 27.36it/s]


failed calculate:  0


basictype: 100%|██████████| 85/85 [00:02<00:00, 30.38it/s]


failed calculate:  0


boolean: 100%|██████████| 85/85 [00:02<00:00, 33.75it/s]


failed calculate:  0


decimalinteger: 100%|██████████| 85/85 [00:02<00:00, 32.55it/s]


failed calculate:  0


identifier: 100%|██████████| 85/85 [00:07<00:00, 11.09it/s]


failed calculate:  0


keyword: 100%|██████████| 85/85 [00:03<00:00, 24.56it/s]


failed calculate:  0


modifier: 100%|██████████| 85/85 [00:02<00:00, 29.09it/s]


failed calculate:  0


operator: 100%|██████████| 85/85 [00:03<00:00, 21.42it/s]


failed calculate:  0


separator: 100%|██████████| 85/85 [00:09<00:00,  9.08it/s]


failed calculate:  0


null: 100%|██████████| 85/85 [00:02<00:00, 34.47it/s]


failed calculate:  0


string: 100%|██████████| 85/85 [00:02<00:00, 33.50it/s]


failed calculate:  0


decimalfloatingpoint: 100%|██████████| 85/85 [00:02<00:00, 37.36it/s]

failed calculate:  0





### Split based on corrrect and incorrect predictions

In [31]:
correct_precition_index = []
misprediction_index = []

for i, code in enumerate(codes): 
    tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code))
    index = tokenized_text.index(tokenizer.mask_token_id)
    input_ids = torch.tensor([tokenized_text])

    with torch.no_grad():
        scores = model(input_ids)[0]
        score_list = scores[0][index]
        word_index = torch.LongTensor(list(idx2word.keys()))
        word_index = torch.zeros(score_list.shape[0]).scatter(0, word_index, 1)
        score_list = score_list + (1-word_index) * -1e6
        predict_word_id = torch.argmax(score_list).data.tolist()
    
    predict_word = idx2word[predict_word_id]
    
    if predict_word == selected_answers[i]:
        correct_precition_index.append(i)
    else:
        misprediction_index.append(i)

In [32]:
len(correct_precition_index), len(misprediction_index)

(74, 11)

In [33]:
correct_codes = []
mispredic_codes = []

for i in correct_precition_index:
    correct_codes.append(codes[i])
    
for i in misprediction_index:
    mispredic_codes.append(codes[i])

In [34]:
correct_avg_attns = {}
correct_avg_attens_sum = {}
correct_syntax_frequenct = {}

for syntax in syntax_list:
    correct_avg_attns[syntax] = np.zeros((12, 12))
    correct_avg_attns[syntax], correct_syntax_frequenct[syntax] = getSyntaxAttentionScore(correct_codes, tokenizer, syntax)
    correct_avg_attens_sum[syntax] = np.sum(correct_avg_attns[syntax], axis=1)

annotation: 100%|██████████| 74/74 [00:02<00:00, 30.31it/s]


failed calculate:  0


basictype: 100%|██████████| 74/74 [00:02<00:00, 33.83it/s]


failed calculate:  0


boolean: 100%|██████████| 74/74 [00:02<00:00, 36.88it/s]


failed calculate:  0


decimalinteger: 100%|██████████| 74/74 [00:02<00:00, 33.89it/s]


failed calculate:  0


identifier: 100%|██████████| 74/74 [00:06<00:00, 10.92it/s]


failed calculate:  0


keyword: 100%|██████████| 74/74 [00:03<00:00, 19.52it/s]


failed calculate:  0


modifier: 100%|██████████| 74/74 [00:02<00:00, 26.14it/s]


failed calculate:  0


operator: 100%|██████████| 74/74 [00:03<00:00, 21.44it/s]


failed calculate:  0


separator: 100%|██████████| 74/74 [00:08<00:00,  8.89it/s]


failed calculate:  0


null: 100%|██████████| 74/74 [00:02<00:00, 32.66it/s]


failed calculate:  0


string: 100%|██████████| 74/74 [00:02<00:00, 32.76it/s]


failed calculate:  0


decimalfloatingpoint: 100%|██████████| 74/74 [00:02<00:00, 36.44it/s]

failed calculate:  0





In [35]:
mispredict_avg_attns = {}
mispredict_avg_attens_sum = {}
mispredict_syntax_frequenct = {}

for syntax in syntax_list:
    mispredict_avg_attns[syntax] = np.zeros((12, 12))
    mispredict_avg_attns[syntax], mispredict_syntax_frequenct[syntax] = getSyntaxAttentionScore(mispredic_codes, tokenizer, syntax)
    mispredict_avg_attens_sum[syntax] = np.sum(mispredict_avg_attns[syntax], axis=1)

annotation: 100%|██████████| 11/11 [00:00<00:00, 22.56it/s]


failed calculate:  0


basictype: 100%|██████████| 11/11 [00:00<00:00, 29.90it/s]


failed calculate:  0


boolean: 100%|██████████| 11/11 [00:00<00:00, 31.81it/s]
  identifier = identifier/number


failed calculate:  0


decimalinteger: 100%|██████████| 11/11 [00:00<00:00, 30.89it/s]


failed calculate:  0


identifier: 100%|██████████| 11/11 [00:01<00:00, 10.68it/s]


failed calculate:  0


keyword: 100%|██████████| 11/11 [00:00<00:00, 27.24it/s]


failed calculate:  0


modifier: 100%|██████████| 11/11 [00:00<00:00, 29.70it/s]


failed calculate:  0


operator: 100%|██████████| 11/11 [00:00<00:00, 24.86it/s]


failed calculate:  0


separator: 100%|██████████| 11/11 [00:01<00:00, 10.29it/s]


failed calculate:  0


null: 100%|██████████| 11/11 [00:00<00:00, 35.57it/s]


failed calculate:  0


string: 100%|██████████| 11/11 [00:00<00:00, 34.73it/s]


failed calculate:  0


decimalfloatingpoint: 100%|██████████| 11/11 [00:00<00:00, 38.44it/s]

failed calculate:  0





annotation
Ttest_indResult(statistic=3.447045271752126, pvalue=0.0022979045790702934)
-----------------
basictype
Ttest_indResult(statistic=3.865915629913053, pvalue=0.0008359468076887228)
-----------------
boolean
Ttest_indResult(statistic=nan, pvalue=nan)
-----------------
decimalinteger
Ttest_indResult(statistic=6.582215566506413, pvalue=1.2791377235426926e-06)
-----------------
identifier
Ttest_indResult(statistic=-0.627945912747651, pvalue=0.5365039292636904)
-----------------
keyword
Ttest_indResult(statistic=0.5595072193479274, pvalue=0.5814677909015604)
-----------------
modifier
Ttest_indResult(statistic=-1.2976000912130323, pvalue=0.20786655274103427)
-----------------
operator
Ttest_indResult(statistic=0.5761339491592351, pvalue=0.5703723068155411)
-----------------
separator
Ttest_indResult(statistic=-0.018223083697325476, pvalue=0.9856251496106448)
-----------------
null
Ttest_indResult(statistic=-1.0264636168881236, pvalue=0.3158275780621145)
-----------------
string
Ttes

In [45]:
import pickle 

# read data from a pick file 
file1 = "results/CLS_atten_sum_correct.pkl"
file2 = "results/CLS_atten_sum_misprediction.pkl"

correct_CLS = pickle.load(open(file1, "rb"))
mispredict_CLS = pickle.load(open(file2, "rb"))

stats.ttest_ind(correct_CLS, mispredict_CLS)

Ttest_indResult(statistic=-0.03177469619889501, pvalue=0.974938237839359)

In [46]:
# read data from a pick file 
file1 = "results/SEP_atten_sum_correct.pkl"
file2 = "results/SEP_atten_sum_misprediction.pkl"

correct_CLS = pickle.load(open(file1, "rb"))
mispredict_CLS = pickle.load(open(file2, "rb"))

stats.ttest_ind(correct_CLS, mispredict_CLS)

Ttest_indResult(statistic=-0.5376287728995854, pvalue=0.5962304462821135)

In [3]:
import pickle 
# read data from a pick file 
file1 = "results/syntax_atten_attention_correct.pkl"
file2 = "results/syntax_atten_attention_misprediction.pkl"

correct_Syntax = pickle.load(open(file1, "rb"))
mispredict_Syntax = pickle.load(open(file2, "rb"))

# stats.ttest_ind(correct_CLS, mispredict_CLS)

In [6]:
from scipy import stats
for syntax in syntax_list:
    print(syntax)
    print(stats.ttest_ind(correct_Syntax[syntax], mispredict_Syntax[syntax]))
    print("-----------------")

annotation
Ttest_indResult(statistic=-0.10900912901747421, pvalue=0.9141834332621221)
-----------------
basictype
Ttest_indResult(statistic=0.4417422652789735, pvalue=0.6629873183982515)
-----------------
boolean
Ttest_indResult(statistic=0.380070647429576, pvalue=0.7075375418955014)
-----------------
decimalinteger
Ttest_indResult(statistic=0.6374880193760781, pvalue=0.5303862769151588)
-----------------
identifier
Ttest_indResult(statistic=0.04539532784516388, pvalue=0.9642017468560653)
-----------------
keyword
Ttest_indResult(statistic=0.2467891018367265, pvalue=0.8073597443604594)
-----------------
modifier
Ttest_indResult(statistic=-0.10819073376114986, pvalue=0.9148250755338021)
-----------------
operator
Ttest_indResult(statistic=0.23067892868031312, pvalue=0.819695913827764)
-----------------
separator
Ttest_indResult(statistic=0.21227993166624043, pvalue=0.8338431361591297)
-----------------
null
Ttest_indResult(statistic=-0.2872184961987086, pvalue=0.776632943548014)
-------