In [1]:
import torch
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaTokenizer
import argparse
import json
import os
import pandas as pd
import numpy as np 
from tqdm import tqdm
from scipy import stats
import javalang
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, LayerActivation
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

if torch.__version__ >= '1.7.0':
    norm_fn = torch.linalg.norm
else:
    norm_fn = torch.norm

In [2]:
device = torch.device("cpu")
MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)}

config_class, model_class, tokenizer_class = MODEL_CLASSES['roberta']
config = config_class.from_pretrained('roberta-base')
tokenizer = tokenizer_class.from_pretrained('roberta-base')

model = RobertaForMaskedLM.from_pretrained('microsoft/codebert-base-mlm', 
                                           output_attentions=True, output_hidden_states=True)

In [11]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

token_reference = TokenReferenceBase(reference_token_idx=ref_token_id)

In [3]:
def get_cloze_words(filename, tokenizer):
    with open(filename, 'r', encoding='utf-8') as fp:
        words = fp.read().split('\n')
    idx2word = {tokenizer.encoder[w]: w for w in words}
    return idx2word

In [4]:
cloze_results = []
cloze_words_file = 'data/cloze-all/cloze_test_words.txt'
file_path = 'data/cloze-all/java/clozeTest.json'

idx2word = get_cloze_words(cloze_words_file, tokenizer)
lines = json.load(open(file_path))
len(lines)

In [6]:
def read_answers(filename):
    answers = {}
    with open(filename, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            answers[line.split('<CODESPLIT>')[0]] = line.split('<CODESPLIT>')[1]
    return answers

answer_file = 'evaluator/answers/java/answers.txt'
answers = read_answers(answer_file)
answer_list = list(answers.values())
print(len(answer_list))

40492


In [66]:
bestSampleWithMaxPairLength = []
bestSampleWithMaxPairLength_LEN =[]

number_of_samples = 10
for i in range(len(lines[:number_of_samples])):
    code = ' '.join(lines[i]['pl_tokens'])
    bestStr = "<s> " + code + " </s>"
    bestLen = len(bestStr.split(" "))
    bestSampleWithMaxPairLength.append(bestStr)
    bestSampleWithMaxPairLength_LEN.append(bestLen)

In [67]:
lengths=[]
codes=[]
selected_answers = []

for index, code in enumerate(bestSampleWithMaxPairLength):
  l = len(tokenizer.tokenize(code))
  if l<=256:
    lengths.append(l)
    codes.append(code)
    selected_answers.append(answer_list[index])

len(codes), len(selected_answers)

(6, 6)

### Extract attribution score

In [9]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / norm_fn(attributions)
    return attributions

In [10]:
def construct_whole_bert_embeddings(input_ids, ref_input_ids):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids)
    
    return input_embeddings, ref_input_embeddings

In [40]:
def predict_forward_func(input_embeddings):
    output = model(inputs_embeds=input_embeddings)
    index = tokenized_text.index(tokenizer.mask_token_id)
    output_list = output.logits[0][index]
    output_list = output_list.unsqueeze(0)
    
    return output_list.max(1).values
    

In [12]:
# use code as example
code = codes[0]
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code))
input_ids = torch.tensor([tokenized_text])
reference_indices = token_reference.generate_reference(input_ids.shape[1], device=device).unsqueeze(0)

torch.Size([1, 78])
torch.Size([1, 78])


In [14]:
interpretable_embedding = configure_interpretable_embedding_layer(model, 'roberta.embeddings.word_embeddings')



### Calculate Average attribution on CLS 

In [60]:
cls_data = np.zeros((12,12))

with torch.no_grad():
    for code in tqdm(codes):
        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code))
        input_ids = torch.tensor([tokenized_text])
        reference_indices = token_reference.generate_reference(input_ids.shape[1], device=device).unsqueeze(0)

        layer_attrs = []
        layer_attn_mat = []
        input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, reference_indices)

        for i in range(model.config.num_hidden_layers):
            lc = LayerConductance(predict_forward_func, 
                                model.roberta.encoder.layer[i])
            layer_attributions = lc.attribute(inputs=input_embeddings, 
                                                    baselines=ref_input_embeddings, 
                                                    additional_forward_args=())
            layer_attrs.append(summarize_attributions(layer_attributions[0]))
            layer_attn_mat.append(layer_attributions[1])
        # layer x seq_len
        layer_attrs = torch.stack(layer_attrs)
        # layer x batch x head x seq_len x seq_len
        layer_attn_mat = torch.stack(layer_attn_mat)
        for layer in range(12):
            for head in range(12):
                cls_data[layer][head] += layer_attn_mat[layer][0][head][:, 0:1].mean().cpu().detach().numpy()
            
