Skip to content

Commit

Permalink
Squeeze templates
Browse files Browse the repository at this point in the history
  • Loading branch information
gahdritz committed Aug 19, 2022
1 parent 913903e commit 2986436
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion openfold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe):
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)

Expand Down
2 changes: 1 addition & 1 deletion openfold/model/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def embed_templates_offload(
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)

Expand Down

0 comments on commit 2986436

Please sign in to comment.