Skip to content

Commit

Permalink
Provide possibility to add provider
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanja Bayer committed Aug 24, 2021
1 parent cdc3d4e commit 6538da9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
18 changes: 9 additions & 9 deletions python-package/insightface/app/face_analysis.py
Expand Up @@ -6,23 +6,22 @@


from __future__ import division
import collections
import numpy as np

import glob
import os
import os.path as osp
from numpy.linalg import norm

import numpy as np
import onnxruntime
from numpy.linalg import norm

from ..model_zoo import model_zoo
from ..utils import face_align
from ..utils import ensure_available
from ..utils import DEFAULT_MP_NAME, ensure_available
from .common import Face
from ..utils import DEFAULT_MP_NAME

__all__ = ['FaceAnalysis']

class FaceAnalysis:
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None):
def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
onnxruntime.set_default_logger_severity(3)
self.models = {}
self.model_dir = ensure_available('models', name, root=root)
Expand All @@ -32,7 +31,8 @@ def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=
if onnx_file.find('_selfgen_')>0:
#print('ignore:', onnx_file)
continue
model = model_zoo.get_model(onnx_file)
model = model_zoo.get_model(onnx_file, **kwargs)

if model is None:
print('model not recognized:', onnx_file)
elif allowed_modules is not None and model.taskname not in allowed_modules:
Expand Down
9 changes: 5 additions & 4 deletions python-package/insightface/model_zoo/model_zoo.py
Expand Up @@ -22,12 +22,13 @@ class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file

def get_model(self):
session = onnxruntime.InferenceSession(self.onnx_file, None)
def get_model(self, **kwargs):
session = onnxruntime.InferenceSession(self.onnx_file, **kwargs)
print(f'Applied providers: {session._providers}, with options: {session._provider_options}')
input_cfg = session.get_inputs()[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
#print(input_shape)

if len(outputs)>=5:
return SCRFD(model_file=self.onnx_file, session=session)
elif input_shape[2]==112 and input_shape[3]==112:
Expand Down Expand Up @@ -66,6 +67,6 @@ def get_model(name, **kwargs):
assert osp.exists(model_file), 'model_file should exist'
assert osp.isfile(model_file), 'model_file should be file'
router = ModelRouter(model_file)
model = router.get_model()
model = router.get_model(providers=kwargs.get('providers'), provider_options=kwargs.get('provider_options'))
return model

0 comments on commit 6538da9

Please sign in to comment.