In [1]:
import openke
import torch
import torch.autograd
import time
import random
from openke.config import Trainer, Tester
from openke.module.model import TransE
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader

In [127]:
def filter_triples(input_entity_file, input_triple_file, output_triple_file, remove_num=500):
    with open(input_entity_file, 'r') as f:
        total_entities = int(f.readline().strip())

    selected_indices = set(random.sample(range(total_entities), remove_num))

    with open(input_triple_file, 'r') as f:
        lines = f.readlines()

    header = lines[0]  
    triples = lines[1:]  

    filtered_triples = []
    for triple in triples:
        h, t, r = map(int, triple.split())
        if h not in selected_indices and t not in selected_indices:
            filtered_triples.append(triple)

    with open(output_triple_file, 'w') as f:
        f.write(str(len(filtered_triples)) + '\n') 
        f.writelines(filtered_triples)  

def filter_relations(input_relation_file, input_type_file, input_triple_file, output_triple_file, remove_num=5):
    # 读取关系文件以获取关系总数
    with open(input_relation_file, 'r') as f:
        f.readline()  # Skip the first line which contains the total number of relations
        relations = [line.split()[-1] for line in f]  # Assume relation ID is the second column

    # 随机选择要删除的关系
    selected_relations = set(random.sample(relations, remove_num))

    # 读取三元组文件，并过滤掉包含选定关系的三元组
    with open(input_triple_file, 'r') as f:
        lines = f.readlines()

    header = lines[0]  # Header with the total count of triples, which we ignore
    triples = lines[1:]  # Actual triples start from the second line

    filtered_triples = []
    for triple in triples:
        _, r, _ = triple.split()
        if r not in selected_relations:
            filtered_triples.append(triple)

    # 将过滤后的三元组写入输出文件
    with open(output_triple_file, 'w') as f:
        f.write(str(len(filtered_triples)) + '\n')  # Write the count of remaining triples
        f.writelines(filtered_triples)  # Write the triples
        

def calculate_gradients(model, data):
    model.eval()  

    loss = model.model({
        'batch_h': torch.autograd.Variable(torch.from_numpy(data['batch_h']).cuda()),
        'batch_t': torch.autograd.Variable(torch.from_numpy(data['batch_t']).cuda()),
        'batch_r': torch.autograd.Variable(torch.from_numpy(data['batch_r']).cuda()),
        'batch_y': torch.autograd.Variable(torch.from_numpy(data['batch_y']).cuda()),
        'mode': data['mode']
    })
    loss_scalar = torch.mean(loss)
    params = filter(lambda p: p.requires_grad, model.parameters())
    params_to_update = [param for name, param in model.named_parameters() if name.endswith('.weight')]
    return torch.autograd.grad(loss_scalar, params_to_update, create_graph=True)

def hvps(grad_all, model_params, h_estimate):
    element_product = 0
    for grad_elem, v_elem in zip(grad_all, h_estimate):
        element_product += torch.sum(grad_elem * v_elem)
    return_grads = torch.autograd.grad(element_product, model_params, create_graph=True)
    
    return return_grads
    
def GIF_unleanring(model, train_dataloader, test_dataloader, iteration=1, damp=0.0, scale=50):
    start_time = time.time()

    for data in train_dataloader:
        grad_full = calculate_gradients(model, data)

    for data in test_dataloader:
        grad_removed = calculate_gradients(model, data)

    grad1 = [g1 - g2 for g1, g2 in zip(grad_full, grad_removed)]
    grad2 = grad_removed
    res_tuple = (grad_full, grad1, grad2)

    v = tuple(grad1 - grad2 for grad1, grad2 in zip(res_tuple[1], res_tuple[2]))
    h_estimate = tuple(grad1 - grad2 for grad1, grad2 in zip(res_tuple[1], res_tuple[2]))
    
    for _ in range(iteration):
        model_params  = [p for p in model.parameters() if p.requires_grad]
        hv = hvps(res_tuple[0], model_params, h_estimate)
        with torch.no_grad():
            h_estimate = [ v1 + (1-damp)*h_estimate1 - hv1/scale for v1, h_estimate1, hv1 in zip(v, h_estimate, hv)]
            
    params_change = [h_est / scale for h_est in h_estimate]
    params_esti   = [p1 + p2 for p1, p2 in zip(params_change, model_params)]

    print(time.time() - start_time)
    
    return params_esti

