In [151]:
import pickle
from typing import Dict, Tuple, List
import os
import numpy as np
import json
import logging
import pandas as pd
import sys

import glob

import torch
from torch.utils.data import DataLoader

# from evaluation import evaluation
import evaluation
from model import Distmult, Complex, Conve, Transe
import utils

In [152]:
'''
Pseudocode - 
    - Load the poisoned dataset, test.txt is the file with target triples, influential_triples.txt has influential triples
    - (but need to load the target triples from target dataset to get correct to_skip_eval; otherwise can regenerate the dicts)
    - Load the original model and compute ranks on target triples
    - Load the poisoned model and compute ranks on target triples 
    - Compute the difference in original and poisoned ranks
    - Sort the indexes of target triples based on the difference in their ranks
    - identify the influential triple for highest rank diff and lowest rank diff
'''

'\nPseudocode - \n    - Load the poisoned dataset, test.txt is the file with target triples, influential_triples.txt has influential triples\n    - (but need to load the target triples from target dataset to get correct to_skip_eval; otherwise can regenerate the dicts)\n    - Load the original model and compute ranks on target triples\n    - Load the poisoned model and compute ranks on target triples \n    - Compute the difference in original and poisoned ranks\n    - Sort the indexes of target triples based on the difference in their ranks\n    - identify the influential triple for highest rank diff and lowest rank diff\n'

In [153]:
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt = '%m/%d/%Y %H:%M:%S',
                            level = logging.INFO,
                            #filename = log_path
                           )
logger = logging.getLogger(__name__)

In [154]:
## set arguments to pass to model init later
parser = utils.get_argument_parser()
sys.argv = ['prog.py']
args = parser.parse_args()

In [155]:
args.model = 'complex'
args.original_data = 'FB15k-237'
attack_method = 'com_add_3'
args.data = '{}_{}_{}_1_1_1'.format(attack_method, args.model, args.original_data)

In [156]:
## set the hyperparams
args = utils.set_hyperparams(args)

## set the device - legacy code to re-use functions from utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [157]:
logger.info('Model name: {}\n'.format(args.model))
logger.info('Dataset name: {} \n'.format(args.data))
logger.info('Original dataset name: {} \n'.format(args.original_data))

09/06/2022 14:36:09 - INFO - __main__ -   Model name: complex

09/06/2022 14:36:09 - INFO - __main__ -   Dataset name: com_add_3_complex_FB15k-237_1_1_1 

09/06/2022 14:36:09 - INFO - __main__ -   Original dataset name: FB15k-237 



In [158]:
## Load the target dataset and coresponding eval dictionaries
logger.info('------------ Load the target dataset ----------')
data_path = 'data/target_{}_{}_1'.format(args.model, args.original_data)

n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)

data  = utils.load_data(data_path)
train_data, valid_data, test_data = data['train'], data['valid'], data['test']

inp_f = open(os.path.join(data_path, 'to_skip_eval.pickle'), 'rb')
to_skip_eval: Dict[str, Dict[Tuple[int, int], List[int]]] = pickle.load(inp_f)
inp_f.close()
to_skip_eval['lhs'] = {(int(k[0]), int(k[1])): v for k,v in to_skip_eval['lhs'].items()}
to_skip_eval['rhs'] = {(int(k[0]), int(k[1])): v for k,v in to_skip_eval['rhs'].items()}

09/06/2022 14:36:09 - INFO - __main__ -   ------------ Load the target dataset ----------


In [159]:
## example name of original model
## FB15k-237_distmult_200_0.5_0.3_0.3.model

## example name of poisoned model
## sym_add_1_distmult_FB15k-237_1_1_1_distmult_200_0.5_0.3_0.3.model

In [160]:
logger.info('-------- Load the original model -----------')
## set the model path without hyperparam arguments
model_dir = 'saved_models/{}_{}_*.model'.format(args.original_data, args.model)
for filename in glob.glob(model_dir):
    model_path = filename
    
# add a model and load the pre-trained params
original_model = utils.load_model(model_path, args, n_ent, n_rel, device)

