此 .ipynb为按照行级挑选最有代表性的代码注释对

In [None]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re

import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, TensorDataset


from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

# 解决服务器挂掉的问题
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device( "cpu")

In [None]:
MAX_LEN = 1024
MAX_EPOCHS = 10
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
NUM_CLASSES = 2
WEIGTH_DECAY = 1e-3

In [None]:
def retrieve_test_data():
    test_param = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/param/test.json")
    test_return = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/return/test.json")
    test_summary = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/summary/test.json")
    test_df = pd.concat([test_summary,test_param, test_return], axis=0)
    test_df = test_df.reset_index(drop=True)
    return test_df
test_df = retrieve_test_data()

In [None]:
def get_lines_count(df):
    line_counts = []
    for i in range(len(df)):
        string = df.loc[i]['new_code_raw']
        line_count = len(string.split('\n'))
        line_counts.append(line_count)
    df['line_counts'] = line_counts
    return df
test_df = get_lines_count(test_df)
test_df.head()

In [None]:
def remove_last(df):
    for i in range(len(df)):
        string = df.loc[i]['old_comment_raw']
        df['old_comment_raw'][i] = string.rstrip('.')
    return df
test_df = remove_last(test_df)

In [None]:
df_clean = test_df
df_clean.head()

In [None]:
# model = torch.load('save_GCBmodel.pt',map_location=torch.device('cuda:0'))
model = torch.load('D:/BERT_learing/CCDP/for_captum/save_model/save_GCBmodel.pt',map_location=torch.device('cpu'))

model.to(device)

In [None]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())  #输出为True，则安装无误


In [None]:
device

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("D:/BERT_learing/code_comment_inconsistency_detection/graphcodebert")

In [None]:
model.roberta.embeddings

In [None]:
model.roberta.encoder.layer[0]

In [None]:
input_embeddings = model.get_input_embeddings()

In [None]:
# predict和squad_pos_forward_func可以合成一个
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs,
                   position_ids=position_ids,
                  attention_mask=attention_mask )
    
    prediction = output.logits
    prediction_1 = nn.functional.softmax(prediction, dim=1)
    prediction = prediction_1.max(1).values
    out = torch.argmax(prediction_1, dim=-1)
    # prediction：每个输入样本的最大预测概率。
    # out：预测的类别标签。
    # prediction_1：所有类别的预测概率。    
    return prediction,out,prediction_1

In [None]:
def squad_pos_forward_func(inputs,position_ids=None, attention_mask=None, position=0):
    pred ,_,_= predict(inputs,
                     position_ids=position_ids,
                   attention_mask=attention_mask)
    return pred

In [None]:
ref_token_id = tokenizer.pad_token_id # 0
sep_token_id = tokenizer.sep_token_id # 101
cls_token_id = tokenizer.cls_token_id # 102
ref_token_id,sep_token_id,cls_token_id

In [None]:
# 注意长度
def truncate(ids,len_tru = 512):
    return ids[:len_tru] if len(ids) > len_tru else ids

In [None]:
# 这是单个数据的处理方式，应该要想数据集应该怎么处理
def construct_input_ref_pair(comment,AST_type,  ref_token_id, sep_token_id, cls_token_id):
    comment = tokenizer.encode(comment, add_special_tokens=False,truncation=True,max_length=512)
    AST_type = tokenizer.encode(AST_type, add_special_tokens=False,truncation=True,max_length=512)
    # construct input token ids
    input_ids = [cls_token_id] + comment + [sep_token_id] + AST_type + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(comment) + [sep_token_id] + \
        [ref_token_id] * len(AST_type) + [sep_token_id]
    input_ids = truncate(input_ids)
    ref_input_ids = truncate(ref_input_ids)
    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(comment)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.roberta.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.roberta.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [None]:
# 前k个贡献最高的word 和 token_type 和 position
# return value为归因贡献值  indices为词对应的索引  top_tokens为 词或位置或token_type
def get_topk_attributed_tokens(attrs,all_token_t, k=5):
    values_max, indices_max = torch.topk(attrs, k)
    top_tokens_max = [all_token_t[idx] for idx in indices_max]
    values_min, indices_min = torch.topk(attrs, k, largest=False)
    top_tokens_min = [all_token_t[idx] for idx in indices_min] 
    
    return top_tokens_max, values_max,top_tokens_min,values_min


In [None]:
import re
import string

