Skip to content
This repository has been archived by the owner on Feb 25, 2022. It is now read-only.

Commit

Permalink
fix sampling error when recompute_grad = True
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Sep 17, 2020
1 parent ab942b1 commit 8415e56
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion models/gpt2/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ def fn(x):
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None):
"""A GPT style model implemented in mesh tensorflow."""
results = {}
recompute_grad = params["recompute_grad"] == True # if true, enable gradient checkpointing
# if true and in train mode, enable gradient checkpointing
recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True
use_axial_pos_emb = params["axial_pos_emb"] != None
no_weight_tie_emb = params["no_weight_tie"] == True
share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True
Expand Down

0 comments on commit 8415e56

Please sign in to comment.