CLS_atten = cls_data/len(codes)

100%|██████████| 85/85 [10:41<00:00,  7.55s/it]


In [58]:
# import seaborn as sns
# import matplotlib.pyplot as plt

# indices = input_ids[0].detach().tolist()
# all_tokens = tokenizer.convert_ids_to_tokens(indices)

# fig, ax = plt.subplots(figsize=(15,5))
# xticklabels=all_tokens
# yticklabels=list(range(1,13))
# ax = sns.heatmap(layer_attrs.cpu().detach().numpy(), xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2)
# plt.xlabel('Tokens')
# plt.ylabel('Layers')
# plt.show()

In [61]:
print(CLS_atten.shape)
CLS_atten_sum = np.sum(CLS_atten, axis=1)
print(CLS_atten_sum)

(12, 12)
[ 3.22083902e-07 -1.26793037e-05  1.03388186e-05  8.71080211e-06
  3.89063882e-06 -1.46472271e-05  1.53170432e-05 -3.57923516e-06
  2.00554549e-06 -3.51545516e-07  6.12477663e-06  9.59633963e-06]


### Calculating Average attribution put on SEP token

In [68]:
sep_data = np.zeros((12,12))

with torch.no_grad():
    for code in tqdm(codes):
        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(code))
        input_ids = torch.tensor([tokenized_text])
        reference_indices = token_reference.generate_reference(input_ids.shape[1], device=device).unsqueeze(0)

        layer_attrs = []
        layer_attn_mat = []
        input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, reference_indices)

        for i in range(model.config.num_hidden_layers):
            lc = LayerConductance(predict_forward_func, 
                                model.roberta.encoder.layer[i])
            layer_attributions = lc.attribute(inputs=input_embeddings, 
                                                    baselines=ref_input_embeddings, 
                                                    additional_forward_args=())
            layer_attrs.append(summarize_attributions(layer_attributions[0]))
            layer_attn_mat.append(layer_attributions[1])
        # layer x seq_len
        layer_attrs = torch.stack(layer_attrs)
        # layer x batch x head x seq_len x seq_len
        layer_attn_mat = torch.stack(layer_attn_mat)
        for layer in range(12):
          for head in range(12):
            for each_sep_index in torch.where(input_ids[0]==2)[0].cpu().detach().numpy():
              sep_data[layer][head] += layer_attn_mat[layer][0][head][:, each_sep_index].mean().cpu().detach().numpy() / len(torch.where(input_ids[0]==2)[0].cpu().detach().numpy())

            
SEP_atten = sep_data/len(codes)

100%|██████████| 6/6 [00:40<00:00,  6.71s/it]


### Average attention on Syntactic Types

In [69]:
def get_syntax_types_for_code(code_snippet):
  types = ["[CLS]"]
  code = ["<s>"]
  tree = list(javalang.tokenizer.tokenize(code_snippet))
  
  for i in tree:
    j = str(i)
    j = j.split(" ")
    if j[1] == '"MASK"':
      types.append('[MASK]')
      code.append('<mask>')
    else:
      types.append(j[0].lower())
      code.append(j[1][1:-1])
    
  types.append("[SEP]")
  code.append("</s>")
  return np.array(types), ' '.join(code)

In [70]:
def get_start_end_of_token_when_tokenized(code, types, tokenizer):
  reindexed_types = []
  start = 0
  end = 0
  for index, each_token in enumerate(code.split(" ")):
    tokenized_list = tokenizer.tokenize(each_token)
    for i in range(len(tokenized_list)):
      end += 1
    reindexed_types.append((start, end-1))
    start = end
  return reindexed_types

In [73]:
code = codes[0]

