diff --git a/neuralogic/nn/module/general/attention.py b/neuralogic/nn/module/general/attention.py index 12742b14..5467dbe2 100644 --- a/neuralogic/nn/module/general/attention.py +++ b/neuralogic/nn/module/general/attention.py @@ -73,9 +73,9 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads self.output_name = output_name - self.query_name = query_name - self.key_name = key_name - self.value_name = value_name + self.queries = query_name + self.keys = key_name + self.values = value_name self.vdim = vdim if vdim is not None else embed_dim self.kdim = kdim if kdim is not None else embed_dim self.mask_name = mask_name @@ -118,9 +118,9 @@ def __call__(self): for i in range(self.num_heads): meta = [Transformation.SLICE(rows=(i * size, (i + 1) * size))] - multihead_rules.append((q_proj(i, *terms) <= R.get(self.query_name)(terms)[q_weight:dim, dim]) | meta) - multihead_rules.append((v_proj(i, *terms) <= R.get(self.value_name)(terms)[v_weight:dim, dim]) | meta) - multihead_rules.append((k_proj(i, *terms) <= R.get(self.key_name)(terms)[k_weight:dim, dim]) | meta) + multihead_rules.append((q_proj(i, *terms) <= R.get(self.queries)(terms)[q_weight:dim, dim]) | meta) + multihead_rules.append((v_proj(i, *terms) <= R.get(self.values)(terms)[v_weight:dim, self.vdim]) | meta) + multihead_rules.append((k_proj(i, *terms) <= R.get(self.keys)(terms)[k_weight:dim, self.kdim]) | meta) attention_concat.append(R.get(attention_name)(i, *terms)) multihead_rules.append( @@ -128,11 +128,11 @@ def __call__(self): ) else: multihead_rules = [ - (q_proj(terms)[q_weight:dim, dim] <= R.get(self.query_name)(terms)) | [Transformation.IDENTITY], + (q_proj(terms)[q_weight:dim, dim] <= R.get(self.queries)(terms)) | [Transformation.IDENTITY], q_proj / self.arity | [Transformation.IDENTITY], - (v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.value_name)(terms)) | [Transformation.IDENTITY], + (v_proj(terms)[v_weight:dim, self.vdim] <= R.get(self.values)(terms)) | [Transformation.IDENTITY], v_proj / self.arity | [Transformation.IDENTITY], - (k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.key_name)(terms)) | [Transformation.IDENTITY], + (k_proj(terms)[k_weight:dim, self.kdim] <= R.get(self.keys)(terms)) | [Transformation.IDENTITY], k_proj / self.arity | [Transformation.IDENTITY], (output_rel(terms)[dim, dim] <= R.get(attention_name)(terms)) | [Transformation.IDENTITY], output_rel / self.arity | [Transformation.IDENTITY],