Skip to content

Commit

Permalink
Merge pull request #331 from dingquanyu/multimer
Browse files Browse the repository at this point in the history
Batch size fix for tm computation
  • Loading branch information
christinaflo committed Jul 19, 2023
2 parents 51556d5 + 8569353 commit 8332aa0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
15 changes: 11 additions & 4 deletions openfold/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,13 @@ def compute_tm(

n = residue_weights.shape[-1]
pair_mask = residue_weights.new_ones((n, n), dtype=torch.int32)
if interface:
if interface and (asym_id is not None):
if len(asym_id.shape)>1:
assert len(asym_id.shape)<=2
batch_size = asym_id.shape[0]
pair_mask = residue_weights.new_ones((batch_size,n, n), dtype=torch.int32)
pair_mask *= (asym_id[..., None] != asym_id[..., None, :]).to(dtype=pair_mask.dtype)

predicted_tm_term *= pair_mask

pair_residue_weights = pair_mask * (
Expand Down Expand Up @@ -1440,7 +1444,10 @@ def violation_loss(
+ l_clash
)

return loss
# Average over the batch dimension
mean = torch.mean(loss)

return mean


def compute_renamed_ground_truth(
Expand Down Expand Up @@ -1563,7 +1570,7 @@ def experimentally_resolved_loss(
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)).unsqueeze(-1))
loss = torch.sum(loss, dim=-1)

loss = loss * (
Expand Down
2 changes: 1 addition & 1 deletion openfold/utils/validation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def prep_d(structure):
if(mask is not None):
drmsd = drmsd * (mask[..., None] * mask[..., None, :])
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
n = d1.shape[-1] if mask is None else torch.min(torch.sum(mask, dim=-1))
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)

Expand Down
5 changes: 3 additions & 2 deletions train_openfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def _log(self, loss_breakdown, batch, outputs, train=True):
)

for k,v in other_metrics.items():
assert(len(v.shape) == 1)
self.log(
f"{phase}/{k}",
v,
f"{phase}/{k}",
torch.mean(v),
on_step=False, on_epoch=True, logger=True
)

Expand Down

0 comments on commit 8332aa0

Please sign in to comment.