def split_punctuation(s):
    # 使用正则表达式匹配连续的标点符号或者字母和标点符号之间的位置
    splits = re.finditer(r'(?<=\w)(?=[{}])|(?<=[{}])(?=\w)'.format(string.punctuation, string.punctuation), s)
    
    # 获取所有分割位置
    split_positions = [match.start() for match in splits]
    
    # 在分割位置插入空格
    for pos in reversed(split_positions):
        s = s[:pos] + ' ' + s[pos:]
        
    s = s.replace("< s >", "<s>")
    s = s.replace("</ s >", "</s>")
    return s


In [None]:
# 从 all_tokens 还原为 原单词 ，并且计算归因值
def get_restore_words(code,comment,input_ids,all_tokens,attribution_num):
    all_tokens_decode = tokenizer.decode(input_ids)
    len_all_tokens_decode = len(all_tokens_decode) + 4
    # 使用decode获得的序列，去掉分词之后的空格    例 ' a' -> 'a'
    all_tokens_clean = []
    for token in all_tokens:
        s_without_leading_space = token.lstrip()
        all_tokens_clean.append(s_without_leading_space)
#     print('all_tokens_clean:\n',all_tokens_clean)
#     print('all_tokens_clean:\n',len(all_tokens_clean))
    

    # 获得 code_comment_baseline
    code = tokenizer.encode(code, add_special_tokens=False,truncation=True,max_length=512)
    comment = tokenizer.encode(comment, add_special_tokens=False,truncation=True,max_length=512)
    code_decode = tokenizer.decode(code)
    comment_decode = tokenizer.decode(comment)
    
    code_comment_baseline = tokenizer.decode(tokenizer.cls_token_id) + ' '+ comment_decode \
                            + ' '+ tokenizer.decode(tokenizer.sep_token_id) + ' ' + code_decode \
                            + ' ' + tokenizer.decode(tokenizer.sep_token_id)
    
    code_comment_baseline = code_comment_baseline[:len_all_tokens_decode]
    code_comment_baseline = split_punctuation(code_comment_baseline)
    code_comment_baseline = code_comment_baseline.split()

#     print('code_comment_baseline:\n',code_comment_baseline)
#     print('code_comment_baseline_len:\n',len(code_comment_baseline))

    
    # 获得 相邻有几个token合并在一块的列表times  为了以后再计算attribute时求和
    times = []
    token_index = 0
    for code_comment in code_comment_baseline:
        temp = ''
        time = 0
        while temp != code_comment:
            temp = temp + all_tokens_clean[token_index]
            token_index = token_index + 1
            time = time + 1
        times.append(time)
#     print('times:\n',times)
    
    attribute_sum = []
    start = 0
    for time in times:
        end = start + time
        attribute = sum(attribution_num[start:end])
        attribute_sum.append(attribute)
        start = end
#     print('attribute_sum:\n',attribute_sum)
    return code_comment_baseline ,attribute_sum 

In [None]:
def get_line_code_and_attribute(code,code_all_tokens,attributions_num):
    # 特殊符号前  加空格
    code = split_punctuation(code)
    
    # 把  行前空格去掉  便于单词与行之间匹配
    code = [line.lstrip() for line in code.splitlines()]
    code = '\n'.join(code)  
    code_lineList = code.split('\n')
    code_lineList = [' '.join(x.split()) for x in code_lineList]
    # 有空行，把空行去掉
    code_lineList = [item for item in code_lineList if item != '']
    
    attribute = []
    i = 0  
    for code_line in code_lineList:
        if i < len(code_all_tokens):
            temp = code_all_tokens[i]
            attr = attributions_num[i]
            i = i + 1
        while((i < len(code_all_tokens)) and(temp != code_line)):
            attr = attr + attributions_num[i]
            temp = temp + ' ' + code_all_tokens[i]
            i = i + 1

        attribute.append(attr)
        
    for i in range(len(code_lineList)):
        if len(code_lineList[i])== 1:
            attribute[i] = torch.tensor(0, dtype=torch.float64)
            
    return code_lineList,attribute   

In [None]:
def remove_before_and_including(lst, element):
    if element in lst:
        index = lst.index(element)
        lst_c = lst[index+1:]
        if element in lst_c:
            lst_c.remove(element)
        return lst_c
    else:
        return lst
    
def remove_after_including(lst, element):
    if element in lst:
        index = lst.index(element)
        return lst[:index + 1]
    else:
        return lst

In [None]:
def final_all_attribute(code,all_tokens,attributions_num):
    # 删除 all_token 列表中的<s>注释</s>  </s>
    code_all_tokens = remove_before_and_including(all_tokens,'</s>')
    comment_all_tokens = remove_after_including(all_tokens,'</s>')  
    index = all_tokens.index('</s>')
    attribute_comment = attributions_num[:index+1]
