# Optimal transport


In [1]:
from optimal_transport import *
from measure_spaces.measure_spaces import *
import ipywidgets as widgets
import scipy.stats as sts
import numpy as np

In [2]:
global ot
def build_f(M):
    def f(x_bar):
        if isinstance(x_bar, list):
            x_bar = np.array(x_bar)
        return np.dot(np.dot(M, x_bar), x_bar.T)
    return f


def build_ot(**kwargs):
    
    np.random.seed(754)

    bins = n_bins.value
    n = 100 * bins
    a1_, a2_, a3_ = a1.value, a2.value, a3.value 
    x = np.concatenate([
        sts.gamma(a1_).rvs(int(n / 2)),
        sts.gamma(a2_, loc=a3_).rvs(int(n / 2))])

    y = sts.gamma(b.value).rvs(n)

    mass_x, support_x = np.histogram(x, bins=bins)
    support_x = (support_x[:-1] + support_x[1:]) / 2

    alpha = DiscreteMeasure(support=support_x, mass=mass_x / n)

    mass_y, support_y = np.histogram(y, bins=bins)
    support_y = (support_y[:-1] + support_y[1:]) / 2
    beta = DiscreteMeasure(support=support_y, mass=mass_y / n)

    M = np.array([[f_grid[0, 0].value, f_grid[1, 0].value],
                  [f_grid[0, 1].value, f_grid[1, 1].value]])

    f = build_f(M)

    ot = DiscreteOT(alpha, beta, cost_function=f)
    return ot

alenght = '300px'

grid_cb = widgets.Checkbox(
    value=True,
    description='Grid',
    disabled=False,
    indent=False
)

n_bins = widgets.IntSlider(value=30,
                       min=10, max=200,
                       step=1,
                       disabled=False,
                       continuous_update=False,
                       orientation='horizontal',
                       readout_format='',
                       layout=widgets.Layout(width=alenght))

a1 = widgets.IntSlider(value=1,
                       min=1, max=10,
                       step=1,
                       disabled=False,
                       continuous_update=False,
                       orientation='horizontal',
                       readout_format='',
                       layout=widgets.Layout(width=alenght))
a2 = widgets.IntSlider(value=4,
                       min=1, max=10,
                       step=1,
                       disabled=False,
                       continuous_update=False,
                       orientation='horizontal',
                       readout_format='',
                       layout=widgets.Layout(width=alenght))

a3 = widgets.IntSlider(value=5,
                       min=-10, max=10,
                       step=1,
                       disabled=False,
                       continuous_update=False,
                       orientation='horizontal',
                       readout_format='',
                       layout=widgets.Layout(width=alenght))

b = widgets.IntSlider(value=10,
                      min=5, max=30,
                      step=1,
                      disabled=False,
                      continuous_update=False,
                      orientation='vertical',
                      readout_format='',
                      layout=widgets.Layout(width='70px', height='70px'))


f_len = '50px'

f_grid = widgets.GridspecLayout(4, 4)
f_grid[0, 0] = widgets.BoundedFloatText(
    value=1,
    min=-2,
    max=2,
    step=0.1,
    disabled=False,
    layout=widgets.Layout(width=f_len)
)
f_grid[0, 1] = widgets.BoundedFloatText(
    value=0,
    min=-2,
    max=2,
    step=0.1,
    disabled=False,
    layout=widgets.Layout(width=f_len)
)
f_grid[1, 0] = widgets.BoundedFloatText(
    value=0,
    min=-2,
    max=2,
    step=0.1,
    disabled=False,
    layout=widgets.Layout(width=f_len)
)
f_grid[1, 1] = widgets.BoundedFloatText(
    value=1,
    min=-2,
    max=2,
    step=0.1,
    disabled=False,
    layout=widgets.Layout(width=f_len)
)

all_inputs = [a1, a2, a3, b, f_grid[0, 0], f_grid[1, 0], f_grid[0, 1], f_grid[1, 1]]
# for wid in all_inputs:
#     wid.observe(build_ot, 'value')


def plot_solution(a1,a2,a3,b,f00,f10,f01,f11, n_bins, grid_cb):
    ot = build_ot()
    ot.solve()
    ot.plot(grid_cb)
    plt.show()


def plot_cost_functon(f00,f10,f01,f11):
    ot = build_ot()
    ot.plot_cost_function()
    plt.show()


out_sol = widgets.interactive_output(plot_solution, {'a1': a1,
                                                    'a2': a2,
                                                    'a3': a3,
                                                    'b': b,
                                                    'f00': f_grid[0, 0],
                                                     'f10': f_grid[1, 0],
                                                     'f01': f_grid[0, 1],
                                                     'f11': f_grid[1, 1],
                                                     'n_bins': n_bins,
                                                     'grid_cb': grid_cb
                                                    })
out_cf = widgets.interactive_output(plot_cost_functon,{'f00': f_grid[0, 0],
                                                     'f10': f_grid[1, 0],
                                                     'f01': f_grid[0, 1],
                                                     'f11': f_grid[1, 1]})

widgets.VBox([
    widgets.HBox([
        widgets.VBox([a1, a2, a3]),
        b,
        f_grid,
        widgets.VBox([n_bins, grid_cb])
    ]),
    widgets.HBox([
        out_sol,
        out_cf
    ])
])
    

NameError: name 'widgets' is not defined