In [1]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from transformers import BertTokenizer,BertModel
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, TensorDataset


from captum.attr import visualization as viz
from captum.attr import LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

# 解决服务器挂掉的问题
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [2]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device( "cpu")

In [3]:
MAX_LEN = 1024
MAX_EPOCHS = 10
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
NUM_CLASSES = 2
WEIGTH_DECAY = 1e-3

In [4]:
def retrieve_train_data():
    train_param = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/param/train.json")
    train_return = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/return/train.json")
    train_summary = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/summary/train.json")
    train_df = pd.concat([train_summary,train_param, train_return], axis=0)
    train_df = train_df.reset_index(drop=True)
    return train_df
def retrieve_valid_data():
    valid_param = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/param/valid.json")
    valid_return = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/return/valid.json")
    valid_summary = pd.read_json("D:/BERT_learing/code_comment_inconsistency_detection/data/summary/valid.json")
    valid_df = pd.concat([valid_summary,valid_param, valid_return ], axis=0)
    valid_df = valid_df.reset_index(drop=True)
    return valid_df

In [5]:
train_df = retrieve_train_data()
valid_df = retrieve_valid_data()
train_df.head()

Unnamed: 0,id,label,comment_type,old_comment_raw,old_comment_subtokens,new_comment_raw,new_comment_subtokens,span_minimal_diff_comment_subtokens,old_code_raw,old_code_subtokens,new_code_raw,new_code_subtokens,span_diff_code_subtokens,token_diff_code_subtokens
0,grails-plugins_grails-plugin-converters-5-Asso...,1,Summary,Parses the given JSON and returns ether a JSON...,"[parses, the, given, json, and, returns, ether...",Parses the given JSON and returns either a JSO...,"[parses, the, given, json, and, returns, eithe...","[<REPLACE_OLD>, ether, <REPLACE_NEW>, either, ...",public static JSONElement parse(InputStrea...,"[public, static, jsonelement, parse, (, input,...",public static JSONElement parse(InputStrea...,"[public, static, jsonelement, parse, (, input,...","[<KEEP>, public, static, jsonelement, parse, (...","[<KEEP>, public, <KEEP>, static, <KEEP>, jsone..."
1,jitsi_jitsi-4343-FirstSentence-0,0,Summary,Loads an image from a given image identifier.,"[loads, an, image, from, a, given, image, iden...",Loads an image from a given image identifier.,"[loads, an, image, from, a, given, image, iden...",[],public static byte[] getImageInBytes(Strin...,"[public, static, byte, [, ], get, image, in, b...",public static byte[] getImageInBytes(Strin...,"[public, static, byte, [, ], get, image, in, b...","[<KEEP>, public, static, byte, [, ], get, imag...","[<KEEP>, public, <KEEP>, static, <KEEP>, byte,..."
2,dropwizard_metrics-26-Associations-FirstSentence,1,Summary,Creates a new CounterMetric and registers it ...,"[creates, a, new, counter, metric, and, regist...",Creates a new com.yammer.metrics.core.Counter...,"[creates, a, new, com, ., yammer, ., metrics, ...","[<INSERT_OLD_KEEP_BEFORE>, new, <INSERT_NEW_KE...",public static CounterMetric newCounter(Cla...,"[public, static, counter, metric, new, counter...",public static Counter newCounter(Class<?> ...,"[public, static, counter, new, counter, (, cla...","[<KEEP>, public, static, counter, <KEEP_END>, ...","[<KEEP>, public, <KEEP>, static, <KEEP>, count..."
3,google_ExoPlayer-92-FirstSentence-0,0,Summary,Derives a sample format corresponding to a giv...,"[derives, a, sample, format, corresponding, to...",Derives a sample format corresponding to a giv...,"[derives, a, sample, format, corresponding, to...",[],private static Format getSampleFormat(Format...,"[private, static, format, get, sample, format,...",private static Format getSampleFormat(Format...,"[private, static, format, get, sample, format,...","[<KEEP>, private, static, format, get, sample,...","[<KEEP>, private, <KEEP>, static, <KEEP>, form..."
4,slachiewicz_orekit-main-661-Associations-First...,1,Summary,Revert a rotation/rotation rate pair.,"[revert, a, rotation, /, rotation, rate, pair, .]",Revert a rotation/rotation rate/ rotation acce...,"[revert, a, rotation, /, rotation, rate, /, ro...","[<REPLACE_OLD>, pair, <REPLACE_NEW>, /, rotati...",public AngularCoordinates revert() {\n ...,"[public, angular, coordinates, revert, (, ), {...",public AngularCoordinates revert() {\n ...,"[public, angular, coordinates, revert, (, ), {...","[<KEEP>, public, angular, coordinates, revert,...","[<KEEP>, public, <KEEP>, angular, <KEEP>, coor..."


In [6]:
valid_df.head()

Unnamed: 0,id,label,comment_type,old_comment_raw,old_comment_subtokens,new_comment_raw,new_comment_subtokens,span_minimal_diff_comment_subtokens,old_code_raw,old_code_subtokens,new_code_raw,new_code_subtokens,span_diff_code_subtokens,token_diff_code_subtokens
0,todoroo_astrid-987-FirstSentence-0,1,Summary,Return SQL selector query for getting tasks wi...,"[return, sql, selector, query, for, getting, t...",Return SQL selector query for getting tasks wi...,"[return, sql, selector, query, for, getting, t...","[<INSERT_OLD_KEEP_BEFORE>, tag, <INSERT_NEW_KE...",public QueryTemplate queryTemplate(Cri...,"[public, query, template, query, template, (, ...",public static QueryTemplate queryTempl...,"[public, static, query, template, query, templ...","[<KEEP>, public, <KEEP_END>, <INSERT>, static,...","[<KEEP>, public, <INSERT>, static, <KEEP>, que..."
1,Red5_red5_server-43-FirstSentence-0,0,Summary,Return period of ghost connections cleanup tas...,"[return, period, of, ghost, connections, clean...",Return period of ghost connections cleanup tas...,"[return, period, of, ghost, connections, clean...",[],public int getGhostConnsCleanupPeriod() {\...,"[public, int, get, ghost, conns, cleanup, peri...",public int getGhostConnsCleanupPeriod() {\...,"[public, int, get, ghost, conns, cleanup, peri...","[<KEEP>, public, int, get, ghost, conns, clean...","[<KEEP>, public, <KEEP>, int, <KEEP>, get, <KE..."
2,nickman_Rindle-11-Associations-FirstSentence,1,Summary,Allocates an initialized and initially unlocke...,"[allocates, an, initialized, and, initially, u...",Allocates an initialized and initially unlocke...,"[allocates, an, initialized, and, initially, u...","[<INSERT_OLD_KEEP_BEFORE>, unlocked, <INSERT_N...",\tpublic static long allocateSpinLock() {\r\n\...,"[public, static, long, allocate, spin, lock, (...",\tpublic static SpinLock allocateSpinLock() {\...,"[public, static, spin, lock, allocate, spin, l...","[<KEEP>, public, static, <KEEP_END>, <REPLACE_...","[<KEEP>, public, <KEEP>, static, <REPLACE_OLD>..."
3,h2oai_h2o_2-427-FirstSentence-0,0,Summary,Rebalance a frame for load balancing,"[rebalance, a, frame, for, load, balancing]",Rebalance a frame for load balancing,"[rebalance, a, frame, for, load, balancing]",[],"private Frame reBalance(final Frame fr, bool...","[private, frame, re, balance, (, final, frame,...",private static Frame reBalance(final Frame f...,"[private, static, frame, re, balance, (, final...","[<KEEP>, private, <KEEP_END>, <INSERT>, static...","[<KEEP>, private, <INSERT>, static, <KEEP>, fr..."
4,sonatype_sonatype-aether-11-Associations-First...,1,Summary,Sets the host of this proxy.,"[sets, the, host, of, this, proxy, .]",Sets the host of the proxy.,"[sets, the, host, of, the, proxy, .]","[<REPLACE_OLD>, this, <REPLACE_NEW>, the, <REP...",public Proxy setHost( String host )\n {...,"[public, proxy, set, host, (, string, host, ),...",public Proxy setHost( String host )\n {...,"[public, proxy, set, host, (, string, host, ),...","[<KEEP>, public, proxy, set, host, (, string, ...","[<KEEP>, public, <KEEP>, proxy, <KEEP>, set, <..."


In [7]:
def format_data(df):
    old_code_raw = df['new_code_raw']
    old_code_raw = old_code_raw.values
    old_code_raw = [str(ele) for ele in old_code_raw]
       
    multi_line_old_code = []
    for i in range(len(old_code_raw)):
        multi_line_test = old_code_raw[i].replace('\n', ' ')   # 去掉\n
        multi_line_test = ' '.join(multi_line_test.split())    # 把多余空格变成一个空格
        multi_line_old_code.append(multi_line_test) 
     
    old_comment_raw = df['old_comment_raw']
    old_comment_raw = old_comment_raw.values
    old_comment_raw = [str(ele) for ele in old_comment_raw]
    multi_line_old_comment = []
    for i in range(len(old_comment_raw)):
        multi_line_test = ' '.join(old_comment_raw[i].split())    # 把多余空格变成一个空格
        multi_line_old_comment.append(multi_line_test)  
    
    df['new_code_raw'] = multi_line_old_code
    df['old_comment_raw'] = multi_line_old_comment
    
    return df

In [8]:
train_df_clean = format_data(train_df)
train_df_clean.head()

Unnamed: 0,id,label,comment_type,old_comment_raw,old_comment_subtokens,new_comment_raw,new_comment_subtokens,span_minimal_diff_comment_subtokens,old_code_raw,old_code_subtokens,new_code_raw,new_code_subtokens,span_diff_code_subtokens,token_diff_code_subtokens
0,grails-plugins_grails-plugin-converters-5-Asso...,1,Summary,Parses the given JSON and returns ether a JSON...,"[parses, the, given, json, and, returns, ether...",Parses the given JSON and returns either a JSO...,"[parses, the, given, json, and, returns, eithe...","[<REPLACE_OLD>, ether, <REPLACE_NEW>, either, ...",public static JSONElement parse(InputStrea...,"[public, static, jsonelement, parse, (, input,...",public static JSONElement parse(InputStream is...,"[public, static, jsonelement, parse, (, input,...","[<KEEP>, public, static, jsonelement, parse, (...","[<KEEP>, public, <KEEP>, static, <KEEP>, jsone..."
1,jitsi_jitsi-4343-FirstSentence-0,0,Summary,Loads an image from a given image identifier.,"[loads, an, image, from, a, given, image, iden...",Loads an image from a given image identifier.,"[loads, an, image, from, a, given, image, iden...",[],public static byte[] getImageInBytes(Strin...,"[public, static, byte, [, ], get, image, in, b...",public static byte[] getImageInBytes(String im...,"[public, static, byte, [, ], get, image, in, b...","[<KEEP>, public, static, byte, [, ], get, imag...","[<KEEP>, public, <KEEP>, static, <KEEP>, byte,..."
2,dropwizard_metrics-26-Associations-FirstSentence,1,Summary,Creates a new CounterMetric and registers it u...,"[creates, a, new, counter, metric, and, regist...",Creates a new com.yammer.metrics.core.Counter...,"[creates, a, new, com, ., yammer, ., metrics, ...","[<INSERT_OLD_KEEP_BEFORE>, new, <INSERT_NEW_KE...",public static CounterMetric newCounter(Cla...,"[public, static, counter, metric, new, counter...",public static Counter newCounter(Class<?> klas...,"[public, static, counter, new, counter, (, cla...","[<KEEP>, public, static, counter, <KEEP_END>, ...","[<KEEP>, public, <KEEP>, static, <KEEP>, count..."
3,google_ExoPlayer-92-FirstSentence-0,0,Summary,Derives a sample format corresponding to a giv...,"[derives, a, sample, format, corresponding, to...",Derives a sample format corresponding to a giv...,"[derives, a, sample, format, corresponding, to...",[],private static Format getSampleFormat(Format...,"[private, static, format, get, sample, format,...",private static Format getSampleFormat(Format c...,"[private, static, format, get, sample, format,...","[<KEEP>, private, static, format, get, sample,...","[<KEEP>, private, <KEEP>, static, <KEEP>, form..."
4,slachiewicz_orekit-main-661-Associations-First...,1,Summary,Revert a rotation/rotation rate pair.,"[revert, a, rotation, /, rotation, rate, pair, .]",Revert a rotation/rotation rate/ rotation acce...,"[revert, a, rotation, /, rotation, rate, /, ro...","[<REPLACE_OLD>, pair, <REPLACE_NEW>, /, rotati...",public AngularCoordinates revert() {\n ...,"[public, angular, coordinates, revert, (, ), {...",public AngularCoordinates revert() { return ne...,"[public, angular, coordinates, revert, (, ), {...","[<KEEP>, public, angular, coordinates, revert,...","[<KEEP>, public, <KEEP>, angular, <KEEP>, coor..."


In [9]:
valid_df_clean = format_data(valid_df)
valid_df_clean.head()

Unnamed: 0,id,label,comment_type,old_comment_raw,old_comment_subtokens,new_comment_raw,new_comment_subtokens,span_minimal_diff_comment_subtokens,old_code_raw,old_code_subtokens,new_code_raw,new_code_subtokens,span_diff_code_subtokens,token_diff_code_subtokens
0,todoroo_astrid-987-FirstSentence-0,1,Summary,Return SQL selector query for getting tasks wi...,"[return, sql, selector, query, for, getting, t...",Return SQL selector query for getting tasks wi...,"[return, sql, selector, query, for, getting, t...","[<INSERT_OLD_KEEP_BEFORE>, tag, <INSERT_NEW_KE...",public QueryTemplate queryTemplate(Cri...,"[public, query, template, query, template, (, ...",public static QueryTemplate queryTemplate(Crit...,"[public, static, query, template, query, templ...","[<KEEP>, public, <KEEP_END>, <INSERT>, static,...","[<KEEP>, public, <INSERT>, static, <KEEP>, que..."
1,Red5_red5_server-43-FirstSentence-0,0,Summary,Return period of ghost connections cleanup tas...,"[return, period, of, ghost, connections, clean...",Return period of ghost connections cleanup tas...,"[return, period, of, ghost, connections, clean...",[],public int getGhostConnsCleanupPeriod() {\...,"[public, int, get, ghost, conns, cleanup, peri...",public int getGhostConnsCleanupPeriod() { retu...,"[public, int, get, ghost, conns, cleanup, peri...","[<KEEP>, public, int, get, ghost, conns, clean...","[<KEEP>, public, <KEEP>, int, <KEEP>, get, <KE..."
2,nickman_Rindle-11-Associations-FirstSentence,1,Summary,Allocates an initialized and initially unlocke...,"[allocates, an, initialized, and, initially, u...",Allocates an initialized and initially unlocke...,"[allocates, an, initialized, and, initially, u...","[<INSERT_OLD_KEEP_BEFORE>, unlocked, <INSERT_N...",\tpublic static long allocateSpinLock() {\r\n\...,"[public, static, long, allocate, spin, lock, (...",public static SpinLock allocateSpinLock() { lo...,"[public, static, spin, lock, allocate, spin, l...","[<KEEP>, public, static, <KEEP_END>, <REPLACE_...","[<KEEP>, public, <KEEP>, static, <REPLACE_OLD>..."
3,h2oai_h2o_2-427-FirstSentence-0,0,Summary,Rebalance a frame for load balancing,"[rebalance, a, frame, for, load, balancing]",Rebalance a frame for load balancing,"[rebalance, a, frame, for, load, balancing]",[],"private Frame reBalance(final Frame fr, bool...","[private, frame, re, balance, (, final, frame,...","private static Frame reBalance(final Frame fr,...","[private, static, frame, re, balance, (, final...","[<KEEP>, private, <KEEP_END>, <INSERT>, static...","[<KEEP>, private, <INSERT>, static, <KEEP>, fr..."
4,sonatype_sonatype-aether-11-Associations-First...,1,Summary,Sets the host of this proxy.,"[sets, the, host, of, this, proxy, .]",Sets the host of the proxy.,"[sets, the, host, of, the, proxy, .]","[<REPLACE_OLD>, this, <REPLACE_NEW>, the, <REP...",public Proxy setHost( String host )\n {...,"[public, proxy, set, host, (, string, host, ),...",public Proxy setHost( String host ) { return n...,"[public, proxy, set, host, (, string, host, ),...","[<KEEP>, public, proxy, set, host, (, string, ...","[<KEEP>, public, <KEEP>, proxy, <KEEP>, set, <..."


In [10]:
# model = torch.load('save_GCBmodel.pt',map_location=torch.device('cuda:0'))
model = torch.load('D:/BERT_learing/CCDP/for_captum/save_model/save_bertmodel.pt',map_location=torch.device('cpu'))

model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [11]:
import torch
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())  #输出为True，则安装无误


2.3.1+cpu
None
False


In [12]:
device

device(type='cpu')

In [14]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")



In [21]:
model.bert.embeddings

BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [22]:
model.bert.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
# input_embeddings = model.get_input_embeddings()

In [None]:
# predict和squad_pos_forward_func可以合成一个
def predict(inputs, position_ids=None, attention_mask=None):
    output = model(inputs,
                   position_ids=position_ids,
                  attention_mask=attention_mask )
    
    prediction = output.logits
    prediction_1 = nn.functional.softmax(prediction, dim=1)
    prediction = prediction_1.max(1).values
    out = torch.argmax(prediction_1, dim=-1)
    # prediction：每个输入样本的最大预测概率。
    # out：预测的类别标签。
    # prediction_1：所有类别的预测概率。    
    return prediction,out,prediction_1

In [None]:
def squad_pos_forward_func(inputs,position_ids=None, attention_mask=None, position=0):
    pred ,_,_= predict(inputs,
                     position_ids=position_ids,
                   attention_mask=attention_mask)
    return pred

In [None]:
ref_token_id = tokenizer.pad_token_id # 0
sep_token_id = tokenizer.sep_token_id # 101
cls_token_id = tokenizer.cls_token_id # 102
ref_token_id,sep_token_id,cls_token_id

In [None]:
# 这是单个数据的处理方式，应该要想数据集应该怎么处理
def construct_input_ref_pair(comment,AST_type,  ref_token_id, sep_token_id, cls_token_id):
    comment = tokenizer.encode(comment, add_special_tokens=False,truncation=True,max_length=512)
    AST_type = tokenizer.encode(AST_type, add_special_tokens=False,truncation=True,max_length=512)
    # construct input token ids
    input_ids = [cls_token_id] + comment + [sep_token_id] + AST_type + [sep_token_id]

    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(comment) + [sep_token_id] + \
        [ref_token_id] * len(AST_type) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(comment)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids

def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [None]:
code_list = train_df_clean.loc[2,'new_code_raw'] 

In [None]:
code_list = train_df_clean.loc[2,'new_code_raw'] 
comment_list = train_df_clean.loc[2,'old_comment_raw']
ground_lable = train_df_clean.loc[2,'label']
print(code_list)

In [None]:
print(comment_list)

In [None]:
print(ground_lable)

In [None]:
input_ids_t, ref_input_ids_t, comment_len_t = construct_input_ref_pair(comment_list,code_list, ref_token_id, sep_token_id, cls_token_id)
token_type_ids_t, ref_token_type_ids_t = construct_input_ref_token_type_pair(input_ids_t, comment_len_t)
position_ids_t, ref_position_ids_t = construct_input_ref_pos_id_pair(input_ids_t)
attention_mask_t = construct_attention_mask(input_ids_t)

indices_t = input_ids_t[0].detach().tolist()
all_tokens_t = tokenizer.convert_ids_to_tokens(indices_t) 

In [None]:
input_ids_t,ref_input_ids_t,all_tokens_t

In [None]:
output = model(input_ids_t,position_ids=position_ids_t,attention_mask=attention_mask_t)
output

In [None]:
pred ,sen_type,pred_tensor= predict(input_ids_t,position_ids=position_ids_t,attention_mask=attention_mask_t)
pred ,sen_type,pred_tensor

In [None]:
pred = squad_pos_forward_func(input_ids_t,position_ids=position_ids_t,attention_mask=attention_mask_t)
pred

In [None]:
pre ,out,_ = predict(input_ids_t,position_ids=position_ids_t,attention_mask=attention_mask_t)
if out == 1:
    sen_type = 'pos'
else:
    sen_type = 'nag'
pre = pre.item()
pre = "{:.3f}".format(pre)
pre = float(pre) 
pre ,sen_type

In [None]:
AST_list = []
for i in range(10): 
    AST_list.append(train_df_clean.loc[i,'new_code_raw'])

comment_list = []
for i in range(10): 
    comment_list.append(train_df_clean.loc[i,'old_comment_raw'])

ground_lable = []
for i in range(10): 
    ground_lable.append(train_df_clean.loc[i,'label'])

print(AST_list[2])
print(comment_list[2])
print(ground_lable[2])

def input_data_list(AST_list,comment_list):
    input_ids_all = []
    ref_input_ids_all = []
    position_ids_all = []
    attention_mask_all = []
    token_type_ids_all = []
    all_tokens_all = []
    for i in range(len(AST_list)):
        input_ids, ref_input_ids, comment_len = construct_input_ref_pair(comment_list[i],AST_list[i], ref_token_id, sep_token_id, cls_token_id)
        token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, comment_len)
        position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
        attention_mask = construct_attention_mask(input_ids)
        
        indices = input_ids[0].detach().tolist()
        all_tokens = tokenizer.convert_ids_to_tokens(indices)
        
        input_ids_all.append(input_ids)
        ref_input_ids_all.append(ref_input_ids)
        position_ids_all.append(position_ids)
        attention_mask_all.append(attention_mask)
        token_type_ids_all.append(token_type_ids)
        all_tokens_all.append(all_tokens)

    return input_ids_all,ref_input_ids_all,position_ids_all,attention_mask_all,token_type_ids_all,all_tokens_all 


In [None]:
input_ids_all,ref_input_ids_all,position_ids_all,attention_mask_all,token_type_ids_all,all_tokens_all= input_data_list(AST_list,comment_list)
# print(input_ids_all[1])
# print(ref_input_ids_all[1])
# print(input_ids_all[1])

In [None]:
# 前k个贡献最高的word 和 token_type 和 position
# return value为归因贡献值  indices为词对应的索引  top_tokens为 词或位置或token_type
def get_topk_attributed_tokens(attrs,all_token_t, k=5):
    values_max, indices_max = torch.topk(attrs, k)
    top_tokens_max = [all_token_t[idx] for idx in indices_max]
    values_min, indices_min = torch.topk(attrs, k, largest=False)
    top_tokens_min = [all_token_t[idx] for idx in indices_min] 
    
    return top_tokens_max, values_max, indices_max,top_tokens_min,values_min,indices_min


In [None]:
from captum.attr import LayerIntegratedGradients
from captum.attr import visualization as viz

lig = LayerIntegratedGradients(squad_pos_forward_func,input_embeddings)

vis_data_records_ig = []

def interpret_sentence(input_ids,ref_input_ids, token_type_ids, position_ids, attention_mask, all_tokens, ground_lable):
    pre ,out,_ = predict(input_ids, \
                position_ids=position_ids,
                attention_mask=attention_mask)
    if out == 1:
        sen_type = 'pos'
    else:
        sen_type = 'nag'
    pre = pre.item()
    pre = "{:.3f}".format(pre)
    pre = float(pre) 
    pre ,sen_type
    
    attributions_ig, delta_ig = lig.attribute(input_ids, baselines=ref_input_ids,\
                           additional_forward_args=(position_ids,attention_mask,0),return_convergence_delta=True,internal_batch_size=8)
    
    attributions = attributions_ig.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    
    add_attributions_to_visualizer(attributions, all_tokens, pre, ground_lable, sen_type, delta_ig, vis_data_records_ig)
    
    top_tokens_max, values_max, indices_max,top_tokens_min,values_min,indices_min = get_topk_attributed_tokens(attributions,all_tokens)
    
    return top_tokens_max, values_max, indices_max,top_tokens_min,values_min,indices_min


def add_attributions_to_visualizer(attributions, all_tokens, pre, ground_lable, sen_type, delta, vis_data_records):

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions,
                            pre,
                            pre,
                            ground_lable,
                            sen_type,
                            attributions.sum(),
                            all_tokens,
                            delta))


In [None]:
for i in range(len(AST_list)):    
    top_tokens_max, values_max, indices_max,top_tokens_min,values_min,indices_min = interpret_sentence(input_ids_all[i],ref_input_ids_all[i], token_type_ids_all[i], position_ids_all[i], attention_mask_all[i], all_tokens_all[i], ground_lable[i])
    print(f'第{i}个top：分别为贡献最大，贡献最大值，token值，贡献最小，贡献最小值\n{top_tokens_max}\n, {values_max}\n, {indices_max}\n,{top_tokens_min}\n,{values_min}\n,{indices_min}\n')


In [None]:
print('Visualize attributions based on Integrated Gradients')
_ = viz.visualize_text(vis_data_records_ig)

#### Interpreting Bert Layers
每个token在所有层的归因分数分布。
此处使用了LayerConductance进行分析，更改了前向传播函数；只更改了模型的输入参数

In [None]:
# 与predict函数差不多
def squad_pos_forward_func2(input_emb, attention_mask=None, position=0):
    pred = model(inputs_embeds=input_emb, attention_mask=attention_mask, )
    
    prediction = pred.logits
    prediction_1 = nn.functional.softmax(prediction, dim=1)
    prediction = prediction_1.max(1).values

    return prediction 

In [None]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
position_ids_t,ref_position_ids_t = construct_input_ref_pos_id_pair(input_ids_t)
position_ids_t,ref_position_ids_t 

In [None]:
layer_attrs = []

# The token that we would like to examine separately. 下一步实验为对所有token进行计算。
token_to_explain = 2   # 想要查看的token在所有层归因  示例中23为kinds
layer_attrs_dist = []

# input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids_t, ref_input_ids_t, \
#                                          token_type_ids=token_type_ids_t, ref_token_type_ids=ref_token_type_ids_t, \
#                                          position_ids=position_ids_t, ref_position_ids=ref_position_ids_t)
input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids_t, ref_input_ids_t, \
                                         position_ids=position_ids_t, ref_position_ids=ref_position_ids_t)

In [None]:
input_embeddings.shape

In [None]:
# 在bert.config.num_hidden_layers上对token=23的词进行分析
for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(squad_pos_forward_func2, model.encoder.layer[i])
    layer_attributions = lc.attribute(inputs=input_embeddings, baselines=ref_input_embeddings, additional_forward_args=(attention_mask_t, 0),internal_batch_size=1)
    layer_attrs.append(summarize_attributions(layer_attributions).cpu().detach().tolist())
    
    # storing attributions of the token id that we would like to examine in more detail in token_to_explain
    layer_attrs_dist.append(layer_attributions[0,token_to_explain,:].cpu().detach().tolist())

In [None]:
# 画图
fig, ax = plt.subplots(figsize=(25,10))
xticklabels=all_tokens_t
yticklabels=list(range(1,13))
ax = sns.heatmap(np.array(layer_attrs), xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2)
plt.xlabel('Tokens')
plt.ylabel('Layers')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(20,10))
ax = sns.boxplot(data=layer_attrs_dist)
plt.xlabel('Layers')
plt.ylabel('Attribution')
plt.show()

In [None]:
len(layer_attrs)

In [None]:
layer_attributions

In [None]:
# layer_attributions[0,token_to_explain,:]
# 在第一个维度选择第一个元素，第二个维度选择 token_to_explain 这个元素三个维度选择全部元素
layer_attributions.shape

Visualizing Attention Matrices

In [None]:
# 只有设置了 output_attentions=True 时 output.attentions 才会有值
output = model(input_ids_t, output_attentions=True)
print(output.attentions)

In [18]:
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                 position_ids=position_ids, attention_mask=attention_mask, output_attentions=True )
    return output.logits, output.attentions

In [19]:
output_logits,output_attentions= predict(input_ids_t,attention_mask=attention_mask_t)
output_logits,output_attentions

NameError: name 'input_ids_t' is not defined

In [None]:
# shape -> layer x batch x head x seq_len x seq_len
output_attentions_all = torch.stack(output_attentions)
output_attentions_all

In [None]:
output_attentions_all[11]

In [None]:
def visualize_token2token_scores(scores_mat,all_tokens, x_label_name='Head'):
    fig = plt.figure(figsize=(50, 50))

    for idx, scores in enumerate(scores_mat):
        scores_np = np.array(scores)
        ax = fig.add_subplot(4, 3, idx+1)
        # append the attention weights
        im = ax.imshow(scores, cmap='viridis')

        fontdict = {'fontsize': 15}

        ax.set_xticks(range(len(all_tokens)))
        ax.set_yticks(range(len(all_tokens)))

        ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
        ax.set_yticklabels(all_tokens, fontdict=fontdict)
        ax.set_xlabel('{} {}'.format(x_label_name, idx+1))

        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

In [None]:
def visualize_token2head_scores(scores_mat,all_tokens):
    fig = plt.figure(figsize=(50, 50))

    for idx, scores in enumerate(scores_mat):
        scores_np = np.array(scores)
        ax = fig.add_subplot(6, 2, idx+1)
        # append the attention weights
        im = ax.matshow(scores_np, cmap='viridis')

        fontdict = {'fontsize': 15}

        ax.set_xticks(range(len(all_tokens)))
        ax.set_yticks(range(len(scores)))

        ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
        ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
        ax.set_xlabel('Layer {}'.format(idx+1))

        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

In [None]:
layer = 11

In [None]:
visualize_token2token_scores(output_attentions_all[layer].squeeze().detach().cpu().numpy(),all_tokens_t)

In [None]:
if torch.__version__ >= '1.7.0':
    norm_fn = torch.linalg.norm
else:
    norm_fn = torch.norm

In [None]:
visualize_token2token_scores(norm_fn(output_attentions_all, dim=2).squeeze().detach().cpu().numpy(),all_tokens_t,x_label_name='Layer')

Interpreting Outputs and Self-Attention Matrices in each Layer

In [None]:
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = interpretable_embedding.indices_to_embeddings(input_ids)
    ref_input_embeddings = interpretable_embedding.indices_to_embeddings(ref_input_ids)
    
    return input_embeddings, ref_input_embeddings

In [None]:
def squad_pos_forward_func(inputs, attention_mask=None, position=0):
    pred = model(inputs_embeds=inputs,  attention_mask=attention_mask, )
    pred = pred[position]
    return pred.max(1).values

In [None]:
interpretable_embedding = configure_interpretable_embedding_layer(model, 'roberta.embeddings.word_embeddings')

In [None]:
layer_attrs = []
layer_attn_mat = []

input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids_t, ref_input_ids_t, \
                                         ref_position_ids=ref_position_ids_t)

In [None]:
input_embeddings.shape

In [None]:
pred = model(inputs_embeds=input_embeddings, attention_mask=attention_mask_t, )

In [None]:
for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(squad_pos_forward_func, model.roberta.encoder.layer[i])    
    layer_attributions = lc.attribute(inputs=input_embeddings, baselines=ref_input_embeddings, additional_forward_args=(attention_mask_t, 0),internal_batch_size=1)
    print(layer_attributions.shape)
    layer_attrs.append(summarize_attributions(layer_attributions[0]))

    layer_attn_mat.append(layer_attributions[1])

Interpreting Attribution Scores for Attention Matrices

In [None]:
visualize_token2token_scores(layer_attn_mat[layer].squeeze().cpu().detach().numpy())

In [None]:
visualize_token2token_scores(norm_fn(layer_attn_mat, dim=2).squeeze().detach().cpu().numpy(),
                             x_label_name='Layer')

In [None]:
Interpreting Attribution Scores for Attention Matrices

visualize_token2token_scores(layer_attn_mat[layer].squeeze().cpu().detach().numpy())

visualize_token2token_scores(norm_fn(layer_attn_mat, dim=2).squeeze().detach().cpu().numpy(),
                             x_label_name='Layer')