Skip to content

Commit

Permalink
Merge pull request #1475 from axinc-ai/e5_opset17
Browse files Browse the repository at this point in the history
Implement opset 17 version for e5
  • Loading branch information
kyakuno committed May 26, 2024
2 parents 3380562 + 0cf0247 commit 96d5b67
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
1 change: 1 addition & 0 deletions natural_language_processing/multilingual-e5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ ONNX opset=11
## Netron

[multilingual-e5-base.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/multilingual-e5/multilingual-e5-base.onnx.prototxt)
[multilingual-e5-base.opt.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/multilingual-e5/multilingual-e5-base.opt.onnx.prototxt)
[multilingual-e5-large.onnx.prototxt](https://netron.app/?url=https://storage.googleapis.com/ailia-models/multilingual-e5/multilingual-e5-large.onnx.prototxt)
31 changes: 24 additions & 7 deletions natural_language_processing/multilingual-e5/multilingual-e5.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,16 @@
action='store_true',
help='execute onnxruntime version.'
)
parser.add_argument(
'--opt',
action='store_true',
help='use opset17 version.'
)
args = update_parser(parser)

if args.opt:
WEIGHT_BASE_PATH = 'multilingual-e5-base.opt.onnx'
MODEL_BASE_PATH = 'multilingual-e5-base.opt.onnx.prototxt'

# ======================
# Secondaty Functions
Expand Down Expand Up @@ -99,8 +107,8 @@ def closest_sentence(embs, q_emb):
# Main functions
# ======================

def predict(models, sentences):
input_texts = ['query: {}'.format(t) for t in sentences]
def predict(models, sentences, header):
input_texts = [header + ': {}'.format(t) for t in sentences]

tokenizer = models['tokenizer']
batch_dict = tokenizer(
Expand Down Expand Up @@ -138,18 +146,21 @@ def recognize_from_sentence(models):
logger.info("Generating embeddings...")
if args.benchmark:
logger.info('BENCHMARK mode')
total = 0
for i in range(5):
start = int(round(time.time() * 1000))
embs = predict(models, sentences)
embs = predict(models, sentences, "passage")
end = int(round(time.time() * 1000))
logger.info(f'\tailia processing time {end - start} ms')
exit()
total = total + end - start
logger.info(f'average time {total / 5} ms\n')
return
else:
embs = predict(models, sentences)
embs = predict(models, sentences, "passage")

# check prompt from command line argument
if prompt is not None:
prompt_emb = predict(models, [prompt])
prompt_emb = predict(models, [prompt], "query")

idx, sim = closest_sentence(embs, prompt_emb)

Expand All @@ -160,7 +171,7 @@ def recognize_from_sentence(models):
# application
prompt = input('User (press q to exit): ')
while prompt not in ('q', 'q'):
prompt_emb = predict(models, [prompt])
prompt_emb = predict(models, [prompt], "query")

idx, sim = closest_sentence(embs, prompt_emb)

Expand Down Expand Up @@ -199,8 +210,14 @@ def main():
"tokenizer": tokenizer,
}

if args.profile:
net.set_profile_mode(True)

recognize_from_sentence(models)

if args.profile:
print(net.get_summary())


if __name__ == '__main__':
main()

0 comments on commit 96d5b67

Please sign in to comment.