Skip to content

Commit

Permalink
Fix multihead attention embedding dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasZahradnik committed Jan 8, 2023
1 parent d6f895f commit 38d8744
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions neuralogic/nn/module/general/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(
self.embed_dim = embed_dim
self.num_heads = num_heads
self.output_name = output_name
self.query_name = query_name
self.key_name = key_name
self.value_name = value_name
self.queries = query_name
self.keys = key_name
self.values = value_name
self.vdim = vdim if vdim is not None else embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.mask_name = mask_name
Expand Down Expand Up @@ -118,21 +118,21 @@ def __call__(self):

for i in range(self.num_heads):
meta = [Transformation.SLICE(rows=(i * size, (i + 1) * size))]
multihead_rules.append((q_proj(i, *terms) <= R.get(self.query_name)(terms)[q_weight:dim, dim]) | meta)
multihead_rules.append((v_proj(i, *terms) <= R.get(self.value_name)(terms)[v_weight:dim, dim]) | meta)
multihead_rules.append((k_proj(i, *terms) <= R.get(self.key_name)(terms)[k_weight:dim, dim]) | meta)
multihead_rules.append((q_proj(i, *terms) <= R.get(self.queries)(terms)[q_weight:dim, dim]) | meta)
multihead_rules.append((v_proj(i, *terms) <= R.get(self.values)(terms)[v_weight:dim, self.vdim]) | meta)
multihead_rules.append((k_proj(i, *terms) <= R.get(self.keys)(terms)[k_weight:dim, self.kdim]) | meta)
attention_concat.append(R.get(attention_name)(i, *terms))

multihead_rules.append(
(output_rel(terms)[dim, dim] <= attention_concat) | [Transformation.IDENTITY, Combination.CONCAT]
)
else:
multihead_rules = [
(q_proj(terms)[q_weight:dim, dim] <= R.get(self.query_name)(terms)) | [Transformation.IDENTITY],
(q_proj(terms)[q_weight:dim, dim] <= R.get(self.queries)(terms)) | [Transformation.IDENTITY],
q_proj / self.arity | [Transformation.IDENTITY],
(v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.value_name)(terms)) | [Transformation.IDENTITY],
(v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.values)(terms)) | [Transformation.IDENTITY],
v_proj / self.arity | [Transformation.IDENTITY],
(k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.key_name)(terms)) | [Transformation.IDENTITY],
(k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.keys)(terms)) | [Transformation.IDENTITY],
k_proj / self.arity | [Transformation.IDENTITY],
(output_rel(terms)[dim, dim] <= R.get(attention_name)(terms)) | [Transformation.IDENTITY],
output_rel / self.arity | [Transformation.IDENTITY],
Expand Down

0 comments on commit 38d8744

Please sign in to comment.