In [1]:
from tqdm import tqdm,trange
import ujson as json
import numpy as np

import argparse
import os
import time
from datetime import datetime
import torch
import pickle
import copy

import ujson as json
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from model_balanceloss import DocREModel
from utils_sample import set_seed, collate_fn
from evaluation import to_official, official_evaluate
from prepro import ReadDataset
from train_balanceloss_reset import evaluate_o

from collections import defaultdict
import nltk
from nltk.corpus import wordnet as wn
import collections


MAX_SEQ_LENGTH = 1024
rel2id_path = 'dataset/meta/rel2id.json'
docred_rel2id = json.load(open(rel2id_path))

# Load tokenizer and new keyword dataset

In [3]:
dev_keys_new = json.load(open('dataset/docred/dev_keys_new.json'))
kdict = pickle.load(open('dataset/docred/keywords_dict.pkl','rb'))
model_type = 'roberta-large'
tokenizer = AutoTokenizer.from_pretrained(model_type)
len(dev_keys_new), len(kdict.keys())

(699, 644)

## model prediction for new dev keyword dataset

In [4]:
def report(args, model, features):

    dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
    preds = []
#     print('test feature size: ', len(features))
    for batch in tqdm(dataloader):
        model.eval()

        inputs = {'input_ids': batch[0].to(args.device),
                  'attention_mask': batch[1].to(args.device),
                  'entity_pos': batch[3],
                  'hts': batch[4],
                  }

        with torch.no_grad():
            pred = model(**inputs)
            pred = pred.cpu().numpy()
            pred[np.isnan(pred)] = 0
            preds.append(pred)

    preds = np.concatenate(preds, axis=0).astype(np.float32)
    preds = to_official(preds, features)
    return preds


def arg_pre():
    parser = argparse.ArgumentParser()

    
    parser.add_argument("--data_dir", default="./dataset/docred", type=str)
    parser.add_argument("--transformer_type", default="roberta", type=str)
    parser.add_argument("--model_name_or_path", default="roberta-large", type=str)

    parser.add_argument("--load_path", default="checkpoint/docred/roberta.pt", type=str)

    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--max_seq_length", default=1024, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")

    parser.add_argument("--test_batch_size", default=8, type=int,
                        help="Batch size for testing.")
   
    parser.add_argument("--seed", type=int, default=66,
                        help="random seed for initialization")
    parser.add_argument("--num_class", type=int, default=97,
                        help="Number of relation types in dataset.")
    parser.add_argument("--num_labels", default=4, type=int,
                        help="Max number of labels in prediction.")

    parser.add_argument("--unet_in_dim", type=int, default=3,
                        help="unet_in_dim.")
    parser.add_argument("--unet_out_dim", type=int, default=256,
                        help="unet_out_dim.")
    parser.add_argument("--down_dim", type=int, default=256,
                        help="down_dim.")
    parser.add_argument("--channel_type", type=str, default='context-based',
                        help="unet_out_dim.")
    parser.add_argument("--log_dir", type=str, default='',
                        help="log.")
    parser.add_argument("--max_height", type=int, default=42,
                        help="log.")
    parser.add_argument("--train_from_saved_model", type=str, default='',
                        help="train from a saved model.")
    parser.add_argument("--dataset", type=str, default='docred',
                        help="dataset type")

    args = parser.parse_args(args=[]) # for jupyter execution
#     wandb.init(project="DocRED",mode='disabled')
    return args

In [28]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args = arg_pre()
args.n_gpu = torch.cuda.device_count()
args.device = device
args.model_name_or_path = model_type
if model_type.startswith('r'):
    args.transformer_type = 'roberta'
    args.load_path = args.load_path.replace('model_bert','model_roberta')
else:
    args.transformer_type = 'bert'
args.load_path = 'checkpoint/docred/roberta_reemb.pt'
PRETRAINED_DIR = '/cpfs/user/cht/cbsp/'
config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else PRETRAINED_DIR+ args.model_name_or_path,
    num_labels=args.num_class,
)
config.output_attentions = True
tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else PRETRAINED_DIR+args.model_name_or_path,
)

Dataset = ReadDataset(args.dataset, tokenizer, args.max_seq_length)
model = AutoModel.from_pretrained(
    PRETRAINED_DIR+args.model_name_or_path,
    from_tf=bool(".ckpt" in args.model_name_or_path),
    config=config,
)

config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type

