Skip to content

Commit

Permalink
Fix template masking bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Aug 17, 2022
1 parent 4d513bb commit e56b597
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
11 changes: 8 additions & 3 deletions openfold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,16 @@ def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
use_lma=self.globals.use_lma,
)

t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)

if(inplace_safe):
t *= (torch.sum(batch["template_mask"], dim=-1) > 0)
t *= t_mask
else:
t = t * (torch.sum(batch["template_mask"], dim=-1) > 0)
t = t * t_mask

ret = {}

Expand Down Expand Up @@ -380,7 +386,6 @@ def iteration(self, feats, prevs, _recycle=True):
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
_mask_trans=self.config._mask_trans,
)

Expand Down
1 change: 0 additions & 1 deletion openfold/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def verify_arg_order(fn, arg_list):
# Trim unspecified arguments
fn_arg_names = fn_arg_names[:len(arg_list)]
name_tups = list(zip(fn_arg_names, [n for n, _ in arg_list]))
print(name_tups)
assert(all([n1 == n2 for n1, n2 in name_tups]))

evoformer_attn_chunk_size = max(
Expand Down

0 comments on commit e56b597

Please sign in to comment.