In [1]:
import openke
from openke.config import Trainer, Tester
from openke.module.model import TransH
from openke.module.loss import MarginLoss
from openke.module.strategy import NegativeSampling
from openke.data import TrainDataLoader, TestDataLoader, EmbeddingDataLoader
import torch
import torch.nn.functional as F
import random
import time
from tqdm import tqdm

In [2]:
def compute_embedding(model, e1, e2):
    e1 = torch.tensor(e1, dtype=torch.long).to(device)
    e2 = torch.tensor(e2, dtype=torch.long).to(device)
    embed_1 = model.ent_embeddings(e1)
    embed_2 = model.ent_embeddings(e2)
    embedding = embed_1 - embed_2
    return embedding

In [3]:
lr = 5e-6
n_cluster = 3
device = torch.device('cuda')

In [4]:
Cosine_dataloader = TrainDataLoader(
    in_path=None,
    tri_file='./benchmarks/YAGO3-10/train2id.txt',
    ent_file="./benchmarks/YAGO3-10/entity2id.txt",
    rel_file="./benchmarks/YAGO3-10/relation2id.txt",
    nbatches=100,
    threads=8,
    sampling_mode="normal",
    bern_flag=1,
    filter_flag=1,
    neg_ent=25,
    neg_rel=0)
Cosine_model = TransH(
	ent_tot = Cosine_dataloader.get_ent_tot(),
	rel_tot = Cosine_dataloader.get_rel_tot(),
	dim = 200, 
	p_norm = 1, 
	norm_flag = True)

Cosine_model.to(device)
Cosine_model.load_checkpoint('./checkpoint/YAGO/YAGO_TransH.ckpt')

In [5]:
tri_file = './benchmarks/YAGO3-10/train2id.txt'
unlearn_file = './benchmarks/YAGO3-10/deleted_node_unlearning.txt'
schema_file = './benchmarks/YAGO3-10/type_constrain.txt'
weight_file = './checkpoint/YAGO/YAGO_TransH.ckpt'

Schema_dataloader = EmbeddingDataLoader.CosineSchemaDataLoader(
    n_clusters=n_cluster,
    tri_file=tri_file,
    unlearn_file=unlearn_file,
    schema_file=schema_file,
    weight_file=weight_file)
Cosine_Sampling = NegativeSampling(
	model = Cosine_model, 
	loss = MarginLoss(margin = 5.0),
	batch_size = Cosine_dataloader.get_batch_size()
)

trainer = Trainer(model = Cosine_Sampling, 
                  data_loader = Cosine_dataloader, 
                  train_times = 1000, 
                  alpha = lr, 
                  use_gpu = True)
trainer.optimizer = torch.optim.Adam(
    trainer.model.parameters(),
    lr=trainer.alpha,
    weight_decay=trainer.weight_decay,
)

  super()._check_params_vs_input(X, default_n_init=10)


In [None]:
start_time = time.time()
total_loss = 0.0


with tqdm(total=len(Schema_dataloader.removed_triples), desc="Processing triples") as pbar:
    for idx, data in enumerate(Schema_dataloader.removed_triples):
        trainer.optimizer.zero_grad()
        max_similarity = float('-inf')
        all_iterations = 10
        loss_value = 0.0
        e1, e2, e3, e4 = None, None, None, None
        attempts = 0
        max_attempts = 10
        while (e1 is None or e2 is None or e3 is None or e4 is None) and attempts < max_attempts:
            e1, e2, e3, e4 = Schema_dataloader.query_match_entity(Schema_dataloader.triples, 
                                                           Schema_dataloader.adj_matrix, 
                                                           data,
                                                           Schema_dataloader.labels)
            attempts += 1
        if attempts >= max_attempts:
            pbar.update(1)
            continue
            
        Embed_Query = compute_embedding(Cosine_model, e1, e2)
        Embed_Match = compute_embedding(Cosine_model, e3, e4)
        cosine_similarity = torch.nn.functional.cosine_similarity(Embed_Query, Embed_Match, dim=0)
        if cosine_similarity > max_similarity:
            max_similarity = cosine_similarity
        loss_value = 1 - max_similarity
        loss_value.backward()
        total_loss += loss_value.item()
        trainer.optimizer.step()
        pbar.set_description(f"Processing triples (Loss: {total_loss / (idx + 1):.4f})")
        pbar.update(1)
print(f'Running Time: {time.time() - start_time}s')

Processing triples (Loss: 0.9433):   0%|                                           | 11/95344 [00:00<8:35:25,  3.08it/s]

Training Files Path : ./benchmarks/YAGO3-10/train2id.txt
Entity Files Path : ./benchmarks/YAGO3-10/entity2id.txt
Relation Files Path : ./benchmarks/YAGO3-10/relation2id.txt
The toolkit is importing datasets.
The total of relations is 37.
The total of entities is 123182.
The total of train triples is 1079040.


Processing triples (Loss: 0.6266):  50%|████████████████████▍                    | 47544/95344 [08:18<07:32, 105.63it/s]

In [7]:
Cosine_model.save_checkpoint(f'./checkpoint/YAGO/Entity_Unlearning_TransH_YAGO.ckpt')

test_dataloader = TestDataLoader("./benchmarks/YAGO3-10/", "link")
tester = Tester(model = Cosine_model, data_loader = test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)


100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:32<00:00, 153.33it/s]

0.5518999695777893





(0.33844101428985596,
 2177.32861328125,
 0.5518999695777893,
 0.400299996137619,
 0.227400004863739)

no type constraint results:
metric:			 MRR 		 MR 		 hit@10 	 hit@3  	 hit@1 
l(raw):			 0.039901 	 5523.730957 	 0.099000 	 0.031600 	 0.005800 
r(raw):			 0.215559 	 731.536194 	 0.590800 	 0.209400 	 0.088400 
averaged(raw):		 0.127730 	 3127.633545 	 0.344900 	 0.120500 	 0.047100 

l(filter):		 0.201675 	 3628.959961 	 0.376600 	 0.230000 	 0.114800 
r(filter):		 0.475207 	 725.697388 	 0.727200 	 0.570600 	 0.340000 
averaged(filter):	 0.338441 	 2177.328613 	 0.551900 	 0.400300 	 0.227400 
0.551900