#     print(attribute_comment)
    
    attribute_code = attributions_num[index+1:-1]
#     print(attribute_code,len(attribute_code))
    code_lineList_token,attribute_num_code = get_line_code_and_attribute(code,code_all_tokens,attribute_code)
##     print(attribute_num_code)
    attribute_num_code = torch.stack(attribute_num_code)
    
    new_all_tokens = comment_all_tokens + code_lineList_token
    attributions_num_all = torch.cat((attribute_comment, attribute_num_code))

#     attributions_num_all = attribute_comment + attribute_num_code
#     print(len(attribute_comment),len(attribute_num_code))
#     print(new_all_tokens)
    return new_all_tokens,attributions_num_all,code_lineList_token,attribute_num_code

In [None]:
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

lig = LayerIntegratedGradients(squad_pos_forward_func,input_embeddings)

vis_data_records_ig = []

def interpret_sentence(code,comment,old_code,input_ids,ref_input_ids, token_type_ids,\
                       position_ids, attention_mask, all_tokens, ground_lable,vis_data_records):
    pre ,out,_ = predict(input_ids, position_ids=position_ids,attention_mask=attention_mask)
    if out == 1:
        sen_type = 'pos'
    else:
        sen_type = 'nag'
    pre = pre.item()
    pre = "{:.3f}".format(pre)
    pre = float(pre) 
    pre ,sen_type
    
    attributions_ig, delta_ig = lig.attribute(input_ids, baselines=ref_input_ids,\
                           additional_forward_args=(position_ids,attention_mask,0),return_convergence_delta=True,internal_batch_size=8)
    
    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
#     print(delta_ig)
#     print(attributions)
    all_tokens ,attributions = get_restore_words(code,comment,input_ids[0],all_tokens,attributions)     # 合并为一个单词
    attributions = torch.tensor(attributions)
    print(all_tokens)
    # code_lineList_token 和 attribute_num_code 用于后续统计分析
    new_all_tokens,attributions_num_all,code_lineList_token,attribute_num_code= final_all_attribute(code,all_tokens,attributions)
    
    add_attributions_to_visualizer(attributions_num_all, new_all_tokens, pre, ground_lable, sen_type, delta_ig, vis_data_records)

    top_tokens_max, values_max,top_tokens_min,\
    values_min = get_topk_attributed_tokens(attributions_num_all,new_all_tokens)

    return top_tokens_max, values_max,top_tokens_min,values_min,\
           code,old_code,code_lineList_token,attribute_num_code

def add_attributions_to_visualizer(attributions, all_tokens, pre, ground_lable, sen_type, delta, vis_data_records):

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions,
                            pre,
                            pre,
                            ground_lable,
                            sen_type,
                            attributions.sum(),
                            all_tokens,
                            delta))


In [None]:
def interpret_sentence_2(code,comment,input_ids,ref_input_ids, token_type_ids,\
                         position_ids, attention_mask, all_tokens, ground_lable,vis_data_records):
    pre ,out,_ = predict(input_ids,  position_ids=position_ids,attention_mask=attention_mask)
    if out == 1:
        sen_type = 'Inconsistency'
    else:
        sen_type = 'Consistency'
    pre = pre.item()
    pre = "{:.3f}".format(pre)
    pre = float(pre) 
    pre ,sen_type
    
    attributions_ig, delta_ig = lig.attribute(input_ids, baselines=ref_input_ids,\
                           additional_forward_args=(position_ids,attention_mask,0),return_convergence_delta=True,internal_batch_size=8)
    
    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
#     print(delta_ig)
#     print(attributions)
    try:
        all_tokens ,attributions = get_restore_words(code,comment,input_ids[0],all_tokens,attributions)     # 合并为一个单词
        attributions = torch.tensor(attributions)


        # code_lineList_token 和 attribute_num_code 用于后续统计分析
        new_all_tokens,attributions_num_all,code_lineList_token,attribute_num_code = final_all_attribute(code,all_tokens,attributions)        

        add_attributions_to_visualizer(attributions_num_all, new_all_tokens, pre, ground_lable, sen_type, delta_ig, vis_data_records)

        top_tokens_max, values_max, top_tokens_min,\
        values_min = get_topk_attributed_tokens(attributions_num_all,new_all_tokens)

        return top_tokens_max, values_max, top_tokens_min,values_min,\
               code,code_lineList_token,attribute_num_code
    except Exception as e:
