Skip to content

Commit

Permalink
Cast weights from fp16 to fp32 after loadin.g
Browse files Browse the repository at this point in the history
  • Loading branch information
wzzju committed Jan 8, 2021
1 parent 969939e commit bffb5a9
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion PaddleNLP/benchmark/transformer/static/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@
logger = logging.getLogger(__name__)


def cast_parameters_to_fp32(place, program, scope=None):
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())

var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
tensor = var_scope.find_var(param.name).get_tensor()
if 'FP16' in str(tensor._dtype()) and 'FP32' in str(param.dtype):
data = np.array(tensor)
tensor.set(np.float32(data), place)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -56,7 +69,7 @@ def do_train(args):
place = paddle.set_device("cpu")

# Define data loader
test_loader, to_tokens = reader.create_infer_loader(args)
test_loader, to_tokens = reader.create_infer_loader(args, use_all_vocab=True)

test_program = paddle.static.Program()
startup_program = paddle.static.Program()
Expand Down Expand Up @@ -93,6 +106,9 @@ def do_train(args):
os.path.join(args.init_from_params, "transformer"), exe)
print("finish initing model from params from %s" % (args.init_from_params))

# cast weights from fp16 to fp32 after loading
cast_parameters_to_fp32(place, test_program)

f = open(args.output_file, "w")
for data in test_loader:
finished_sequence, = exe.run(test_program,
Expand Down

0 comments on commit bffb5a9

Please sign in to comment.