Skip to content

Commit

Permalink
Allow model path override for bert (mlcommons#1279)
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunsuresh committed Nov 7, 2022
1 parent 7937922 commit b9c5daa
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion language/bert/onnxruntime_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, args):
print("Loading ONNX model...")
self.quantized = args.quantized

model_path = os.environ.get("MODEL_FILE")
model_path = os.environ.get("ML_MODEL_FILE_WITH_PATH")
if not model_path:
if self.quantized:
model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/bert_large_v1_1_fake_quant.onnx"
Expand Down
3 changes: 2 additions & 1 deletion language/bert/pytorch_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self, args):
self.model = BertForQuestionAnswering(config)
self.model.to(self.dev)
self.model.eval()
self.model.load_state_dict(torch.load("build/data/bert_tf_v1_1_large_fp32_384_v2/model.pytorch"), strict=True)
model_file = os.environ.get("ML_MODEL_FILE_WITH_PATH", "build/data/bert_tf_v1_1_large_fp32_384_v2/model.pytorch")
self.model.load_state_dict(torch.load(model_file), strict=True)

print("Constructing SUT...")
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries)
Expand Down
4 changes: 2 additions & 2 deletions language/bert/tf_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(self, args):
if 'TF_INTER_OP_PARALLELISM_THREADS' in os.environ else os.cpu_count()
infer_config.use_per_session_threads = 1
self.sess = tf.compat.v1.Session(config=infer_config)

with gfile.FastGFile('build/data/bert_tf_v1_1_large_fp32_384_v2/model.pb', 'rb') as f:
model_file = os.environ.get('ML_MODEL_FILE_WITH_PATH', 'build/data/bert_tf_v1_1_large_fp32_384_v2/model.pb')
with gfile.FastGFile(model_file, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
self.sess.graph.as_default()
Expand Down
2 changes: 1 addition & 1 deletion language/bert/tf_estimator_SUT.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, batch_size=8):

model_fn = self.model_fn_builder(
bert_config=bert_config,
init_checkpoint="build/data/bert_tf_v1_1_large_fp32_384_v2/model.ckpt-5474")
init_checkpoint=os.environ.get("ML_MODEL_FILE_WITH_PATH", "build/data/bert_tf_v1_1_large_fp32_384_v2/model.ckpt-5474"))

self.estimator = tf.estimator.Estimator(model_fn=model_fn)
self.batch_size = batch_size
Expand Down

0 comments on commit b9c5daa

Please sign in to comment.