set_seed(args)
model = DocREModel(config, args,  model, num_labels=args.num_labels)

model.to(device)
model.load_state_dict(torch.load(args.load_path)['checkpoint'])
print('load model from ', args.load_path)

load model from  checkpoint/docred/roberta_reemb.pt


# Load keyword attack ds

In [18]:
import importlib
import prepro
importlib.reload(prepro)
from prepro import read_docred_att
attack_dir = 'attack_ds/'
ori_path = attack_dir + model_type + '@ori_keyword_dev.json'
mask_path = attack_dir + model_type + '@mask_keyword_dev.json'
anto_path = attack_dir + model_type + '@anto_keyword_dev.json'
ori_anto_path = attack_dir + model_type + '@ori_anto_keyword_dev.json'
syno_path = attack_dir + model_type + '@syno_keyword_dev.json'
ori_syno_path = attack_dir + model_type + '@ori_syno_keyword_dev.json'

ori_features = read_docred_att(ori_path, tokenizer,ast=False)
mask_features = read_docred_att(mask_path, tokenizer,ast=False)
anto_features = read_docred_att(anto_path, tokenizer,ast=False)
ori_anto_features = read_docred_att(ori_anto_path, tokenizer,ast=False)
syno_features = read_docred_att(syno_path, tokenizer,ast=False)
ori_syno_features = read_docred_att(ori_syno_path, tokenizer,ast=False)

len(ori_features), len(anto_features), len(syno_features)

Example: 100%|██████████| 7342/7342 [01:40<00:00, 73.27it/s] 


# of documents 7342.
# of positive examples 7342.
# of negative examples 0.
# 206 examples len>512 and max len is 732.


Example: 100%|██████████| 7342/7342 [01:37<00:00, 75.24it/s] 


# of documents 7342.
# of positive examples 7342.
# of negative examples 0.
# 206 examples len>512 and max len is 732.


Example: 100%|██████████| 2002/2002 [00:26<00:00, 75.48it/s] 


# of documents 2002.
# of positive examples 2002.
# of negative examples 0.
# 45 examples len>512 and max len is 715.


Example: 100%|██████████| 2002/2002 [00:27<00:00, 72.00it/s] 


# of documents 2002.
# of positive examples 2002.
# of negative examples 0.
# 45 examples len>512 and max len is 714.


Example: 100%|██████████| 5231/5231 [01:11<00:00, 72.94it/s] 


# of documents 5231.
# of positive examples 5231.
# of negative examples 0.
# 146 examples len>512 and max len is 732.


Example: 100%|██████████| 5231/5231 [01:13<00:00, 71.36it/s] 

# of documents 5231.
# of positive examples 5231.
# of negative examples 0.
# 146 examples len>512 and max len is 732.





(7342, 2002, 5231)

In [13]:
def metric_keyword(features):
    pred = report(args, model, features)
    print(len(pred))
    spreds = sorted(pred, key=lambda x:x['title'])
    spreds = [{'title':s['title'], 'r': docred_rel2id[s['r']]} for s in spreds]
    key_res = defaultdict(list)
    for s in spreds:
        key_res[s['title']].append(s['r'])
    all_rnum = len(features)
    pos_rnum = len(key_res)
    no_rnum = all_rnum - pos_rnum
    true_rnum = 0
    for title, rl in key_res.items():
        truth = int(title.split('_')[-1])
        if truth in rl:
            true_rnum += 1

    wrong_rnum = pos_rnum - true_rnum
    return key_res, all_rnum, pos_rnum, no_rnum, true_rnum, wrong_rnum

def attack_ratio(ori_key_res, key_res, features):
    # ori no rel, now rel
    all_titles = set([f['title'] for f in features])
    ori_titles = set(ori_key_res.keys())
    titles = set(key_res.keys())
    no_titles = all_titles - ori_titles
    nor_r_ratio = len(no_titles & titles) / len(no_titles)
    # ori rel, now no rel
    r_nor_ratio = len((all_titles - titles) & ori_titles) / len(ori_titles)
    
    no_num,true_num,false_num = 0,0,0
    for key in ori_key_res.keys():
        if key not in key_res:
            no_num += 1
        elif len(set(key_res[key]) & set(ori_key_res[key])) > 0:
            true_num += 1
        else:
            false_num += 1
    # ori one rel, now another rel
    rel_arel_ratio = false_num / len(ori_titles)
    # ori one rel, now rel covers
    rel_srel_ratio = true_num / len(ori_titles)
    
    return r_nor_ratio, nor_r_ratio, rel_arel_ratio, rel_srel_ratio

