In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.patches as mpatches
import numpy as np
import cv2
from PIL import Image

import onnx
import onnxruntime as rt

from src.training.data import JetbotDataset
from src.training.transforms import HalfCrop

In [None]:
onnx_model = onnx.load("sharp_loss_model.onnx")
onnx.checker.check_model(onnx_model)
sess = rt.InferenceSession("sharp_loss_model.onnx", providers=rt.get_available_providers())

input_name = sess.get_inputs()[0].name
print("input name", input_name)
input_shape = sess.get_inputs()[0].shape
print("input shape", input_shape)
input_type = sess.get_inputs()[0].type
print("input type", input_type)
label_name = sess.get_outputs()[0].name

In [None]:
data = JetbotDataset("./data/dataset/")
labs = np.array(data.labels)

In [None]:
left = sum(labs[:,1] > 0)
right = sum(labs[:,1] < 0)
forward = sum(labs[:,1] == 0)

In [None]:
(left+right)/len(labs)

In [None]:
def mae(y,y_pred):
	return np.mean(np.abs(y_pred-y))

In [None]:
random  = np.random.uniform(-1,1,size=(len(labs),2))
random[:,0] = 1

In [None]:
random

In [None]:
mae(labs,random)

In [None]:
def display_img(img,label):
	img = np.transpose(img,axes=(1,2,0))
	plt.imshow(img)
	forward, left = label
	plt.title(f"Forward {forward} Left {left}")
	plt.show()

def half_image(img):
	return HalfCrop(224)(img,None)[0].numpy()

def display_img_with_pred(img,label,pred):
	img = half_image(img)
	img = np.transpose(img,axes=(1,2,0))
	plt.imshow(img)
	forward, left = label
	pf,pl = pred
	plt.title(f"Forward {forward} Left {left}\n Predictions\nForward {pf} Left {pl}")
	plt.show()
	# cv2.imshow("in", cv2.cvtColor(np.transpose(img.numpy(),(1,2,0)),cv2.COLOR_BGR2RGB))
	# cv2.waitKey(0)

In [None]:
def preprocess(img):
	preproc = half_image(img)
	img = np.transpose(img,axes=(1,2,0))
	img = cv2.cvtColor(img.numpy(),cv2.COLOR_BGR2RGB)
	img = np.transpose(img,axes=(2,0,1))
	preproc = preproc[None,:,:,:]
	preproc = np.transpose(preproc,axes=(0,1,3,2))
	return preproc.astype(np.float32)/255

In [None]:
img, label, _ = data[5795]

display_img(img,label)

In [None]:
preproc = preprocess(img)
out = sess.run([label_name],{input_name:preproc})

display_img_with_pred(img,label,out[0][0])

In [None]:
def create_data(start=5000,frames=300):
    labels = []
    preds = []
    images = []

    for i in range(start,start+frames):
        img,label,_ = data[i]
        labels.append(label.numpy())
        images.append(np.transpose(img.numpy(),axes=(1,2,0)))#cv2.cvtColor(np.transpose(img.numpy(),axes=(1,2,0)),cv2.COLOR_BGR2RGB))
        preproc = preprocess(img)

        pred = sess.run([label_name],{input_name:preproc})

        preds.append(pred[0][0])
    return labels,preds,images

In [None]:
def create_gif(images,labels,preds,name="animation.gif",start=5000):
    frames = len(images)
    fig, ax = plt.subplots()
    ax.set_title(f"Images {start}-{start+frames}")
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    img_display = ax.imshow(images[0])


    ax.legend(handles=[
        mpatches.Patch(color='red', label='Ground truth'),
        mpatches.Patch(color='blue', label='Prediction'),
        mpatches.Patch(color='lime', label='Model sees below this line')
        ])

    ax.plot([0,223],[112,112],color="lime")
    arrows = []
    numbers = []
    def update(frame):
        img_display.set_array(images[frame])
        for arrow in arrows:
            arrow.remove()
        arrows.clear()
        for number in numbers:
            number.remove()
        numbers.clear()

        forward,left = labels[frame]
        pfor,pleft = preds[frame]

        specs = [
            [0,forward,"red",0.2],
            [left,0,"red",0.2],
            [0,pfor,"blue",0.8],
            [pleft,0,"blue",0.8]]

        length = 0.1
        for left,forward,col,x in specs:
            arrows.append(ax.arrow(x,0.2, -length*left,length*forward, head_width=0.025, head_length=0.05, fc=col, ec=col, transform=ax.transAxes))
            numbers.append(ax.text(x-(length+0.05)*left, 0.2+(length+0.05)*forward, f"{forward+left:.4f}", color='black', ha='center', va='center', transform=ax.transAxes))
        return [img_display] + arrows + numbers

    ani = animation.FuncAnimation(fig, update, frames=len(images), blit=True)
    ani.save(name, writer='pillow', fps=12)

In [None]:
def animate(start=5000,frames=300,name="anim.gif"):
    labs,preds,imgs = create_data(start,frames)
    create_gif(imgs,labs,preds,name,start)

In [None]:
animate()

In [None]:
scenes = [
    [5780,200],
    [2137,200],
    [6150,300],
]

for s,f in scenes:
    animate(s,f,f"scene_{s}_{f}.gif")
    print("Created gif",s,f)