#         pass
        print("解析错误")
        return _, _, _,_,_,_,_

def add_attributions_to_visualizer(attributions, all_tokens, pre, ground_lable, sen_type, delta, vis_data_records):

#     all_tokens = '\n'.join(all_token for all_token in all_tokens)
    
#     all_tokens = (all_token+'\n' for all_token in all_tokens)
#     print(all_tokens)

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions,
                            pre,
                            sen_type,
                            ground_lable,
                            sen_type,
                            attributions.sum(),
                            all_tokens,
                            delta))


In [None]:
def input_data_list(AST_list,comment_list):
    input_ids_all = []
    ref_input_ids_all = []
    position_ids_all = []
    attention_mask_all = []
    token_type_ids_all = []
    all_tokens_all = []
    for i in range(len(AST_list)):
        input_ids, ref_input_ids, comment_len = construct_input_ref_pair(comment_list[i],AST_list[i], ref_token_id,\
                                                                         sep_token_id, cls_token_id)
        token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, comment_len)
        position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
        attention_mask = construct_attention_mask(input_ids)
        
        indices = input_ids[0].detach().tolist()
        
        all_tokens = []                                       ###
        for _, token in enumerate(indices):
            all_tokens.append(tokenizer.decode([token]))
        
        input_ids_all.append(input_ids)
        ref_input_ids_all.append(ref_input_ids)
        position_ids_all.append(position_ids)
        attention_mask_all.append(attention_mask)
        token_type_ids_all.append(token_type_ids)
        all_tokens_all.append(all_tokens)

    return input_ids_all,ref_input_ids_all,position_ids_all,attention_mask_all,token_type_ids_all,all_tokens_all 

In [None]:
def remove_other_index(df,index):
    filtered_df = df.loc[index]
    filtered_df = filtered_df.reset_index(drop=True)
    filtered_df["raw_data_index_column"] = index
    return filtered_df

In [None]:
def retrieve_train_data():
    test_param = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/param/train.json")
    test_return = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/return/train.json")
    test_summary = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/summary/train.json")
    test_df = pd.concat([test_summary,test_param, test_return], axis=0)
    test_df = test_df.reset_index(drop=True)
    return test_df
train_df = retrieve_train_data()

In [None]:
def remove_last(df):
    for i in range(len(df)):
        string = df.loc[i]['old_comment_raw']
        df['old_comment_raw'][i] = string.rstrip('.')
    return df
train_df = remove_last(train_df)

In [None]:
list_false_index_ast = [610,716, 363]
test_df_true_set = remove_other_index(train_df,list_false_index_ast)
code_list = list(test_df_true_set['new_code_raw'])
comment_list = list(test_df_true_set['old_comment_raw'])
ground_lable = list(test_df_true_set['label'])

input_ids_all,ref_input_ids_all,position_ids_all,\
attention_mask_all,token_type_ids_all,all_tokens_all= input_data_list(code_list,comment_list)

vis_data_records_ig = []
for i in range(len(code_list)):  
    top_tokens_max, values_max, top_tokens_min,values_min,code,_,_ \
    = interpret_sentence_2(code_list[i],comment_list[i],input_ids_all[i],ref_input_ids_all[i],\
                           token_type_ids_all[i], position_ids_all[i], attention_mask_all[i], all_tokens_all[i],\
                           ground_lable[i],vis_data_records_ig)
    
    
print('Visualize attributions based on Integrated Gradients')
_ = viz.visualize_text(vis_data_records_ig)

In [None]:
list_false_index_ast = [65]
test_df_true_set = remove_other_index(test_df,list_false_index_ast)
code_list = list(test_df_true_set['new_code_raw'])
comment_list = list(test_df_true_set['old_comment_raw'])
ground_lable = list(test_df_true_set['label'])

input_ids_all,ref_input_ids_all,position_ids_all,\
attention_mask_all,token_type_ids_all,all_tokens_all= input_data_list(code_list,comment_list)

vis_data_records_ig = []
for i in range(len(code_list)):  
    top_tokens_max, values_max, top_tokens_min,values_min,code,_,_ \
    = interpret_sentence_2(code_list[i],comment_list[i],input_ids_all[i],ref_input_ids_all[i],\
                           token_type_ids_all[i], position_ids_all[i], attention_mask_all[i], all_tokens_all[i],\
                           ground_lable[i],vis_data_records_ig)
    
    
print('Visualize attributions based on Integrated Gradients')
_ = viz.visualize_text(vis_data_records_ig)