In [1]:
from tree_sitter import Language, Parser
import javalang
import numpy as np 

Language.build_library(
	# Store the library in the `build` directory
	'build/my-languages.so',	
	# Include one or more languages
	[
		'/Users/jirigesi/Documents/tree-sitter-java'
	]
)
JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
parser = Parser()

parser.set_language(JAVA_LANGUAGE)

In [2]:
def traverse(code, node,depth=0):
    declaration = {}
    stack = []
    stack.append(node)
    while stack:
        node = stack.pop()
        if ('declaration' in node.type and node.type != "local_variable_declaration") or 'if_statement' in node.type or 'else' in node.type:
            data = code[node.start_byte:node.end_byte].split('{')[0].strip().split(' ')
            if node.type in declaration:
                declaration[node.type].append(data)
            else:
                declaration[node.type] = [data]
        for child in node.children:
            stack.append(child)
    return declaration

def label_tokens(token_list, declaration):
    types = [] 
    for token in token_list:
        flag = False
        for key in declaration:
            for value in declaration[key]:
                if token in value:
                    types.append(key)
                    flag = True
                    break
            if flag:
                break
        if not flag:
            types.append("other")
    return types

In [3]:
def get_extended_types(token_list, types):
    tree = list(javalang.tokenizer.tokenize(" ".join(token_list)))
    code = ' '.join(token_list)
    right = 0
    left = 0
    postion_mapping = [] 

    while right < len(code):
        if code[right] == ' ':
            postion_mapping.append((left, right))
            left = right + 1
        right += 1

    # add the last token
    postion_mapping.append((left, right))
    code = ["<s>"]
    extended_types = []
    left = 0
    for node in tree:
        # rewrite code
        node = str(node).split(' ')
        if node[1] == '"MASK"':
            code.append('<mask>')
        else:
            code.append(node[1][1:-1])
        # extend types
        left = int(node[-1]) -1
        right = left + len(node[1][1:-1])
        # check (left, right) in postion_mapping and get the index
        for i in range(len(postion_mapping)):
            if left >= postion_mapping[i][0] and right <= postion_mapping[i][1]:
                extended_types.append([types[i], node[1]])
                break
    code.append("</s>")
    return extended_types, ' '.join(code)

In [4]:
def get_ast_types(code):
    code = code.replace("{", " {")
    code = " ".join(code.split())
    code_list = code.split(' ')
    tree = parser.parse(bytes(code, "utf8"))
    root_node = tree.root_node
    
    declaration = traverse(code, root_node)
    types = label_tokens(code_list, declaration)

    ast_types, rewrote_code = get_extended_types(code_list, types)
    # check the index of first second value is the "{"
    if ast_types[0][1] == '"class"':
        return ['[CLS]'] + [i[0] for i in ast_types] + ['[SEP]'], rewrote_code
    index_ = 0
    # if not class declaration, find the first "{" and add method_declaration before it
    for i in range(len(ast_types)):
        if ast_types[i][1] == '"{"':
            index_ = i
            break
    final_types = [] 
    final_types.append('[CLS]')
    for i in range(len(ast_types)):
        if i < index_:
            final_types.append("method_declaration")
        else:
            final_types.append(ast_types[i][0])
    final_types.append('[SEP]')
    return final_types, rewrote_code

code = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"
final_types, rewrote_code = get_ast_types(code)

In [5]:
len(final_types), len(rewrote_code.split(' '))

(27, 27)

In [6]:
final_types

['[CLS]',
 'class_declaration',
 'class_declaration',
 'other',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 '[SEP]']

In [7]:
rewrote_code

"<s> class Simple { public static void main ( String args [ ] ) { System . out . println ( 'Hell ) ; } } </s>"

In [8]:
import json

file_path = "../dataset/valid.txt"
postfix=file_path.split('/')[-1].split('.txt')[0]
index_filename=file_path
url_to_code={}
with open('/'.join(index_filename.split('/')[:-1])+'/data.jsonl') as f:
    for line in f:
        line=line.strip()
        js=json.loads(line)
        url_to_code[js['idx']]=js['func']
data=[]
cache={}
f=open(index_filename)
with open(index_filename) as f:
    lines = 1000
    added_lines = 0
    for line in f:
        # control number of read data 
        if added_lines >= lines:
            break
        line=line.strip()
        url1,url2,label=line.split('\t')
        if url1 not in url_to_code or url2 not in url_to_code:
            continue
        if label=='0':
            label=0
        else:
            label=1
        data.append((url1,url2,label,' '.join(url_to_code[url1].split()), ' '.join(url_to_code[url2].split())))
        added_lines += 1

In [9]:
len(data)

1000

In [10]:
code_sample = data[10]

In [11]:
types_1, rewrote_code_1 = get_ast_types(code_sample[3])

In [12]:
len(types_1), len(rewrote_code_1.split(' '))

(144, 144)

In [13]:
types_2, rewrote_code_2 = get_ast_types(code_sample[4])

In [14]:
len(types_2), len(rewrote_code_2.split(' '))

(109, 109)