def update_and_save_checkpoint(checkpoint_path, new_checkpoint_path, new_params):
    weights = torch.load(checkpoint_path)
    weights['ent_embeddings.weight'] = new_params[0]
    weights['rel_embeddings.weight'] = new_params[1]
    torch.save(weights, new_checkpoint_path)
    print(f"Updated checkpoint saved to {new_checkpoint_path}")

In [128]:
# removed_entities = 3000
# removed_files = f'./benchmarks/FB15K237/removed_{removed_entities}_train2id.txt'
# filter_triples('./benchmarks/FB15K237/entity2id.txt', './benchmarks/FB15K237/train2id.txt', removed_files, removed_entities)

In [129]:
removed_relations = 50
removed_files = f'./benchmarks/FB15K237/removed_{removed_relations}_train2id.txt'
filter_relations(input_relation_file='./benchmarks/FB15K237/relation2id.txt', input_triple_file='./benchmarks/FB15K237/train2id.txt',
                 output_triple_file = removed_files, remove_num=removed_relations)

In [130]:
retrain_dataloader = TrainDataLoader(
	in_path = None, 
    tri_file = removed_files,
    ent_file = "./benchmarks/FB15K237/entity2id.txt",
    rel_file = "./benchmarks/FB15K237/relation2id.txt",
	nbatches = 100,
	threads = 8, 
	sampling_mode = "normal", 
	bern_flag = 1, 
	filter_flag = 1, 
	neg_ent = 25,
	neg_rel = 0)

Training Files Path : ./benchmarks/FB15K237/removed_50_train2id.txt
Entity Files Path : ./benchmarks/FB15K237/entity2id.txt
Relation Files Path : ./benchmarks/FB15K237/relation2id.txt
The toolkit is importing datasets.
The total of relations is 237.
The total of entities is 14541.
The total of train triples is 258816.


In [119]:
test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link")

Input Files Path : ./benchmarks/FB15K237/
The total of test triples is 20466.
The total of valid triples is 17535.


In [105]:
retrain_transe = TransE(
	ent_tot = retrain_dataloader.get_ent_tot(),
	rel_tot = retrain_dataloader.get_rel_tot(),
	dim = 200, 
	p_norm = 1, 
	norm_flag = True)

In [106]:
retrain_model = NegativeSampling(
	model = retrain_transe, 
	loss = MarginLoss(margin = 5.0),
	batch_size = retrain_dataloader.get_batch_size()
)

In [107]:
retrain_trainer = Trainer(model = retrain_model, data_loader = retrain_dataloader, train_times = 1000, alpha = 1.0, use_gpu = True)
retrain_trainer.run()
retrain_transe.save_checkpoint(f'./checkpoint/retrain_{removed_relations}_transe.ckpt')

Finish initializing...


Epoch 999 | loss: 2.365449: 100%|███████████████████████████████████████████████████| 1000/1000 [17:00<00:00,  1.02s/it]


In [108]:
retrain_transe.load_checkpoint(f'./checkpoint/retrain_{removed_relations}_transe.ckpt')
retrain_tester = Tester(model = retrain_transe, data_loader = test_dataloader, use_gpu = True)
retrain_tester.run_link_prediction(type_constrain = False)

100%|███████████████████████████████████████████████████████████████████████████| 20466/20466 [00:16<00:00, 1226.95it/s]

0.4750317335128784





(0.28603631258010864,
 227.21011352539062,
 0.4750317335128784,
 0.3225349187850952,
 0.19041337072849274)

In [109]:
train_dataloader = TrainDataLoader(
	in_path = None, 
    tri_file = "./benchmarks/FB15K237/train2id.txt",
    ent_file = "./benchmarks/FB15K237/entity2id.txt",
    rel_file = "./benchmarks/FB15K237/relation2id.txt",
	nbatches = 100,
	threads = 8, 
	sampling_mode = "normal", 
	bern_flag = 1, 
	filter_flag = 1, 
	neg_ent = 25,
	neg_rel = 0)
print('----------------------------------------------------------------------')
retrain_dataloader = TrainDataLoader(
	in_path = None, 
    tri_file = removed_files,
    ent_file = "./benchmarks/FB15K237/entity2id.txt",
    rel_file = "./benchmarks/FB15K237/relation2id.txt",
	nbatches = 100,
	threads = 8, 
	sampling_mode = "normal", 
	bern_flag = 1, 
	filter_flag = 1, 
	neg_ent = 25,
	neg_rel = 0)