09/06/2022 14:36:10 - INFO - __main__ -   -------- Load the original model -----------
09/06/2022 14:36:10 - INFO - utils -   Loading saved model from saved_models/FB15k-237_complex_200_0.5_0.3_0.3.model
09/06/2022 14:36:10 - INFO - utils -   Key:emb_e_real.weight, Size:torch.Size([14505, 200]), Count:2901000
09/06/2022 14:36:10 - INFO - utils -   Key:emb_e_img.weight, Size:torch.Size([14505, 200]), Count:2901000
09/06/2022 14:36:10 - INFO - utils -   Key:emb_rel_real.weight, Size:torch.Size([237, 200]), Count:47400
09/06/2022 14:36:10 - INFO - utils -   Key:emb_rel_img.weight, Size:torch.Size([237, 200]), Count:47400
09/06/2022 14:36:10 - INFO - utils -   Complex(
  (emb_e_real): Embedding(14505, 200)
  (emb_e_img): Embedding(14505, 200)
  (emb_rel_real): Embedding(237, 200)
  (emb_rel_img): Embedding(237, 200)
  (inp_drop): Dropout(p=0.5, inplace=False)
  (loss): BCEWithLogitsLoss()
)


In [161]:
logger.info('------- Ranks on target dataset from original model ----------')
### legacy code
if args.add_reciprocals:
    num_rel= n_rel
else:
    num_rel = 0
    
test_data = torch.from_numpy(test_data.astype('int64')).to(device)
ranks_lhs, ranks_rhs = evaluation.get_ranking(original_model, test_data, num_rel, to_skip_eval, device)
ranks_lhs, ranks_rhs = np.array(ranks_lhs), np.array(ranks_rhs)
ranks = np.mean( np.array([ ranks_lhs, ranks_rhs ]), axis=0 )

09/06/2022 14:36:10 - INFO - __main__ -   ------- Ranks on target dataset from original model ----------


In [162]:
mr_lhs = np.mean(ranks_lhs, dtype=np.float64)
mr_rhs = np.mean(ranks_rhs, dtype=np.float64)
mr = np.mean(ranks, dtype=np.float64)
### these should match the mean values from log files
logger.info('Original mean ranks. Lhs:{}, Rhs:{}, Mean:{}\n'.format(mr_lhs, mr_rhs, mr))

09/06/2022 14:36:13 - INFO - __main__ -   Original mean ranks. Lhs:3.271501272264631, Rhs:2.589821882951654, Mean:2.9306615776081424



In [163]:
## Load the poisoned dataset and coresponding eval dictionaries
logger.info('------------ Load the poisoned dataset ----------')
data_path = 'data/{}'.format(args.data)

n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)

data  = utils.load_data(data_path)
train_data, valid_data, test_data = data['train'], data['valid'], data['test']

inp_f = open(os.path.join(data_path, 'to_skip_eval.pickle'), 'rb')
to_skip_eval: Dict[str, Dict[Tuple[int, int], List[int]]] = pickle.load(inp_f)
inp_f.close()
to_skip_eval['lhs'] = {(int(k[0]), int(k[1])): v for k,v in to_skip_eval['lhs'].items()}
to_skip_eval['rhs'] = {(int(k[0]), int(k[1])): v for k,v in to_skip_eval['rhs'].items()}

09/06/2022 14:36:13 - INFO - __main__ -   ------------ Load the poisoned dataset ----------


In [164]:
# Get all adversarial triples
with open(os.path.join(data_path, 'summary_edits.json')) as f:
    summary_dict = json.load(f)

In [165]:
num_adv_o = 0
num_adv_s = 0
for key, value in summary_dict.items():
    adv_o, adv_s = value[1], value[2]
    if len(adv_o) == 3:
        num_adv_o +=1 
    if len(adv_s) == 3:
        num_adv_s +=1 

In [166]:
#### simple check
assert(2*test_data.shape[0] == num_adv_o + num_adv_s)

In [167]:
logger.info('-------- Load the poisoned model -----------')
## set the model path without hyperparam arguments
model_dir = 'saved_models/{}_{}_*.model'.format(args.data, args.model)
for filename in glob.glob(model_dir):
    model_path = filename
    
# add a model and load the pre-trained params
poisoned_model = utils.load_model(model_path, args, n_ent, n_rel, device)

