Skip to content

Commit

Permalink
Merge pull request #295 from nikitos9000/improve_fp16_stability
Browse files Browse the repository at this point in the history
Improve TriangularMultiplicativeUpdate stability in fp16 mode
  • Loading branch information
gahdritz committed Apr 10, 2023
2 parents ee5d2c3 + 6625e8d commit 208cce6
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions openfold/model/triangular_multiplicative_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,13 @@ def forward(self,
b = mask
b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z)

if(is_fp16_enabled()):

# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a = a / a.std()
b = b / b.std()

if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
Expand Down

0 comments on commit 208cce6

Please sign in to comment.