forked from uma-pi1/kge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rotate.py
213 lines (167 loc) · 7.38 KB
/
rotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
import math
from kge import Config, Dataset
from kge.job import Job
from kge.model.kge_model import RelationalScorer, KgeModel
from torch.nn import functional as F
# TODO sp_ and _po scoring with RotatE leads to *large* intermediate results. It's
# unclear whether this can be fixed. Expect out-of-memory errors when using RotatE with
# 1vsAll or KvsAll training. To do validation/evaluation, you may want to set
# eval.chunk_size.
class RotatEScorer(RelationalScorer):
r"""Implementation of the RotatE KGE scorer."""
def __init__(self, config: Config, dataset: Dataset, configuration_key=None):
super().__init__(config, dataset, configuration_key)
self._norm = self.get_option("l_norm")
def score_emb(self, s_emb, p_emb, o_emb, combine: str):
n = p_emb.size(0)
# determine real and imaginary part
s_emb_re, s_emb_im = torch.chunk(s_emb, 2, dim=1)
o_emb_re, o_emb_im = torch.chunk(o_emb, 2, dim=1)
# convert from radians to points on complex unix ball
p_emb_re, p_emb_im = torch.cos(p_emb), torch.sin(p_emb)
if combine == "spo":
# compute the difference vector (s*p-t)
sp_emb_re, sp_emb_im = hadamard_complex(
s_emb_re, s_emb_im, p_emb_re, p_emb_im
)
diff_re, diff_im = diff_complex(sp_emb_re, sp_emb_im, o_emb_re, o_emb_im)
# compute the absolute values for each (complex) element of the difference
# vector
diff_abs = abs_complex(diff_re, diff_im)
# now take the norm of the absolute values of the difference vector
out = -norm_nonnegative(diff_abs, dim=1, p=self._norm)
elif combine == "sp_":
# as above, but pair each sp-pair with each object
sp_emb_re, sp_emb_im = hadamard_complex(
s_emb_re, s_emb_im, p_emb_re, p_emb_im
) # sp x dim
diff_re, diff_im = pairwise_diff_complex(
sp_emb_re, sp_emb_im, o_emb_re, o_emb_im
) # sp x o x dim
diff_abs = abs_complex(diff_re, diff_im) # sp x o x dim
out = -norm_nonnegative(diff_abs, dim=2, p=self._norm)
elif combine == "_po":
# compute the complex conjugate (cc) of the relation vector and perform
# inverse rotation on tail. This uses || s*p - o || = || s - cc(p)*o || for
# a rotation p.
p_emb_im = -p_emb_im
po_emb_re, po_emb_im = hadamard_complex(
p_emb_re, p_emb_im, o_emb_re, o_emb_im
) # po x dim
diff_re, diff_im = pairwise_diff_complex(
po_emb_re, po_emb_im, s_emb_re, s_emb_im
) # po x s x dim
diff_abs = abs_complex(diff_re, diff_im) # po x s x dim
out = -norm_nonnegative(diff_abs, dim=2, p=self._norm)
else:
return super().score_emb(s_emb, p_emb, o_emb, combine)
return out.view(n, -1)
class RotatE(KgeModel):
r"""Implementation of the RotatE KGE model."""
def __init__(
self,
config: Config,
dataset: Dataset,
configuration_key=None,
init_for_load_only=False,
):
self._init_configuration(config, configuration_key)
if self.get_option("entity_embedder.dim") % 2 != 0:
raise ValueError(
"RotatE requires embeddings of even dimensionality"
" (got {})".format(self.get_option("entity_embedder.dim"))
)
if self.get_option("relation_embedder.dim") < 0:
self.set_option(
"relation_embedder.dim",
self.get_option("entity_embedder.dim") // 2,
log=True,
)
super().__init__(
config=config,
dataset=dataset,
scorer=RotatEScorer,
configuration_key=self.configuration_key,
init_for_load_only=init_for_load_only,
)
self._normalize_phases = self.get_option("normalize_phases")
@torch.no_grad()
def normalize_phases(self):
out = self.get_p_embedder()._embeddings.weight.data
# normalize phases so that they lie in [-pi,pi]
# TODO this is a hack that assumes that we use a lookup embedder
# first shift phases by pi
out = out + math.pi
# compute the modulo (result then in [0,2*pi))
out = torch.remainder(out, 2.0 * math.pi)
# shift back
out = out - math.pi
# write back the updated embeddings
self.get_p_embedder()._embeddings.weight.data[:] = out[:]
def prepare_job(self, job: Job, **kwargs):
from kge.job import TrainingJob
super().prepare_job(job, **kwargs)
if self._normalize_phases and isinstance(job, TrainingJob):
from kge.model import LookupEmbedder
if not isinstance(self.get_p_embedder(), LookupEmbedder):
raise ValueError(
"RotatE currently supports normalize_phases=True "
"only when a lookup embedder is used for relations; "
"current relation embedder is "
f"{self.get_option('relation_embedder.type')} "
"however"
)
# just to be sure it's right initially
job.pre_run_hooks.append(lambda job: self.normalize_phases())
# normalize after each batch
job.post_batch_hooks.append(lambda job: self.normalize_phases())
@torch.jit.script
def pairwise_sum(X, Y):
"""Compute pairwise sum of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) + Y
@torch.jit.script
def pairwise_diff(X, Y):
"""Compute pairwise difference of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) - Y
@torch.jit.script
def pairwise_hadamard(X, Y):
"""Compute pairwise Hadamard product of rows of X and Y.
Returns tensor of shape len(X) x len(Y) x dim."""
return X.unsqueeze(1) * Y
@torch.jit.script
def hadamard_complex(x_re, x_im, y_re, y_im):
"Hadamard product for complex vectors"
result_re = x_re * y_re - x_im * y_im
result_im = x_re * y_im + x_im * y_re
return result_re, result_im
@torch.jit.script
def pairwise_hadamard_complex(x_re, x_im, y_re, y_im):
"Pairwise Hadamard product for complex vectors"
result_re = pairwise_hadamard(x_re, y_re) - pairwise_hadamard(x_im, y_im)
result_im = pairwise_hadamard(x_re, y_im) + pairwise_hadamard(x_im, y_re)
return result_re, result_im
@torch.jit.script
def diff_complex(x_re, x_im, y_re, y_im):
"Difference of complex vectors"
return x_re - y_re, x_im - y_im
@torch.jit.script
def pairwise_diff_complex(x_re, x_im, y_re, y_im):
"Pairwise difference of complex vectors"
return pairwise_diff(x_re, y_re), pairwise_diff(x_im, y_im)
@torch.jit.script
def abs_complex(x_re, x_im):
"Compute magnitude of given complex numbers"
x_re_im = torch.stack((x_re, x_im), dim=0) # dim0: real, imaginary
return torch.norm(x_re_im, dim=0) # sqrt(real^2+imaginary^2)
@torch.jit.script
def norm_nonnegative(x, dim: int, p: float):
"Computes lp-norm along dim assuming that all inputs are non-negative."
if p == 1.0:
# speed up things for this common case. We known that the inputs are
# non-negative here.
return torch.sum(x, dim=dim)
else:
return torch.norm(x, dim=dim, p=p)