Skip to content

Commit

Permalink
[Model] Update KG emb methods (#374)
Browse files Browse the repository at this point in the history
* Small Changes

* Small Changes
  • Loading branch information
QingFei1 committed Aug 16, 2022
1 parent d7372ec commit 1be3266
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 13 deletions.
9 changes: 8 additions & 1 deletion cogdl/models/emb/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ def add_args(parser):
parser.add_argument("--embedding_size", type=int, default=500, help="Dimensionality of embedded vectors")
parser.add_argument("--gamma", type=float,default=12.0, help="Hyperparameter for embedding")
parser.add_argument("--double_entity_embedding", default=True)
parser.add_argument("--double_relation_embedding", default=True)
parser.add_argument("--double_relation_embedding", default=True)

def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding,double_relation_embedding
):
super(ComplEx, self).__init__(nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding)


def score(self, head, relation, tail, mode):
re_head, im_head = torch.chunk(head, 2, dim=2)
re_relation, im_relation = torch.chunk(relation, 2, dim=2)
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/emb/distmult.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class DistMult(KGEModel):
"""

def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
):
super(DistMult, self).__init__(
nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
Expand Down
4 changes: 2 additions & 2 deletions cogdl/models/emb/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def build_model_from_args(cls, args):
)

def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
):

super(KGEModel, self).__init__()
Expand All @@ -39,7 +39,7 @@ def __init__(
self.embedding_range = nn.Parameter(
torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), requires_grad=False
)

self.entity_dim = hidden_dim * 2 if double_entity_embedding else hidden_dim
self.relation_dim = hidden_dim * 2 if double_relation_embedding else hidden_dim

Expand Down
4 changes: 2 additions & 2 deletions cogdl/models/emb/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def add_args(parser):
parser.add_argument("--double_entity_embedding", default=True)
parser.add_argument("--double_relation_embedding", action="store_true")
def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
):
super(RotatE, self).__init__(nentity, nrelation, hidden_dim, gamma, True, double_relation_embedding)
super(RotatE, self).__init__(nentity, nrelation, hidden_dim, gamma,double_entity_embedding, double_relation_embedding)

def score(self, head, relation, tail, mode):
pi = 3.14159265358979323846
Expand Down
4 changes: 2 additions & 2 deletions cogdl/models/emb/transe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class TransE(KGEModel):
"""

def __init__(
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding=False, double_relation_embedding=False
self, nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding
):
super(TransE, self).__init__(nentity, nrelation, hidden_dim, gamma, True, True)
super(TransE, self).__init__(nentity, nrelation, hidden_dim, gamma, double_entity_embedding, double_relation_embedding)



Expand Down
11 changes: 6 additions & 5 deletions cogdl/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,17 @@ def evaluation_comp(monitor, compare="<"):
def save_model(model, path, epoch):
print(f"Saving {epoch}-th model to {path} ...")
torch.save(model.state_dict(), path)

model=model.model
emb_path=os.path.dirname(path)
if hasattr(model, "entity_embedding"):
entity_embedding = model.entity_embedding.numpy()
entity_embedding = model.entity_embedding.detach().numpy()
print('Saving entity_embedding to ',path)
np.save(os.path.join(path, "entity_embedding"), entity_embedding)
np.save(os.path.join(emb_path, "entity_embedding"), entity_embedding)

if hasattr(model, "relation_embedding"):
relation_embedding = model.relation_embedding.numpy()
relation_embedding = model.relation_embedding.detach().numpy()
print('Saving entity_embedding to ',path)
np.save(os.path.join(entity_embedding, "relation_embedding"), relation_embedding)
np.save(os.path.join(emb_path, "relation_embedding"), relation_embedding)


def load_model(model, path):
Expand Down

0 comments on commit 1be3266

Please sign in to comment.