In [1]:
import openke
from openke.config import Trainer, Tester
from openke.module.model import TransD
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader, EmbeddingDataLoader
import torch
import torch.nn.functional as F
import random
import time
from tqdm import tqdm

In [2]:
def compute_embedding(model, e1, e2):
    e1 = torch.tensor(e1, dtype=torch.long).to(device)
    e2 = torch.tensor(e2, dtype=torch.long).to(device)
    embed_1 = model.ent_embeddings(e1)
    embed_2 = model.ent_embeddings(e2)
    embedding = embed_1 - embed_2
    return embedding

In [3]:
lr = 5e-6
n_cluster = 3
device = torch.device('cuda')

In [4]:
Cosine_dataloader = TrainDataLoader(
    in_path=None,
    tri_file='./benchmarks/YAGO3-10/train2id.txt',
    ent_file="./benchmarks/YAGO3-10/entity2id.txt",
    rel_file="./benchmarks/YAGO3-10/relation2id.txt",
    nbatches=100,
    threads=8,
    sampling_mode="normal",
    bern_flag=1,
    filter_flag=1,
    neg_ent=25,
    neg_rel=0)
Cosine_model = TransD(
	ent_tot = Cosine_dataloader.get_ent_tot(),
	rel_tot = Cosine_dataloader.get_rel_tot(),
	dim_e = 200, 
	dim_r = 200, 
	p_norm = 1, 
	norm_flag = True)

Cosine_model.to(device)
Cosine_model.load_checkpoint('./checkpoint/YAGO/YAGO_TransD.ckpt')

In [5]:
tri_file = './benchmarks/YAGO3-10/train2id.txt'
unlearn_file = './benchmarks/YAGO3-10/deleted_node_unlearning.txt'
schema_file = './benchmarks/YAGO3-10/type_constrain.txt'
weight_file = './checkpoint/YAGO/YAGO_TransD.ckpt'

Schema_dataloader = EmbeddingDataLoader.CosineSchemaDataLoader(
    n_clusters=n_cluster,
    tri_file=tri_file,
    unlearn_file=unlearn_file,
    schema_file=schema_file,
    weight_file=weight_file)
Cosine_Sampling = NegativeSampling(
	model = Cosine_model, 
	loss = MarginLoss(margin = 5.0),
	batch_size = Cosine_dataloader.get_batch_size()
)

trainer = Trainer(model = Cosine_Sampling, 
                  data_loader = Cosine_dataloader, 
                  train_times = 1000, 
                  alpha = lr, 
                  use_gpu = True)
trainer.optimizer = torch.optim.Adam(
    trainer.model.parameters(),
    lr=trainer.alpha,
    weight_decay=trainer.weight_decay,
)

  super()._check_params_vs_input(X, default_n_init=10)


In [6]:
start_time = time.time()
total_loss = 0.0


with tqdm(total=len(Schema_dataloader.removed_triples), desc="Processing triples") as pbar:
    for idx, data in enumerate(Schema_dataloader.removed_triples):
        trainer.optimizer.zero_grad()
        max_similarity = float('-inf')
        all_iterations = 10
        loss_value = 0.0
        e1, e2, e3, e4 = None, None, None, None
        attempts = 0
        max_attempts = 10
        while (e1 is None or e2 is None or e3 is None or e4 is None) and attempts < max_attempts:
            e1, e2, e3, e4 = Schema_dataloader.query_match_entity(Schema_dataloader.triples, 
                                                           Schema_dataloader.adj_matrix, 
                                                           data,
                                                           Schema_dataloader.labels)
            attempts += 1
        if attempts >= max_attempts:
            pbar.update(1)
            continue
            
        Embed_Query = compute_embedding(Cosine_model, e1, e2)
        Embed_Match = compute_embedding(Cosine_model, e3, e4)
        cosine_similarity = torch.nn.functional.cosine_similarity(Embed_Query, Embed_Match, dim=0)
        if cosine_similarity > max_similarity:
            max_similarity = cosine_similarity
        loss_value = 1 - max_similarity
        loss_value.backward()
        total_loss += loss_value.item()
        trainer.optimizer.step()
        pbar.set_description(f"Processing triples (Loss: {total_loss / (idx + 1):.4f})")
        pbar.update(1)
print(f'Running Time: {time.time() - start_time}s')

Processing triples (Loss: 0.6570):   0%|                                           | 17/95344 [00:00<8:34:03,  3.09it/s]

Training Files Path : ./benchmarks/YAGO3-10/train2id.txt
Entity Files Path : ./benchmarks/YAGO3-10/entity2id.txt
Relation Files Path : ./benchmarks/YAGO3-10/relation2id.txt
The toolkit is importing datasets.
The total of relations is 37.
The total of entities is 123182.
The total of train triples is 1079040.


Processing triples (Loss: 0.5794): 100%|██████████████████████████████████████████| 95344/95344 [16:49<00:00, 94.44it/s]

Running Time: 1009.6234107017517s





In [7]:
Cosine_model.save_checkpoint(f'./checkpoint/YAGO/Entity_Unlearning_TransD_YAGO.ckpt')

test_dataloader = TestDataLoader("./benchmarks/YAGO3-10/", "link")
tester = Tester(model = Cosine_model, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)


Input Files Path : ./benchmarks/YAGO3-10/
The total of test triples is 5000.
The total of valid triples is 5000.


100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:37<00:00, 135.12it/s]

0.510699987411499





(0.29329913854599,
 1541.779052734375,
 0.510699987411499,
 0.3443000018596649,
 0.18140000104904175)

no type constraint results:
metric:			 MRR 		 MR 		 hit@10 	 hit@3  	 hit@1 
l(raw):			 0.044089 	 4716.991211 	 0.108600 	 0.036600 	 0.009400 
r(raw):			 0.206214 	 473.699799 	 0.564000 	 0.198400 	 0.075800 
averaged(raw):		 0.125151 	 2595.345459 	 0.336300 	 0.117500 	 0.042600 

l(filter):		 0.165946 	 2615.651855 	 0.307200 	 0.184200 	 0.092800 
r(filter):		 0.420652 	 467.906189 	 0.714200 	 0.504400 	 0.270000 
averaged(filter):	 0.293299 	 1541.779053 	 0.510700 	 0.344300 	 0.181400 
0.510700