09/06/2022 14:36:14 - INFO - __main__ -   -------- Load the poisoned model -----------
09/06/2022 14:36:14 - INFO - utils -   Loading saved model from saved_models/com_add_3_complex_FB15k-237_1_1_1_complex_200_0.5_0.3_0.3.model
09/06/2022 14:36:14 - INFO - utils -   Key:emb_e_real.weight, Size:torch.Size([14505, 200]), Count:2901000
09/06/2022 14:36:14 - INFO - utils -   Key:emb_e_img.weight, Size:torch.Size([14505, 200]), Count:2901000
09/06/2022 14:36:14 - INFO - utils -   Key:emb_rel_real.weight, Size:torch.Size([237, 200]), Count:47400
09/06/2022 14:36:14 - INFO - utils -   Key:emb_rel_img.weight, Size:torch.Size([237, 200]), Count:47400
09/06/2022 14:36:14 - INFO - utils -   Complex(
  (emb_e_real): Embedding(14505, 200)
  (emb_e_img): Embedding(14505, 200)
  (emb_rel_real): Embedding(237, 200)
  (emb_rel_img): Embedding(237, 200)
  (inp_drop): Dropout(p=0.5, inplace=False)
  (loss): BCEWithLogitsLoss()
)


In [168]:
logger.info('------- Ranks on target dataset from poisoned model ----------')
logger.info('(using eval dicts from poisoned data)')

### legacy code
if args.add_reciprocals:
    num_rel= n_rel
else:
    num_rel = 0
    
test_data = torch.from_numpy(test_data.astype('int64')).to(device)
pos_ranks_lhs, pos_ranks_rhs = evaluation.get_ranking(poisoned_model, test_data, num_rel, to_skip_eval, device)
pos_ranks_lhs, pos_ranks_rhs = np.array(pos_ranks_lhs), np.array(pos_ranks_rhs)
pos_ranks = np.mean( np.array([ pos_ranks_lhs, pos_ranks_rhs ]), axis=0 )

09/06/2022 14:36:14 - INFO - __main__ -   ------- Ranks on target dataset from poisoned model ----------
09/06/2022 14:36:14 - INFO - __main__ -   (using eval dicts from poisoned data)


In [169]:
pos_mr_lhs = np.mean(pos_ranks_lhs, dtype=np.float64)
pos_mr_rhs = np.mean(pos_ranks_rhs, dtype=np.float64)
pos_mr = np.mean(pos_ranks, dtype=np.float64)
### these should match the mean values from log files
logger.info('Poisoned mean ranks. Lhs:{}, Rhs:{}, Mean:{}\n'.format(pos_mr_lhs, pos_mr_rhs, pos_mr))

09/06/2022 14:36:17 - INFO - __main__ -   Poisoned mean ranks. Lhs:14.514758269720101, Rhs:6.40941475826972, Mean:10.462086513994912



In [170]:
ranks_diff = pos_ranks - ranks
sorted_idx = np.argsort(ranks_diff) ## indices of sorted ranks
sorted_diffs = ranks_diff[sorted_idx] ## values of sorted ranks

In [171]:
try: 
    if test_data.is_cuda:
        test_data = test_data.cpu().numpy() #remove the torch tensor
except:
    test_data = np.array(test_data)

In [172]:
# get the entities from IDs
id_to_ent = {ent_to_id[k]:k for k in ent_to_id.keys()}
id_to_rel = {rel_to_id[k]:k for k in rel_to_id.keys()}


In [173]:
if 'com_add' in attack_method:
    max_s, max_p, max_o = test_data[sorted_idx[-1]]
    max_ho, max_ro, max_to = summary_dict[str(sorted_idx[-1])][1]  ## adversarial triple for o-side
    max_hod, max_rod, max_tod = summary_dict[str(sorted_idx[-1])][2]
    max_hs, max_rs, max_ts = summary_dict[str(sorted_idx[-1])][3]  ## adversarial triple for s-side
    max_hsd, max_rsd, max_tsd = summary_dict[str(sorted_idx[-1])][4]

    min_s, min_p, min_o = test_data[sorted_idx[0]]
    min_ho, min_ro, min_to = summary_dict[str(sorted_idx[0])][1] ## adversarial triple for o-side
    min_hod, min_rod, min_tod = summary_dict[str(sorted_idx[0])][2]
    min_hs, min_rs, min_ts = summary_dict[str(sorted_idx[0])][3] ## adversarial triple for s-side
    min_hsd, min_rsd, min_tsd = summary_dict[str(sorted_idx[0])][4]
else:
    max_s, max_p, max_o = test_data[sorted_idx[-1]]
    max_ho, max_ro, max_to = summary_dict[str(sorted_idx[-1])][1]  ## adversarial triple for o-side
    max_hs, max_rs, max_ts = summary_dict[str(sorted_idx[-1])][2]  ## adversarial triple for s-side

    min_s, min_p, min_o = test_data[sorted_idx[0]]
    min_ho, min_ro, min_to = summary_dict[str(sorted_idx[0])][1] ## adversarial triple for o-side
    min_hs, min_rs, min_ts = summary_dict[str(sorted_idx[0])][2] ## adversarial triple for s-side