# Entity Attack

In [23]:
from train_balanceloss_reset import evaluate_o
attack_dir = 'attack_ds/'
# json.dump(ori_features, open(attack_dir + model_type + '@ori_keyword_dev.pkl', 'wb'))
ori_file_path = attack_dir + 'dev_wo_overlap.json'
en_mask_path = attack_dir + model_type + '@en_mask_dev.json'
en_shuf_path = attack_dir + model_type + '@en_shuf_dev.json'
en_repl_path = attack_dir + model_type + '@en_repl_dev.json'

ori_features = read_docred_att(ori_file_path, tokenizer)
en_mask_features = read_docred_att(en_mask_path, tokenizer)
en_shuf_features = read_docred_att(en_shuf_path, tokenizer)
en_repl_features = read_docred_att(en_repl_path, tokenizer)

len(en_mask_features), len(en_shuf_features), len(en_repl_features)

Example: 100%|██████████| 884/884 [00:25<00:00, 35.09it/s]


# of documents 884.
# of positive examples 10295.
# of negative examples 335057.
# 50 examples len>512 and max len is 804.


Example: 100%|██████████| 884/884 [00:22<00:00, 40.04it/s]


# of documents 884.
# of positive examples 10295.
# of negative examples 335057.
# 50 examples len>512 and max len is 804.


Example: 100%|██████████| 884/884 [00:22<00:00, 39.81it/s]


# of documents 884.
# of positive examples 10295.
# of negative examples 335057.
# 51 examples len>512 and max len is 841.


Example: 100%|██████████| 884/884 [00:18<00:00, 48.49it/s]

# of documents 884.
# of positive examples 10295.
# of negative examples 335057.
# 44 examples len>512 and max len is 790.





(884, 884, 884)

In [33]:
args.dev_file = 'dev_wo_overlap.json'
args

Namespace(channel_type='context-based', config_name='', data_dir='./dataset/docred', dataset='docred', dev_file='dev_wo_overlap.json', device=device(type='cuda', index=0), down_dim=256, load_path='checkpoint/docred/roberta_reemb.pt', log_dir='', max_height=42, max_seq_length=1024, model_name_or_path='roberta-large', n_gpu=2, num_class=97, num_labels=4, seed=66, test_batch_size=8, tokenizer_name='', train_from_saved_model='', transformer_type='roberta', unet_in_dim=3, unet_out_dim=256)

In [27]:
# Docu ori roberta-large
all_feas = [ori_features, en_mask_features, en_shuf_features, en_repl_features] # ori_features, 
f1_outs = []
for fea in all_feas:
    _, f1_out = evaluate_o(args, model, fea, tag="dev")
    print(f1_out)
    f1_outs.append(f1_out)
en_attack_strs = ['original','entity mask', 'entity move', 'entity replace']
print()
for i in range(0, len(f1_outs)):
    print(f'& {en_attack_strs[i]} & {f1_outs[i]["dev_F1"]} & {f1_outs[i]["dev_F1_ign"]} \\\\ ')

{'dev_F1': 63.29041487839772, 'dev_F1_ign': 61.40522004808886, 'dev_re_p': 66.63319610402651, 'dev_re_r': 60.26700572155118, 'dev_average_loss': 0.37415633760057054}
{'dev_F1': 8.622610782549069, 'dev_F1_ign': 8.610971727686167, 'dev_re_p': 76.67682926829268, 'dev_re_r': 4.568159113613659, 'dev_average_loss': 1.066540204726898}
{'dev_F1': 8.077393787516755, 'dev_F1_ign': 7.670949823935591, 'dev_re_p': 11.27195836044242, 'dev_re_r': 6.293706293706294, 'dev_average_loss': 1.2295448796169177}
{'dev_F1': 18.54648059177915, 'dev_F1_ign': 18.449231939435034, 'dev_re_p': 57.84966698382493, 'dev_re_r': 11.043501952592862, 'dev_average_loss': 0.9096437628204758}

& original & 63.29041487839772 & 61.40522004808886 \\ 
& entity mask & 8.622610782549069 & 8.610971727686167 \\ 
& entity move & 8.077393787516755 & 7.670949823935591 \\ 
& entity replace & 18.54648059177915 & 18.449231939435034 \\ 