cleancode = code.replace("<s> ", "").replace(" </s>", "").replace('<mask>', 'MASK')
types, rewrote_code = get_syntax_types_for_code(cleancode)

In [78]:
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(rewrote_code))
input_ids = torch.tensor([tokenized_text])

In [79]:
reference_indices = token_reference.generate_reference(input_ids.shape[1], device=device).unsqueeze(0)

layer_attrs = []
layer_attn_mat = []
input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, reference_indices)

In [80]:
for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(predict_forward_func, 
                        model.roberta.encoder.layer[i])
    layer_attributions = lc.attribute(inputs=input_embeddings, 
                                            baselines=ref_input_embeddings, 
                                            additional_forward_args=())
    layer_attrs.append(summarize_attributions(layer_attributions[0]))
    layer_attn_mat.append(layer_attributions[1])

In [81]:
 # layer x seq_len
layer_attrs = torch.stack(layer_attrs)
# layer x batch x head x seq_len x seq_len
layer_attn_mat = torch.stack(layer_attn_mat)

In [83]:
# get start and end index of each token
start_end = get_start_end_of_token_when_tokenized(rewrote_code, types, tokenizer)
 

In [84]:

syntaxType = 'annotation'
for layer in range(12):
    for head in range(12):
        for each_sep_index in np.where(types==syntaxType)[0]:
            start_index, end_index = start_end[each_sep_index]
            interim_value = layer_attn_mat[layer][0][head][:, start_index:end_index+1].mean().cpu().detach().numpy()

In [85]:
def getSyntaxAttributionScore(codes, tokenizer, syntaxType):

  with torch.no_grad():
    identifier = np.zeros((12,12))
    number = 0 
    failed_calculate = 0
    for eachCode in tqdm(codes, desc=syntaxType):
      try: 
        cleancode = eachCode.replace("<s> ", "").replace(" </s>", "").replace('<mask>', 'MASK')
        types, rewrote_code = get_syntax_types_for_code(cleancode)
        # send input to model
        tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(rewrote_code))
        input_ids = torch.tensor([tokenized_text])
        # get reference indices
        reference_indices = token_reference.generate_reference(input_ids.shape[1], device=device).unsqueeze(0)

        layer_attrs = []
        layer_attn_mat = []
        input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, reference_indices)
        # get layer attribution
        for i in range(model.config.num_hidden_layers):
          lc = LayerConductance(predict_forward_func, 
                              model.roberta.encoder.layer[i])
          layer_attributions = lc.attribute(inputs=input_embeddings, 
                                                  baselines=ref_input_embeddings, 
                                                  additional_forward_args=())
          layer_attrs.append(summarize_attributions(layer_attributions[0]))
          layer_attn_mat.append(layer_attributions[1])
          
        # layer x seq_len
        layer_attrs = torch.stack(layer_attrs)
        # layer x batch x head x seq_len x seq_len
        layer_attn_mat = torch.stack(layer_attn_mat)
        # get start and end index of each token
        start_end = get_start_end_of_token_when_tokenized(rewrote_code, types, tokenizer)
        if syntaxType in types:
          number += 1
        for layer in range(12):
          for head in range(12):
            for each_sep_index in np.where(types==syntaxType)[0]:
              start_index, end_index = start_end[each_sep_index]
              interim_value = layer_attn_mat[layer][0][head][:, start_index:end_index+1].mean().cpu().detach().numpy()
              if np.isnan(interim_value):
                  pass
              else: 
                  identifier[layer][head] += interim_value
      except:
        failed_calculate += 1
    print("failed calculate: ", failed_calculate)
                
    identifier = identifier/number
  return identifier, number

In [72]:
syntax_list = ['annotation']

In [86]:
avg_attns = {}
avg_attens_sum = {}
syntax_frequenct = {}

for syntax in syntax_list:
    avg_attns[syntax] = np.zeros((12, 12))
    avg_attns[syntax], syntax_frequenct[syntax] = getSyntaxAttributionScore(codes, tokenizer, syntax)
    avg_attens_sum[syntax] = np.sum(avg_attns[syntax], axis=1)

annotation: 100%|██████████| 6/6 [00:31<00:00,  5.26s/it]

failed calculate:  3



