In [1]:
import numpy as np
import torch
import torchvision
import cv2
import matplotlib.pyplot as plt
import  torchvision.transforms as transforms
import timm
import json

In [2]:
with open('classes.json', 'r') as f:
    classes = json.load(f)

In [3]:
model = timm.create_model(model_name='efficientnet_b4', pretrained=True)
model.eval()

EfficientNet(
  (conv_stem): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SiLU(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
        (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(48, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): Identity()
 

In [4]:
# model = torchvision.models.mobilenet_v3_small(pretrained=False, progress=False, )
# model.state_dict = torch.load('../model/mobilenet_v3_small.pth')
# model.eval()

In [5]:
batch_size = 1
channels = 3
height = 1024
width = 1024
# image = cv2.imread('../../../images/swin.jpeg')
image = cv2.imread('../../../images/n01491361_tiger_shark.jpeg')
resized_image = cv2.resize(image, (height, width))


In [6]:
transform = transforms.ToTensor()
prediction = torch.argmax(model.float().forward(torch.unsqueeze(transform(resized_image), 0)).data).numpy()
classes[prediction]

'tiger shark'

In [7]:
transform(resized_image)

tensor([[[0.5176, 0.5176, 0.5176,  ..., 0.4510, 0.4549, 0.4588],
         [0.5176, 0.5176, 0.5176,  ..., 0.4549, 0.4549, 0.4588],
         [0.5176, 0.5176, 0.5176,  ..., 0.4510, 0.4510, 0.4549],
         ...,
         [0.3961, 0.3961, 0.3961,  ..., 0.3294, 0.3333, 0.3373],
         [0.3961, 0.3961, 0.3961,  ..., 0.3294, 0.3373, 0.3412],
         [0.3961, 0.3961, 0.3961,  ..., 0.3294, 0.3373, 0.3412]],

        [[0.4118, 0.4118, 0.4118,  ..., 0.3294, 0.3333, 0.3373],
         [0.4118, 0.4118, 0.4118,  ..., 0.3333, 0.3333, 0.3373],
         [0.4118, 0.4118, 0.4118,  ..., 0.3294, 0.3294, 0.3333],
         ...,
         [0.2667, 0.2667, 0.2706,  ..., 0.2706, 0.2784, 0.2824],
         [0.2667, 0.2667, 0.2706,  ..., 0.2706, 0.2824, 0.2863],
         [0.2667, 0.2667, 0.2706,  ..., 0.2706, 0.2824, 0.2863]],

        [[0.3294, 0.3294, 0.3294,  ..., 0.2000, 0.2039, 0.2078],
         [0.3294, 0.3294, 0.3294,  ..., 0.2000, 0.2039, 0.2078],
         [0.3294, 0.3294, 0.3294,  ..., 0.2000, 0.2000, 0.

In [8]:
transform = transforms.ToTensor()
prediction = torch.argmax(model.double().forward(torch.unsqueeze(transform(resized_image / 255.), 0)).data).numpy()
classes[prediction]

'tiger shark'

In [10]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
prediction = torch.argmax(model.double().forward(torch.unsqueeze(transform(resized_image / 255.), 0)).data).numpy()
classes[prediction]

'tiger shark'

In [12]:
transform(resized_image)

tensor([[[ 0.1426,  0.1426,  0.1426,  ..., -0.1486, -0.1314, -0.1143],
         [ 0.1426,  0.1426,  0.1426,  ..., -0.1314, -0.1314, -0.1143],
         [ 0.1426,  0.1426,  0.1426,  ..., -0.1486, -0.1486, -0.1314],
         ...,
         [-0.3883, -0.3883, -0.3883,  ..., -0.6794, -0.6623, -0.6452],
         [-0.3883, -0.3883, -0.3883,  ..., -0.6794, -0.6452, -0.6281],
         [-0.3883, -0.3883, -0.3883,  ..., -0.6794, -0.6452, -0.6281]],

        [[-0.1975, -0.1975, -0.1975,  ..., -0.5651, -0.5476, -0.5301],
         [-0.1975, -0.1975, -0.1975,  ..., -0.5476, -0.5476, -0.5301],
         [-0.1975, -0.1975, -0.1975,  ..., -0.5651, -0.5651, -0.5476],
         ...,
         [-0.8452, -0.8452, -0.8277,  ..., -0.8277, -0.7927, -0.7752],
         [-0.8452, -0.8452, -0.8277,  ..., -0.8277, -0.7752, -0.7577],
         [-0.8452, -0.8452, -0.8277,  ..., -0.8277, -0.7752, -0.7577]],

        [[-0.3404, -0.3404, -0.3404,  ..., -0.9156, -0.8981, -0.8807],
         [-0.3404, -0.3404, -0.3404,  ..., -0

In [11]:
dummy_input = torch.randn(1, 3, 1024, 1024)
torch.onnx.export(model, dummy_input, 'efficientnet_b4.onnx', export_params=True)

In [12]:
import onnx
onnx_model = onnx.load('efficientnet_b4.onnx')
onnx.checker.check_model(onnx_model)

In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("efficientnet_b4.onnx")
