In [1]:
%matplotlib inline

import sys
sys.path.append('..')

import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
from time import time
import urllib

import navbench as nb
from navbench import imgproc as ip

# Set these properties for all figures
plt.rcParams['figure.figsize'] = [15, 8]
plt.rcParams['font.size'] = 20

DBROOT = '../datasets/rc_car/Stanmer_park_dataset'

# Boundaries for tests, e.g. a value of 50 means that 50 frames either side of the goal will be tested
FRAME_DIST  = 50

# 1 == include every frame, 2 == every other frame etc.
FR_STEP     = 1

# Rotational step for RIDF-type calculations
RIDF_STEP   = 2

# Size of median filter to apply to values
MEDFILT_SIZE = 5

# Image preprocessing pipeline - if you provide a tuple of functions then each will be applied to the image in turn
IM_SIZE     = (25, 90)
RESIZE     = ip.resize(*IM_SIZE)
PREPROC    = (RESIZE, ip.remove_sky)

SNAP_STEP = 10

def load_db(path):
    db = nb.Database(os.path.join(DBROOT, path))
    print('Database %s has %d images' % (path, len(db)))
    return db

# def get_ca_sizes(db, images, snapshots):
#     ca_sizes = []
    
#     # Cache result because this step takes aaaaages
#     pkl_path = 'ca_sizes_%s.pkl' % urllib.parse.quote_plus(db.path)
#     if os.path.exists(pkl_path):
#         ca_sizes, num_inf = pickle.load(open(pkl_path, 'rb'))
#     else:
#         num_inf = 0
#         for goal_idx, snap in enumerate(snapshots):
#             diffs = nb.route_ridf(images, snap, step=2)
#             ca = nb.calculate_ca(diffs, goal_idx, MEDFILT_SIZE)
#             size = db.calculate_distance(*ca.get_finite_bounds())
#             num_inf += math.isinf(ca.size())
#             ca_sizes.append(size)
#         pickle.dump((ca_sizes, num_inf), open(pkl_path, 'wb'))

#     if num_inf:
#         print("WARNING: %i/%i CAs extended beyond route end" % (num_inf, len(ca_sizes)))

#     print('CA range: [%f, %f]' % (np.amin(ca_sizes), np.amax(ca_sizes)))
#     return ca_sizes

names = ['0511/unwrapped_dataset%d' % num for num in range(1, 6)]
descriptions = ['0 m away', '0 m away (repeat)', '1 m away (left)',
                '1 m away (right)', '2 m away (right)']
dbs = []
images = []

for name in names:
    dbs.append(load_db(name))
    images.append(dbs[-1].read_images(preprocess=PREPROC))

from navbench.infomax import InfoMax

@nb.cache_result
def get_ann(images, im_size, lrate, seed):
    assert not seed is None

    ann = InfoMax(im_size[0] * im_size[1], learning_rate=lrate, seed=seed)
    for im in images:
        ann.train(im)
    return ann

snapshots = images[0][0::SNAP_STEP]
ann = get_ann(snapshots, IM_SIZE, InfoMax.DEFAULT_LEARNING_RATE, 42)

# print('Calculating CAs...')
# t0 = time()
# snapshots = images[0][0::SNAP_STEP]
# ca_sizes = [get_ca_sizes(db, im, snapshots) for db, im in zip(dbs, images)]
# print('Elapsed: %g s' % (time() - t0))

Loading database at ../datasets/rc_car/Stanmer_park_dataset/0511/unwrapped_dataset1...
Database contains 411 images
Database 0511/unwrapped_dataset1 has 411 images
Loading database at ../datasets/rc_car/Stanmer_park_dataset/0511/unwrapped_dataset2...
Database contains 367 images
Database 0511/unwrapped_dataset2 has 367 images
Loading database at ../datasets/rc_car/Stanmer_park_dataset/0511/unwrapped_dataset3...
Database contains 402 images
Database 0511/unwrapped_dataset3 has 402 images
Loading database at ../datasets/rc_car/Stanmer_park_dataset/0511/unwrapped_dataset4...
Database contains 452 images
Database 0511/unwrapped_dataset4 has 452 images
Loading database at ../datasets/rc_car/Stanmer_park_dataset/0511/unwrapped_dataset5...
Database contains 386 images
Database 0511/unwrapped_dataset5 has 386 images
Starting get_ann()...
Seed for InfoMax net: 42
get_ann() took 37.5686 s to run (without caching)
Saving result to cache '.navbench_cache/2164391825.py_get_ann_3a58a07d366fb480ead16

In [None]:
from pathos.multiprocessing import Pool

# Plot UTM coordinates
def plot_route(ax, db, *args, **kwargs):
    return ax.plot(db.x, db.y, *args, **kwargs)

@nb.cache_result
def get_infomax_heads(ann, head_offset, ims):
    def get_infomax_head(im):
        return head_offset + nb.ridf_to_radians(ann.ridf(im))

    with Pool() as pool:
        return pool.map(get_infomax_head, ims)

def plot_arrows_infomax(ax, ann):
    test_dists = range(0, 55, 3)
    for db, ims, line in zip(dbs, images, lines):
        colour = line[0].get_color()
        head_offset = db.calculate_heading_offset(0.5)
        test_frames = [np.argmin(np.abs(db.distance - test_dist)) for test_dist in test_dists]

        heads = get_infomax_heads(ann, head_offset, [ims[fr] for fr in test_frames])
#             heads = [nb.ridf_to_radians(ann.ridf(im)) for im in ims_sel]
#             heads = [head + head_offset for head in heads]
        x = [db.x[i] for i in test_frames]
        y = [db.y[i] for i in test_frames]
        u = [math.cos(head) for head in heads]
        v = [math.sin(head) for head in heads]

        ax.quiver(x, y, u, v, angles='xy', color=colour, zorder=10, scale_units='xy', scale=0.5)



fig, ax = plt.subplots()
ax.axis('equal')
lines = [plot_route(ax, db) for db in dbs]
plot_arrows_infomax(ax, ann)
ax.legend(descriptions)

fig.savefig('infomax_rc_car.svg')

Starting get_infomax_heads()...