In [174]:
max_target = [id_to_ent[max_s], id_to_rel[max_p], id_to_ent[max_o]]
max_adv_o = [id_to_ent[max_ho], id_to_rel[max_ro], id_to_ent[max_to]]
max_adv_s = [id_to_ent[max_hs], id_to_rel[max_rs], id_to_ent[max_ts]]

min_target = [id_to_ent[min_s], id_to_rel[min_p], id_to_ent[min_o]]
min_adv_o = [id_to_ent[min_ho], id_to_rel[min_ro], id_to_ent[min_to]]
min_adv_s = [id_to_ent[min_hs], id_to_rel[min_rs], id_to_ent[min_ts]]

if 'com_add' in attack_method:
    max_adv_od = [id_to_ent[max_hod], id_to_rel[max_rod], id_to_ent[max_tod]]
    max_adv_sd = [id_to_ent[max_hsd], id_to_rel[max_rsd], id_to_ent[max_tsd]]
    
    min_adv_od = [id_to_ent[min_hod], id_to_rel[min_rod], id_to_ent[min_tod]]
    min_adv_sd = [id_to_ent[min_hsd], id_to_rel[min_rsd], id_to_ent[min_tsd]]

In [175]:
logger.info('---- For {} on {} {}\n'.format(attack_method, args.model, args.original_data))

logger.info('Maximum change in ranks: {}\n'.format(sorted_diffs[-1]))
logger.info('Target triple with maximum change: {}\n'.format(max_target))
logger.info('Corresponding adversarial triple on o-side: {}\n'.format(max_adv_o))
if 'com_add' in attack_method:
    logger.info('Corresponding adversarial triple on o-side: {}\n'.format(max_adv_od))
logger.info('Corresponding adversarial triple on s-side: {}\n'.format(max_adv_s))
if 'com_add' in attack_method:
    logger.info('Corresponding adversarial triple on s-side: {}\n'.format(max_adv_sd))

logger.info('Minimum change in ranks: {}\n'.format(sorted_diffs[0]))
logger.info('Target triple with minimum change: {}\n'.format(min_target))
logger.info('Corresponding adversarial triple on o-side: {}\n'.format(min_adv_o))
if 'com_add' in attack_method:
    logger.info('Corresponding adversarial triple on o-side: {}\n'.format(min_adv_od))
logger.info('Corresponding adversarial triple on s-side: {}\n'.format(min_adv_s))
if 'com_add' in attack_method:
    logger.info('Corresponding adversarial triple on s-side: {}\n'.format(min_adv_sd))

09/06/2022 14:36:17 - INFO - __main__ -   ---- For com_add_3 on complex FB15k-237

09/06/2022 14:36:17 - INFO - __main__ -   Maximum change in ranks: 2180.5

09/06/2022 14:36:17 - INFO - __main__ -   Target triple with maximum change: ['/m/01wv9xn', '/common/topic/webpage./common/webpage/category', '/m/08mbj5d']

09/06/2022 14:36:17 - INFO - __main__ -   Corresponding adversarial triple on o-side: ['/m/01wv9xn', '/award/award_winning_work/awards_won./award/award_honor/award_winner', '/m/03rl84']

09/06/2022 14:36:17 - INFO - __main__ -   Corresponding adversarial triple on o-side: ['/m/03rl84', '/people/person/spouse_s./people/marriage/location_of_ceremony', '/m/030qb3t']

09/06/2022 14:36:17 - INFO - __main__ -   Corresponding adversarial triple on s-side: ['/m/01ckcd', '/award/award_winning_work/awards_won./award/award_honor/award_winner', '/m/0mjn2']

09/06/2022 14:36:17 - INFO - __main__ -   Corresponding adversarial triple on s-side: ['/m/0mjn2', '/people/person/spouse_s./people/m

use this to change Freebase IDs to values

Link - https://freebase.toolforge.org/

Another method is to use the Google Knowledge Graph Search API

Link - https://developers.google.com/knowledge-graph/reference/rest/v1/

Original WN18RR dataset with definition files (to get entity values from IDs) - 
- Link1 - https://figshare.com/articles/dataset/WN18/11869548/2
- Link2 - https://everest.hds.utc.fr/doku.php?id=en:smemlj12