In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def inter_extrapolation(x, y, e):
    """ Extrapolation and interpolation.
    
    Arguments
    ===
    x: numpy array
    y: numpy array
    e: numpy array
    
    Return
    ===
    return: numpy array
    """
    
    new_x = np.sort(x)
    new_y = y[np.argsort(x)]

    def point_wise(ep):
        if ep < new_x[0]:
            return new_y[0] + (ep - new_x[0]) * (new_y[1] - new_y[0]) / (new_x[1] - new_x[0])
        elif ep > new_x[-1]:
            return new_y[-1] + (ep - new_x[-1]) * (new_y[-1] - new_y[-2]) / (new_x[-1] - new_x[-2])
        else:
            return np.interp([ep], x, y)[0]
    return np.array([point_wise(i) for i in e])

In [None]:
def calculate_inventory_trading_speed(alpha, phi, t, tt, T, b, k):
    """ For given points t, this function solves for the optimal speed of trading (v), and investor's inventory along the
        optimal path (Q).
        This function also returns optimal speed of trading (vt), and investor's inventory along the optimal path (Qt) as a
        function of time, tt, which is a vector of time points chosen by users for marking.
        
    Arguments
    ===
    alpha: float
    phi: numpy array
    t: numpy array
    tt: numpy array
    T: int
    b: float
    k: float
    
    Return
    ===
    Q: numpy array
    v: numpy array
    Qt: numpy array
    vt: numpy array
    """
        
    return Q, v, Qt, vt

In [None]:
def plot_inventory_trading_speed(alpha0, phi, symb, t, tt, T, b, k, labels, main):
    """ This function plots optimal inventories and trading speeds.
    """
    
    fig, (ax_inv, ax_trad) = plt.subplots(ncols=2, figsize=(20, 10))
    color_idx = np.linspace(0, 1, phi.shape[0])
    for i, line in zip(color_idx, range(0, phi.shape[0])):
        inv_line, trad_line, inv_dot, trad_dot = calculate_inventory_trading_speed(alpha0, phi[line], t, tt, T, b, k)
        plt1, = ax_inv.plot(tt, inv_dot, color=plt.cm.rainbow(i), label=labels[line], marker=symb[line], linestyle='None')
        plt2, = ax_trad.plot(tt, trad_dot, color=plt.cm.rainbow(i), label=labels[line], marker=symb[line], linestyle='None')
        plt3, = ax_inv.plot(t, inv_line, linestyle='-', color=plt.cm.rainbow(i))
        plt4, = ax_trad.plot(t, trad_line, linestyle='-', color=plt.cm.rainbow(i))
    ax_inv.legend()
    ax_inv.set_xlabel(r"Time", fontsize=18)
    ax_inv.set_ylabel(r"Inventory", fontsize=18)
    ax_trad.legend()
    ax_trad.set_xlabel(r"Time", fontsize=18)
    ax_trad.set_ylabel(r"Trading Speed", fontsize=18)
    ax_trad.yaxis.set_label_coords(-0.1,0.5)
    plt.suptitle(main, fontsize=20)
    fig.canvas.draw()