no type constraint results:
metric:			 MRR 		 MR 		 hit@10 	 hit@3  	 hit@1 
l(raw):			 0.088172 	 566.850403 	 0.204192 	 0.084286 	 0.033030 
r(raw):			 0.249192 	 168.386154 	 0.439509 	 0.271328 	 0.156650 
averaged(raw):		 0.168682 	 367.618286 	 0.321851 	 0.177807 	 0.094840 

l(filter):		 0.188467 	 311.960815 	 0.357129 	 0.210349 	 0.105052 
r(filter):		 0.383606 	 142.459396 	 0.592935 	 0.434721 	 0.275774 
averaged(filter):	 0.286036 	 227.210114 	 0.475032 	 0.322535 	 0.190413 
0.475032
Training Files Path : ./benchmarks/FB15K237/train2id.txt
Entity Files Path : ./benchmarks/FB15K237/entity2id.txt
Relation Files Path : ./benchmarks/FB15K237/relation2id.txt
The toolkit is importing datasets.
The total of relations is 237.
The total of entities is 14541.
----------------------------------------------------------------------
The total of train triples is 272115.
Training Files Path : ./benchmarks/FB15K237/removed_10_train2id.txt
Entity Files Path : ./benchmarks/FB15K237/ent

In [110]:
model = TransE(
	ent_tot = train_dataloader.get_ent_tot(),
	rel_tot = train_dataloader.get_rel_tot(),
	dim = 200, 
	p_norm = 1, 
	norm_flag = True)
model.to('cuda')
model.load_checkpoint('./checkpoint/transe.ckpt')
model = NegativeSampling(
	model = model, 
	loss = MarginLoss(margin = 5.0),
	batch_size = train_dataloader.get_batch_size()
)

The total of train triples is 270934.


In [114]:
params_esti = GIF_unleanring(model, train_dataloader, retrain_dataloader, iteration=100, damp=0.0, scale=50)

3.6354987621307373


In [115]:
update_and_save_checkpoint(checkpoint_path='./checkpoint/transe.ckpt', 
                           new_checkpoint_path=f'./checkpoint/GIF_{removed_relations}_TransE.ckpt', 
                           new_params=params_esti)

Updated checkpoint saved to ./checkpoint/GIF_10_TransE.ckpt


In [116]:
# dataloader for test
test_dataloader = TestDataLoader("./benchmarks/FB15K237/", "link")

# define the model
unlearn_transe = TransE(
	ent_tot = train_dataloader.get_ent_tot(),
	rel_tot = train_dataloader.get_rel_tot(),
	dim = 200, 
	p_norm = 1, 
	norm_flag = True)

# test the model
unlearn_transe.load_checkpoint(f'./checkpoint/GIF_{removed_entities}_TransE.ckpt')
unlearn_tester = Tester(model = unlearn_transe, data_loader = test_dataloader, use_gpu = True)
unlearn_tester.run_link_prediction(type_constrain = False)

Input Files Path : ./benchmarks/FB15K237/
The total of test triples is 20466.
The total of valid triples is 17535.


100%|███████████████████████████████████████████████████████████████████████████| 20466/20466 [00:16<00:00, 1215.12it/s]

0.47078078985214233





(0.27326950430870056,
 231.47970581054688,
 0.47078078985214233,
 0.3112967610359192,
 0.17475324869155884)

no type constraint results:
metric:			 MRR 		 MR 		 hit@10 	 hit@3  	 hit@1 
l(raw):			 0.088247 	 569.296326 	 0.206049 	 0.086680 	 0.032200 
r(raw):			 0.247412 	 167.087662 	 0.436871 	 0.269032 	 0.155135 
averaged(raw):		 0.167830 	 368.191986 	 0.321460 	 0.177856 	 0.093668 

l(filter):		 0.184681 	 321.508118 	 0.352731 	 0.207564 	 0.101437 
r(filter):		 0.361858 	 141.451279 	 0.588830 	 0.415030 	 0.248070 
averaged(filter):	 0.273270 	 231.479706 	 0.470781 	 0.311297 	 0.174753 
0.470781
