From 58305aa819c8ffc02ca04518ee9437aae086329a Mon Sep 17 00:00:00 2001 From: yueshuangyan Date: Wed, 7 Jul 2021 20:13:52 +0800 Subject: [PATCH] fix(converter): fix name attr visit of fetch_targets --- packages/paddlejs-converter/convertModel.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/packages/paddlejs-converter/convertModel.py b/packages/paddlejs-converter/convertModel.py index 5ddd1ff0..207b51ad 100644 --- a/packages/paddlejs-converter/convertModel.py +++ b/packages/paddlejs-converter/convertModel.py @@ -42,6 +42,18 @@ # 在转换过程中新生成的、需要添加到vars中的variable appendedVarList = [] +class ObjDict(dict): + """ + Makes a dictionary behave like an object,with attribute-style access. + """ + def __getattr__(self,name): + try: + return self[name] + except: + raise AttributeError(name) + def __setattr__(self,name,value): + self[name]=value + def validateShape(shape, name): """检验shape长度,超过4则截断""" if len(shape) > 4: @@ -323,7 +335,7 @@ def appendConnectOp(fetch_targets): # 从fetch_targets中提取输出算子信息 for target in fetch_targets: - name = target['name'] + name = target.name curVar = fluid.global_scope().find_var(name) curTensor = np.array(curVar.get_tensor()) shape = list(curTensor.shape) @@ -409,7 +421,8 @@ def convertToPaddleJSModel(): for input, value in op['inputs'].items(): if len(value) <= 0: continue - cur = {'name': value[0]} + cur = ObjDict() + cur.name = value[0] inputNames.append(cur) targets = appendConnectOp(inputNames) # op['inputs'] = targets