In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import cv2

import onnx
import onnxruntime as rt

from src.training.data import JetbotDataset
from src.training.transforms import HalfCrop

In [None]:
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
sess = rt.InferenceSession("model.onnx", providers=rt.get_available_providers())

input_name = sess.get_inputs()[0].name
print("input name", input_name)
input_shape = sess.get_inputs()[0].shape
print("input shape", input_shape)
input_type = sess.get_inputs()[0].type
print("input type", input_type)
label_name = sess.get_outputs()[0].name

In [None]:
data = JetbotDataset("./data/dataset/")
labs = np.array(data.labels)

In [None]:
left = sum(labs[:,1] > 0)
right = sum(labs[:,1] < 0)
forward = sum(labs[:,1] == 0)

In [None]:
(left+right)/len(labs)

In [None]:
def mae(y,y_pred):
	return np.mean(np.abs(y_pred-y))

In [None]:
random  = np.random.uniform(-1,1,size=(len(labs),2))
random[:,0] = 1

In [None]:
random

In [None]:
mae(labs,random)

In [None]:
def display_img(img,label):
	img = np.transpose(img,axes=(1,2,0))
	plt.imshow(img)
	forward, left = label
	plt.title(f"Forward {forward} Left {left}")
	plt.show()

def half_image(img):
	return HalfCrop(224)(img,None)[0].numpy()

def display_img_with_pred(img,label,pred):
	img = half_image(img)
	img = np.transpose(img,axes=(1,2,0))
	plt.imshow(img)
	forward, left = label
	pf,pl = pred
	plt.title(f"Forward {forward} Left {left}\n Predictions\nForward {pf} Left {pl}")
	plt.show()
	# cv2.imshow("in", cv2.cvtColor(np.transpose(img.numpy(),(1,2,0)),cv2.COLOR_BGR2RGB))
	# cv2.waitKey(0)

In [None]:
def preprocess(img):
	preproc = half_image(img).astype(np.float32)/255
	img = np.transpose(img,axes=(1,2,0))
	img = cv2.cvtColor(img.numpy(),cv2.COLOR_BGR2RGB)
	img = np.transpose(img,axes=(2,0,1))
	preproc = preproc[None,:,:,:]
	preproc = np.transpose(preproc,axes=(0,1,3,2))
	return preproc

In [None]:
img, label, _ = data[5720]

display_img(img,label)

In [None]:
preproc = preprocess(img)
out = sess.run([label_name],{input_name:preproc})

display_img_with_pred(img,label,out[0][0])