In [3]:
import sys
import numpy as np

from ais_bench.infer.interface import InferSession


# 最短运行样例
def infer_simple():
    device_id = 0
    session = InferSession(device_id, './v14/nnunet.om')

    # create new numpy data according inputs info
    # barray = bytearray(session.get_inputs()[0].realsize)
    # ndata = np.frombuffer(barray)
    ndata = np.load('./test_input.npy')
    

    # in is numpy list and ouput is numpy list
    outputs = session.infer([ndata])
    print("outputs:{} type:{}".format(outputs, type(outputs)))

    # print("static infer avg:{} ms".format(np.mean(session.sumary().exec_time_list)))

def infer_torch_tensor():
    import torch
    device_id = 0
    session = InferSession(device_id, model_path)
    # create continuous torch tensor
    torchtensor = torch.zeros([1,3,256,256], out=None, dtype=torch.uint8)
    # in is torch tensor and ouput is numpy list
    outputs = session.infer([torchtensor])
    print("in torch tensor outputs[0].shape:{} type:{}".format(outputs[0].shape, type(outputs)))

    # create discontinuous torch tensor
    torchtensor = torch.zeros([1,256,3,256], out=None, dtype=torch.uint8)
    torchtensor_discontinue = torchtensor.permute(0,2,1,3)

    # in is discontinuous tensor list and ouput is numpy list
    outputs = session.infer([torchtensor_discontinue])
    print("in discontinuous torch tensor outputs[0].shape:{} type:{}".format(outputs[0].shape, type(outputs)))

    print("static infer avg:{} ms".format(np.mean(session.sumary().exec_time_list)))

def infer_dymshape():
    device_id = 0
    session = InferSession(device_id, model_path)

    ndata = np.zeros([1,3,224,224], dtype=np.float32)

    mode = "dymshape"
    # input args custom_sizes is int
    outputSize = 100000
    outputs = session.infer([ndata], mode, custom_sizes=outputSize)
    print("inputs: custom_sizes: {} outputs:{} type:{}".format(outputSize, outputs, type(outputs)))

    # input args custom_sizes is list
    outputSizes = [100000]
    outputs = session.infer([ndata], mode, custom_sizes=outputSizes)
    print("inputs: custom_sizes: {} outputs:{} type:{}".format(outputSizes, outputs, type(outputs)))
    print("dymshape infer avg:{} ms".format(np.mean(session.sumary().exec_time_list)))

def infer_dymdims():
    device_id = 0
    session = InferSession(device_id, model_path)

    ndata = np.zeros([1,3,224,224], dtype=np.float32)

    mode = "dymdims"
    outputs = session.infer([ndata], mode)
    print("outputs:{} type:{}".format(outputs, type(outputs)))

    print("dymdims infer avg:{} ms".format(np.mean(session.sumary().exec_time_list)))

# 获取模型信息
def get_model_info():
    device_id = 0
    session = InferSession(device_id, model_path)

    # 方法2 直接打印session 也可以获取模型信息
    print(session.session)

    # 方法3 也可以直接通过get接口去获取
    intensors_desc = session.get_inputs()
    for i, info in enumerate(intensors_desc):
        print("input info i:{} shape:{} type:{} val:{} realsize:{} size:{}".format(
            i, info.shape, info.datatype, int(info.datatype), info.realsize, info.size))

    intensors_desc = session.get_outputs()
    for i, info in enumerate(intensors_desc):
        print("outputs info i:{} shape:{} type:{} val:{} realsize:{} size:{}".format(
            i, info.shape, info.datatype, int(info.datatype), info.realsize, info.size))

In [4]:
infer_simple()

outputs:[array([[[[[ 5.15625  ,  5.9140625,  5.84375  , ...,  6.4804688,
            6.3671875,  6.125    ],
          [ 5.6171875,  6.1523438,  5.9492188, ...,  5.2617188,
            5.3046875,  4.6328125],
          [ 5.4140625,  5.8242188,  5.625    , ...,  6.1875   ,
            5.9765625,  5.4726562],
          ...,
          [ 5.5195312,  6.0742188,  5.8515625, ...,  6.6289062,
            6.4414062,  5.5664062],
          [ 5.6289062,  6.1796875,  5.8203125, ...,  6.21875  ,
            5.9609375,  5.2304688],
          [ 5.28125  ,  5.4257812,  5.1796875, ...,  5.5390625,
            5.359375 ,  4.921875 ]],

         [[ 5.6054688,  6.3125   ,  5.6835938, ...,  6.9453125,
            6.7695312,  6.3359375],
          [ 6.28125  ,  6.2617188,  6.0898438, ...,  5.4648438,
            5.4765625,  4.7148438],
          [ 5.7460938,  5.8789062,  5.7382812, ...,  6.265625 ,
            6.0703125,  5.3554688],
          ...,
          [ 6.0625   ,  6.4023438,  6.0585938, ...,  6.8554