In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from IPython.core.debugger import set_trace

In [None]:
from utils import get_X_and_Y, rgb_frame, make_rgb_movie, find_waypoints


DECIMATION = 2
BATCH_SIZE = 32

CAMERA_IDS = [
    'FrontSS', 'LeftSS', 'RightSS', 'RearSS', 'TopSS'
]

CLASSES_NAMES = [
    ['Roads', 'RoadLines'],
    
    ['None', 'Buildings', 'Fences', 'Other', 'Pedestrians',
     'Poles', 'Walls', 'TrafficSigns',
     'Vegetation', 'Sidewalks'],
    
    ['Vehicles'],
]

# CLASSES_NAMES = [
#     ['Roads', 'RoadLines'],
    
#     ['Sidewalks'],
    
#     ['Buildings'],
    
#     ['Fences', 'Other', 'Pedestrians',
#      'Poles', 'Walls', 'TrafficSigns'],
    
#     ['Vehicles'],
    
#     ['Vegetation'],
        
#     ['None']
# ]

In [None]:
import keras.backend as K
from keras.models import Model, load_model
from keras.layers import Input


multi_model = load_model('models/multi_model__sweep=23_decimation=2_numclasses=3_valloss=0.044.h5')

one_inp_shape = K.int_shape(multi_model.input[0])[1:]
inp = [
    multi_model.inputs[inp_idx] for inp_idx, inp_name in enumerate(CAMERA_IDS) if 'Top' not in inp_name
]

out = multi_model.get_layer('reconstruction').output

birds_view_model = Model(inp, out)

In [None]:
birds_view_model.summary()

In [None]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot, plot_model


plot_model(birds_view_model, to_file='images/birds_view_model.png')#, show_shapes=True)
SVG(model_to_dot(birds_view_model).create(prog='dot', format='svg'))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from utils import class_names_to_class_numbers, extract_observation_for_batch


racetrack = 'Town01'
episode = 102
decimation = DECIMATION
camera_ids = CAMERA_IDS
classes_names = CLASSES_NAMES
episode_len = 1000
batch_size = 32


storage = get_X_and_Y([racetrack], [episode], decimation, camera_ids)
X = [storage[id_] for id_ in camera_ids if 'Top' not in id_]
Y = [storage[id_] for id_ in camera_ids if 'Top' in id_][0]

classes_numbers = class_names_to_class_numbers(classes_names)

X_final, y_final = [[] for _ in range(4)], []
for index in range(episode_len):
    X_out, y_out = extract_observation_for_batch(X, Y, index, False, classes_numbers)
    for j in range(len(X_final)):
        X_final[j].append(X_out[j])
    y_final.append(y_out)

X_final = [np.stack(x) for x in X_final]
y_final = np.stack(y_final)

preds = birds_view_model.predict(X_final, batch_size=batch_size)

# BTW, let's evaluate the model
birds_view_model.compile(optimizer='adam', loss='categorical_crossentropy')
categorical_crossentropy = birds_view_model.evaluate(X_final, y_final)
print('Categorical cross entropy on this racetrack and episode: {:.4f}'.format(categorical_crossentropy))


fig, ax = plt.subplots()

ax.grid(False)

fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
ax.margins(0, 0)

for i in range(0, 1000, 100):
    frame = rgb_frame(preds, y_final, i, draw_waypoints=True)
    ax.clear()
    ax.imshow(frame, aspect='auto')
    ax.text(40, 177, 'actual', color='white', fontsize=24, fontweight='bold')
    ax.text(145, 177, 'predicted', color='white', fontsize=24, fontweight='bold')
    ax.axis('off')
    plt.show()

In [None]:
import time


times = []
for i in range(len(X_final[0])):
    X = [np.expand_dims(x[i], 0) for x in X_final]
    start = time.time()
    birds_view_model.predict(X)
    times.append(time.time() - start)
    
    
print('Mean prediction time: {:.4f}s'.format(np.mean(times)))

In [None]:
for racetrack in ['Town01', 'Town02']:
    for episode in [102, 103]:
        make_rgb_movie(birds_view_model, racetrack, episode, DECIMATION, CLASSES_NAMES, CAMERA_IDS)