In [15]:
from transformers import  RobertaConfig, RobertaModel, RobertaTokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('microsoft/codebert-base',
                                    output_attentions=True, 
                                    output_hidden_states=True)

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
from model2 import Model
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model=Model(model, config, tokenizer)
# checkpoint_prefix = "/Users/jirigesi/Documents/icse2023/attentionBias/Clone-detection-BigCloneBench/code/saved_models/checkpoint-best-f1/model.bin"
# model.load_state_dict(torch.load(checkpoint_prefix, map_location='cpu'))
# model = model.to(device)

In [17]:
tokenized_ids_1 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(rewrote_code_1))
tokenized_ids_2 = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(rewrote_code_2))

In [18]:
block_size = 400
if len(tokenized_ids_2) > 400:
    tokenized_ids_2 = tokenized_ids_2[:399] + [tokenizer.sep_token_id]

if len(tokenized_ids_1) > 400:
    tokenized_ids_1 = tokenized_ids_1[:399] + [tokenizer.sep_token_id]

padding_length = block_size - len(tokenized_ids_1)
tokenized_ids_1+=[tokenizer.pad_token_id]*padding_length
padding_length = block_size - len(tokenized_ids_2)
tokenized_ids_2+=[tokenizer.pad_token_id]*padding_length

In [19]:
source_ids = tokenized_ids_1 + tokenized_ids_2
labels = code_sample[2]
source_ids = torch.tensor(source_ids).unsqueeze(0).to(device)
labels = torch.tensor(labels).unsqueeze(0).to(device)

In [20]:
with torch.no_grad():
    output = model(block_size,source_ids,labels)

In [24]:
attention = output[2].attentions

In [25]:
len(attention), attention[0].shape

(12, torch.Size([2, 12, 400, 400]))

In [54]:
def get_start_end_of_token_when_tokenized(code_list, types, tokenizer):
  reindexed_types = []
  start = 0
  end = 0
  for each_token in code_list: 
      tokenized_list = tokenizer.tokenize(each_token)
      end += len(tokenized_list)
      reindexed_types.append((start, end-1))
      start = end
  return reindexed_types


In [32]:
rewrote_code_1

'<s> public static void copyFile ( String file1 , String file2 ) { File filedata1 = new java . io . File ( file1 ) ; if ( filedata1 . exists ( ) ) { try { BufferedOutputStream out = new BufferedOutputStream ( new FileOutputStream ( file2 ) ) ; BufferedInputStream in = new BufferedInputStream ( new FileInputStream ( file1 ) ) ; try { int read ; while ( ( read = in . read ( ) ) != - 1 ) { out . write ( read ) ; } out . flush ( ) ; } catch ( IOException ex1 ) { ex1 . printStackTrace ( ) ; } finally { out . close ( ) ; in . close ( ) ; } } catch ( Exception ex ) { ex . printStackTrace ( ) ; } } } </s>'

In [51]:
reindexed_types = []
start = 0
end = 0
for each_token in code_list: 
    tokenized_list = tokenizer.tokenize(each_token)
    end += len(tokenized_list)
    reindexed_types.append((start, end-1))
    start = end

In [53]:
len(reindexed_types), len

144

In [55]:
code_list = rewrote_code_1.split(' ')
start_end = get_start_end_of_token_when_tokenized(code_list, types_1, tokenizer)

In [56]:
len(start_end), len(types_1), len(code_list)

(144, 144, 144)

In [59]:
types_1

['[CLS]',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'if_statement',
 'if_statement',
 'if_statement',
 'if_statement',
 'if_statement',
 'if_statement',
 'if_statement',
 'if_statement',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 

In [69]:
attention_weights = [[[] for col in range(12)] for row in range(12)]
for layer in range(12):
    for head in range(12):
        for each_sep_index in np.where(types=='if_statement')[0]:
            print(each_sep_index)
            start_index, end_index = start_end[each_sep_index]
            interim_value = attention[layer][0][head][:, start_index:end_index+1].mean().cpu().detach().numpy()
            if np.isnan(interim_value):
                pass
            else:
                attention_weights[layer][head].append(interim_value)

26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
26
27
28
29
30
3

In [68]:
attention_weights

[[[array(0.00407362, dtype=float32),
   array(0.00362173, dtype=float32),
   array(0.00385879, dtype=float32),
   array(0.00255353, dtype=float32),
   array(0.00251998, dtype=float32),
   array(0.00254145, dtype=float32),
   array(0.00188253, dtype=float32),
   array(0.00259924, dtype=float32),
   array(0.00223621, dtype=float32),
   array(0.00193295, dtype=float32),
   array(0.00237823, dtype=float32)],
  [array(0.0029461, dtype=float32),
   array(0.00313395, dtype=float32),
   array(0.00270069, dtype=float32),
   array(0.00287512, dtype=float32),
   array(0.00211838, dtype=float32),
   array(0.00234901, dtype=float32),
   array(0.00235187, dtype=float32),
   array(0.00220437, dtype=float32),
   array(0.00246631, dtype=float32),
   array(0.00262648, dtype=float32),
   array(0.00232628, dtype=float32)],
  [array(0.00386662, dtype=float32),
   array(0.00369617, dtype=float32),
   array(0.00622805, dtype=float32),
   array(0.0035585, dtype=float32),
   array(0.00267758, dtype=float32),
 

In [66]:
types = np.array(types_1)