In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import networkx as nx
from time import time
import matplotlib.pyplot as plt
import path_planning_utils as utils

In [None]:
n = 51
s = 's'
t = 't'
sides = [2, .5]

In [None]:
def solve_batch(n, batch_size):
    
    times = {'dijk': [], 'conv': []}
    lengths = {'dijk': [], 'conv': [], 'line': []}
    
    i = 0
    while len(times['dijk']) < batch_size:
        
        boxes = utils.generate_boxes(n, sides, seed=i)
        i += 1
        inters = utils.intersect(boxes, sides)
        G = utils.line_graph(boxes, inters, s, t)
        
        try:
            tic = time()
            path = nx.shortest_path(G, source=s, target=t, weight='weight')
        except nx.NetworkXNoPath:
            continue
        times['dijk'].append(time() - tic)
        traj = np.array([inters[v].center for v in path[1:-1]])
        length = sum(np.linalg.norm(y - x) for x, y in zip(traj[:-1], traj[1:]))
        lengths['dijk'].append(length)
        
        traj, length, solve_time = utils.optimize_path(path, inters)
        times['conv'].append(solve_time)
        lengths['conv'].append(length)
        
        d = boxes[(n - 1, n - 1)].bot - boxes[(0, 0)].top
        lengths['line'].append(np.linalg.norm(d))
        
    return times, lengths

In [None]:
import gc
ns = range(11, 102, 10)
batch_sizes = 10
times = {}
lengths = {}
for n in ns:
    print(n)
    gc.disable()
    times[n], lengths[n] = solve_batch(n + 1, batch_size)
    gc.enable()

In [None]:
def extract_stats(data, key):
    stats = {}
    stats['min'] = [min(data[n][key]) for n in ns]
    stats['med'] = [np.median(data[n][key]) for n in ns]
    stats['max'] = [max(data[n][key]) for n in ns]
    return stats

dijk_times_stats = extract_stats(times, 'dijk')
dijk_length_stats = extract_stats(lengths, 'dijk')

conv_times_stats = extract_stats(times, 'conv')
conv_length_stats = extract_stats(lengths, 'conv')

line_length_stats = extract_stats(lengths, 'line')

In [None]:
plt.figure(figsize=(5,6))

x = (np.array(ns) - 1) ** 2
def plot_stats(stats, color):
    plt.plot(x, stats['min'], color=color)
    plt.plot(x, stats['med'], color=color, linestyle='--')
    plt.plot(x, stats['max'], color=color)

plt.subplot(2, 1, 1)
plot_stats(dijk_times_stats, 'r')
plot_stats(conv_times_stats, 'b')
plt.plot(np.nan, np.nan, c='r', label='Dijkstra')
plt.plot(np.nan, np.nan, c='b', label='Mosek')
plt.ylabel('Time (s)')
plt.grid()
plt.legend()

plt.subplot(2, 1, 2)
plot_stats(dijk_length_stats, 'r')
plot_stats(conv_length_stats, 'b')
plot_stats(line_length_stats, 'g')
plt.plot(np.nan, np.nan, c='r', label='Dijkstra')
plt.plot(np.nan, np.nan, c='b', label='Mosek')
plt.plot(np.nan, np.nan, c='g', label='Straight line')
plt.ylabel('Trajectory length')
plt.grid()
plt.legend()
plt.xlabel('Number of boxes')

plt.savefig('path_planning_stats.pdf', bbox_inches='tight')