# Optimal transport

*Selected Topics in Mathematical Optimization: 2017-2018*

**Michiel Stock** ([email](michiel.stock@ugent.be))

![](Figures/logo.png)

In [None]:
from optimal_transport import red, green, orange, yellow, blue, black
import matplotlib.pyplot as plt
import numpy as np
from optimal_transport import pairwise_distances
%matplotlib inline

## Cell tracking

In a microscopy imaging experiment we monitor ten moving cells at time $t_1$ and some time later at time $t_2$. Between these times, the cells have moved. An image processing algorithm determined the coordinates of every cell in the two images. We want to know which cell in the first corresponds to the second image. To this end, search the assignment that minimizes the sum of the squared Euclidian distances between cells from the first image versus the corresponding cell of the second image.

1. `X1` and `X2` contain the $x,y$ coordinates of the cells for the two images. Compute the matrix $C$ containing the pairwise squared Euclidean distance. You can use the function `pairwise_distances` from `sklearn`.
2. Complete the function `monge_brute_force` to use brute-force search for the best permutation.
3. Make a plot connecting the cells.

In [None]:
from cell_tracking import X1, X2, plot_cells

In [None]:
# all permutations can easily be generated in python
from itertools import permutations

for perm in permutations([1, 2, 3]):
    print(perm)

In [None]:
fig, ax = plot_cells(X1, X2)

In [None]:
def monge_brute_force(C):
    """
    Solves the Monge assignment problem using
    brute force.

    Inputs:
        - C: cost matrix (square, size n x n)

    Outputs:
        - best_perm: optimal assigments (list of n indices matching the rows
                to the columns)
        - best_cost: optimal cost corresponding to the best permutation

    DO NOT USE FOR PROBLEMS OF A SIZE LARGER THAN 12!!!
    """
    n, m = C.shape
    assert n==m  # C should be square
    best_perm = None
    best_cost = np.inf
    # loop over all permutations and to find the
    # matching with the lowest cost
    return best_perm, best_cost

In [None]:
from optimal_transport import monge_brute_force

In [None]:
# get the cost matrix (i.e. pairwise squared
# Euclidean distances between the cells at the different times)

C = ...

In [None]:
# get matching

best_perm, best_cost = monge_brute_force(C)

In [None]:
# make a plot with the connections of the cells

## Cell differentiation

Three types of cells are cultured together. At $t_1$ we know the expression of some cells of every type (two genes). After some time $t_2$, the cells have multiplied are have differentiated somewhat. A new gene expression analysis is done for a set of cells from the culture (without information about the type). How did the expression change for every type?

1. Link the cells from the two time points using OT. Use Sinkhorn with $\lambda=10$ and squared Euclidean distance for cost.
2. Plot the mapping (use the \texttt{alpha} argument to set the shade of a color).
3. Compute the `drift' (difference in average gene expression) in gene expression for every cell type. 

In [None]:
# X1 and X2 are gene expressions for the cells at time 1 and 2
# y1 is the indicator of the type of cells, only known at t1
from cell_differentiation import X1, X2, y1, plot_cells

In [None]:
fig, ax = plt.subplots()
plot_cells(ax)

In [None]:
def compute_optimal_transport(C, a, b, lam, epsilon=1e-8,
                verbose=False, return_iterations=False):
    """
    Computes the optimal transport matrix and Slinkhorn distance using the
    Sinkhorn-Knopp algorithm

    Inputs:
        - C : cost matrix (n x m)
        - a : vector of marginals (n, )
        - b : vector of marginals (m, )
        - lam : strength of the entropic regularization
        - epsilon : convergence parameter
        - verbose : report number of steps while running
        - return_iterations : report number of iterations till convergence,
                default False

    Output:
        - P : optimal transport matrix (n x m)
        - dist : Sinkhorn distance
        - n_iterations : number of iterations, if `return_iterations` is set to
                        True
    """
    n, m = C.shape
    P = np.exp(- lam * C)
    iteration = 0
    while True:
        iteration += 1
        u = P.sum(1)  # marginals of rows
        max_deviation = np.max(np.abs(u - a))
        if verbose: print('Iteration {}: max deviation={}'.format(
                            iteration, max_deviation
                        ))
        if max_deviation < epsilon:
            break
        # scale rows
        ...
        # scale columns
        ...
    if return_iterations:
        return P, np.sum(P * C), iteration
    else:
        return P, np.sum(P * C)

In [None]:
from optimal_transport import compute_optimal_transport

In [None]:
# get the cost matrix (i.e. pairwise squared
# Euclidean distances of the expression vectors
# of the cells at the different times)

C = ...

In [None]:
# get matching
P, _ = compute_optimal_transport(...

In [None]:
# plot the cells with the mapping between the times

In [None]:
# compute the drift (average change in gene expression
# for different classes between the two time points)