# Sinkhorn algorithm for Optimal Transport on Graphs

This notebook does not import anything from the main project, it is meant to be imported and run in Google Colab if you do not have access to a GPU locally.

## Imports

In [None]:
# There are some incompatibilities between certain scipy and networkx versions
%pip install --upgrade scipy networkx
%pip install scipy==1.8.1

In [1]:
import json
import random
import warnings
from copy import deepcopy
from functools import wraps
from time import perf_counter
from typing import Tuple, List, Dict, Optional, Union, Any, Callable

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch

## Utilitary functions

In [2]:
def timeit(func: Callable[..., Any]) -> Callable[..., Any]:
    """
    Decorator for timing function execution time.

    Args:
        func: The function to time.

    Returns:
        The wrapped function.
    """

    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = perf_counter()
        result = func(*args, **kwargs)
        end_time = perf_counter()
        total_time = end_time - start_time
        print(f"Function {func.__name__} took {total_time:.2f} seconds")
        return result

    return timeit_wrapper


def return_runtime(func: Callable[..., Any]) -> Callable[..., Tuple[float, ...]]:
    """
    Decorator that adds the execution time to the return values of the function.
    Unfortunately this decorator does not preserve the typing of the inner function.

    Args:
        func: The function to time.

    Returns:
        The wrapped function.
    """

    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = perf_counter()
        result = func(*args, **kwargs)
        end_time = perf_counter()
        total_time = end_time - start_time
        return total_time, result

    return timeit_wrapper


def checkpoint(time_ref: float = perf_counter()) -> Callable[..., None]:
    """
    Closure that stores a time checkpoint that is updated at every call.
    Each call prints the time elapsed since the last checkpoint with a custom message.

    Args:
        time_ref: The time reference to start from. By default, the time of the call will be taken.

    Returns:
        The closure.
    """

    def _closure(message: str = "") -> None:
        """
        Prints the time elapsed since the previous call.

        Args:
            message: Custom message to print. The overall result will be: 'message: time_elapsed'.
        """
        nonlocal time_ref
        current_time = perf_counter()
        if message != "":
            print(f"{message}: {current_time - time_ref:.4f}")
        time_ref = current_time

    return _closure


## Data generation

In [None]:
def create_path_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a path graph (linear graph).
    """
    return nx.path_graph(graph_size).to_directed()


def create_cycle_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a cycle graph (circular graph).
    """
    return nx.cycle_graph(graph_size).to_directed()


def create_wheel_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a wheel graph (see Networkx's documentation).
    """
    return nx.wheel_graph(graph_size).to_directed()


def create_complete_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a complete graph.
    """
    return nx.complete_graph(graph_size).to_directed()


def create_watts_strogatz_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a connected graph using the Watts–Strogatz random graph generation model.
    """
    return nx.connected_watts_strogatz_graph(graph_size, 3, 0.4).to_directed()


def create_gnp_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a connected graph using the Erdős–Rényi random graph generation model.
    The loop looks ugly but something similar is actually performed in nx.connected_watts_strogatz_graph.
    """
    n_attempts = 100
    for i in range(n_attempts):
        # for p > (1 + eps) ln(n) / n, the Erdős–Rényi graph should be connected almost surely
        gnp_graph = nx.gnp_random_graph(graph_size, 2 * np.log(graph_size) / graph_size)
        if nx.is_connected(gnp_graph):
            print(
                f"Sparsity of the graph: {gnp_graph.size() / len(gnp_graph) ** 2 * 100:.2f}%."
            )
            return gnp_graph.to_directed()


def create_bipartite_graph(graph_size: int) -> nx.DiGraph:
    """
    Creates a connected bipartite graph using the Erdős–Rényi random graph generation model.
    The loop looks ugly but something similar is actually performed in nx.connected_watts_strogatz_graph.
    """
    n_attempts = 100
    for i in range(n_attempts):
        # for p > (1 + eps) ln(n) / n, the Erdős–Rényi graph should be connected almost surely
        bipartite_graph = nx.bipartite.random_graph(
            graph_size // 2, graph_size // 2, 2 * np.log(graph_size) / graph_size
        )
        if nx.is_connected(bipartite_graph):
            print(
                f"Sparsity of the graph: {bipartite_graph.size() / len(bipartite_graph) ** 2 * 100:.2f}%."
            )
            return bipartite_graph.to_directed()


def create_graph(graph_size: int, graph_type: str) -> nx.DiGraph():
    """
    Creates a graph using one of the functions above depending on the input graph type.
    """
    return (
        create_bipartite_graph(graph_size)
        if graph_type == "bipartite"
        else create_cycle_graph(graph_size)
        if graph_type == "cycle"
        else create_path_graph(graph_size)
        if graph_type == "path"
        else create_complete_graph(graph_size)
        if graph_type == "complete"
        else create_gnp_graph(graph_size)
    )

In [3]:
def add_random_weights(
    graph: nx.Graph,
    plot: bool = True,
    positions: Optional[Dict[int, Tuple[float, float]]] = None,
) -> None:
    """
    Adds random weights on the edges of the graph.
    The weights are integers chosen randomly between 0 and 10 included.
    """
    # noinspection PyArgumentList
    for (_, __, w) in graph.edges(data=True):
        w["weight"] = random.randint(0, 10)

    if plot:
        print("Plotting the weights on each edge.")
        plt.figure()

        positions = positions or nx.spectral_layout(graph)
        edge_labels = {(u, v): weight for u, v, weight in graph.edges.data("weight")}

        nx.draw_networkx_edge_labels(graph, positions, edge_labels=edge_labels)
        nx.draw(graph, positions, with_labels=True, node_size=500)

        plt.show()

In [None]:
def add_random_distributions(
    graph: nx.Graph,
    plot: bool = True,
    positions: Optional[Dict[int, Tuple[float, float]]] = None,
    distribution: str = "dirichlet",
    nonzero_ratio: float = 1.0,
) -> None:
    """
    Adds two random distributions on the nodes of the graph.
    The distributions are added as node attributes (rho_0 and rho_1).
    They are sampled using a Dirichlet distribution.
    """
    n_nonzero = int(nonzero_ratio * len(graph))
    # sampling the distributions using Dirichlet distributions (no need to divide by the sum)
    if distribution == "dirichlet":
        rho_0 = np.random.dirichlet(np.ones(n_nonzero), size=1)[0]
        rho_1 = np.random.dirichlet(np.ones(n_nonzero), size=1)[0]
    else:
        rho_0 = np.random.random(n_nonzero)
        rho_0 = rho_0 / np.sum(rho_0)
        rho_1 = np.random.random(n_nonzero)
        rho_1 = rho_1 / np.sum(rho_1)

    nonzero_indexes = np.random.choice(len(graph), size=n_nonzero, replace=False)
    j = 0
    # noinspection PyArgumentList
    for i, (_, w) in enumerate(graph.nodes(data=True)):
        if i in nonzero_indexes:
            w["rho_0"] = rho_0[j]
            w["rho_1"] = rho_1[j]
            j += 1
        else:
            w["rho_0"] = 0.0
            w["rho_1"] = 0.0

    if plot:
        print("Plotting the two distributions on each node.")
        positions = positions or nx.spectral_layout(graph)
        plt.figure(figsize=(14, 7))

        ax1 = plt.subplot(121)
        distribution = {
            node: round(graph.nodes[node]["rho_0"], 2) for node in graph.nodes()
        }
        nx.draw_networkx_labels(graph, positions, labels=distribution, font_size=10)
        nx.draw(
            graph,
            positions,
            node_color=list(distribution.values()),
            node_size=800,
            ax=ax1,
        )

        ax2 = plt.subplot(122)
        distribution = {
            node: round(graph.nodes[node]["rho_1"], 2) for node in graph.nodes()
        }
        nx.draw_networkx_labels(graph, positions, labels=distribution, font_size=10)
        nx.draw(
            graph,
            positions,
            node_color=list(distribution.values()),
            node_size=800,
            ax=ax2,
        )

        plt.show()

In [None]:
def plot_transportation_plan(
    graph: nx.Graph, positions: Dict[int, Tuple[float, float]]
) -> None:
    """
    Plots a graph with the associated transportation plan.

    Args:
        graph: networkx graph whose nodes have attributes 'rho_0' and 'rho_1' for the distributions
            and attribute 'ot' on the edges for the value of the transportation plan.
        positions: positions of each node.
    """
    plt.figure()
    edge_labels = {(u, v): round(ot, 3) for u, v, ot in graph.edges.data("ot") if ot}
    nx.draw_networkx_edge_labels(graph, positions, edge_labels=edge_labels)
    nx.draw(graph, positions, with_labels=True, node_size=200, arrowsize=15)

    plt.show()

## Graph manipulation

In [4]:
def compute_cost_matrix(graph: nx.Graph) -> Tuple[np.ndarray, List[List[List[int]]]]:
    """
    Computes the cost matrix on a graph.
    """
    n_nodes = len(graph)
    cost_matrix = np.zeros((n_nodes, n_nodes))
    shortest_paths = [[[] for _ in range(n_nodes)] for _ in range(n_nodes)]

    # using networkx built-in shortest path (dijkstra algorithm)
    for origin, path_lengths in dict(nx.all_pairs_dijkstra(graph)).items():
        for destination, length in path_lengths[0].items():
            cost_matrix[origin, destination] = length
        for destination, path in path_lengths[1].items():
            shortest_paths[origin][destination] = path

    return cost_matrix, shortest_paths

In [None]:
def collect_graph(
    graph: nx.Graph,
    transportation_plan: np.ndarray,
    shortest_paths: List[List[List[int]]],
) -> nx.Graph:
    """
    Constructs a graph obtained by adding the values of a transportation plan between couples of nodes
    on all the edges that form the path between the two nodes.
    """
    collected_graph = deepcopy(graph)
    nx.set_edge_attributes(collected_graph, 0., "ot")
    for origin, row in enumerate(transportation_plan):
        for destination, value in enumerate(row):
            if value and value > 1e-6:
                u = origin
                for v in shortest_paths[origin][destination][1:]:
                    collected_graph.edges[u, v]["ot"] += value
                    u = v

    return collected_graph

## Sinkhorn algorithm using PyTorch

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Running Sinkhorn on {device}.")

In [None]:
def after_sinkhorn(
    transportation_plan: np.ndarray,
    graph: nx.Graph,
    shortest_paths: List[List[List[int]]],
    cost_matrix: np.ndarray,
    f: np.ndarray,
    verbose: bool,
) -> Tuple[float, float, np.ndarray, float, nx.Graph]:
    if verbose:
        nonzero = np.count_nonzero(transportation_plan)
        print(
            f"Optimal transportation plan (number of nonzero: {nonzero} / {transportation_plan.size}):"
        )
        print(np.round(transportation_plan, 2))

    # creating a copy of the graph that will only have the edges where there is a movement of mass
    uncollected_graph = nx.create_empty_copy(graph)
    # adding an attribute that will yield the value of the transportation plan
    uncollected_graph.add_edges_from(
        [
            (origin, destination, {"ot": value})
            for origin, row in enumerate(transportation_plan)
            for destination, value in enumerate(row)
            if value > 1e-3
        ]
    )

    # adding the values of the ot on the initial graph by putting back the plan on the actual edges
    collected_graph = collect_graph(graph, transportation_plan, shortest_paths)

    flow = np.array(list(nx.get_edge_attributes(collected_graph, "ot").values()))
    cost = float(np.sum(cost_matrix * transportation_plan))
    quadratic_term = float(np.sum(np.square(flow)))
    err = np.linalg.norm(nx.incidence_matrix(graph, oriented=True).toarray() @ flow - f)

    if verbose:
        nonzero = np.count_nonzero(flow)
        print(f"Optimal flow (number of nonzero: {nonzero} / {graph.size()}):")
        print(np.round(flow, 2))

    return cost, quadratic_term, flow, err, collected_graph

### Standard Sinkhorn

In [5]:
def sinkhorn(
    graph: nx.Graph, alpha: float, verbose: bool = True
) -> Tuple[float, float, np.ndarray, float, nx.Graph]:

    epsilon = alpha

    n_nodes = len(graph)
    rho_0 = np.array(list(nx.get_node_attributes(graph, "rho_0").values()))
    rho_1 = np.array(list(nx.get_node_attributes(graph, "rho_1").values()))
    cost_matrix, shortest_paths = compute_cost_matrix(graph)

    K = np.exp(-(cost_matrix**2) / epsilon)
    u = torch.ones(n_nodes)
    v = torch.ones(n_nodes)

    K1 = torch.from_numpy(K).type(torch.FloatTensor)
    a1 = torch.from_numpy(rho_0).type(torch.FloatTensor)
    b1 = torch.from_numpy(rho_1).type(torch.FloatTensor)

    K1 = K1.to(device)
    u = u.to(device)
    v = v.to(device)
    a1 = a1.to(device)
    b1 = b1.to(device)

    n_iter = 4000
    for i in range(n_iter):
        u = a1 / (K1 * v[None, :]).sum(1)
        v = b1 / (K1 * u[:, None]).sum(0)

    transportation_plan = np.diag(u.cpu()) @ K @ np.diag(v.cpu())

    return after_sinkhorn(
        transportation_plan, graph, shortest_paths, cost_matrix, rho_1 - rho_0, verbose
    )

Running Sinkhorn on cuda:0.


### Log-domain Sinkhorn

Implementation based on a log-sum-exp trick to improve numerical stability.

In [None]:
def stable_sinkhorn(
    graph: nx.Graph, alpha: float, verbose: bool = True
) -> Tuple[float, float, np.ndarray, float, nx.Graph]:

    epsilon = alpha

    n_nodes = len(graph)
    rho_0 = np.array(list(nx.get_node_attributes(graph, "rho_0").values()))
    rho_1 = np.array(list(nx.get_node_attributes(graph, "rho_1").values()))
    cost_matrix, shortest_paths = compute_cost_matrix(graph)

    C = torch.autograd.Variable(torch.from_numpy(cost_matrix).to(device))

    def modified_cost(u_val, v_val):
        return (-C + u_val.unsqueeze(1) + v_val.unsqueeze(0)) / epsilon

    def stable_lse(A):
        # adding 10^-6 to prevent NaN
        return torch.log(
            torch.exp(A - torch.max(A)).sum(1, keepdim=True) + 1e-6
        ) + torch.max(A)

    u = torch.ones(n_nodes)
    v = torch.ones(n_nodes)

    a1 = torch.from_numpy(rho_0).type(torch.FloatTensor)
    b1 = torch.from_numpy(rho_1).type(torch.FloatTensor)

    u = u.to(device)
    v = v.to(device)
    a1 = a1.to(device)
    b1 = b1.to(device)

    n_iter = 10000
    for i in range(n_iter):
        u = epsilon * (torch.log(a1) - stable_lse(modified_cost(u, v)).squeeze()) + u
        v = (
            epsilon * (torch.log(b1) - stable_lse(modified_cost(u, v).t()).squeeze())
            + v
        )

    transportation_plan = torch.exp(modified_cost(u, v)).cpu()

    return after_sinkhorn(
        transportation_plan.numpy(),
        graph,
        shortest_paths,
        cost_matrix,
        rho_1 - rho_0,
        verbose,
    )

## Pipeline

In [None]:
choose_algo = {
    "sinkhorn": sinkhorn,
    "stable_sinkhorn": stable_sinkhorn,
}

In [None]:
def update_records(
    records: Dict[str, Union[float, int, List[np.ndarray]]],
    dist: float,
    quadratic_term: float,
    error: float,
    solution: np.ndarray,
    runtime: float,
) -> None:
    """
    Adds a record to a dict of records taking into account unsuccessful runs.
    """
    if dist < 1e-12 or dist == np.inf or np.isnan(dist):
        records["fails"] += 1
    else:
        records["cost"] += dist
        records["quadratic_term"] += quadratic_term
        records["error"] += error
    records["runtime"] += runtime
    records["solutions"].append(solution)


def average_records(
    records: Dict[str, Union[float, int, List[np.ndarray]]],
    n_runs_per_graph: int,
) -> None:
    """
    Averages the values of the costs and quadratic terms measured by dividing by the number of successful runs.
    Also updates the average runtime.
    """
    n_successful_runs = n_runs_per_graph > records["fails"]

    if n_successful_runs > 0:
        records["cost"] /= n_successful_runs
        records["quadratic_term"] /= n_successful_runs
        records["error"] /= n_successful_runs

    records["runtime"] /= n_runs_per_graph

In [6]:
@timeit
def full_pipeline(
    graph: nx.Graph, n_runs_per_graph: int, nonzero_ratio: float, *args, **kwargs
) -> Dict[str, Union[float, int, List[np.ndarray]]]:
    results = {
        algo: {
            "cost": 0.0,
            "quadratic_term": 0.0,
            "error": 0.0,
            "runtime": 0,
            "fails": 0,
            "solutions": [],
        }
        for algo in choose_algo
    }

    for n_run in range(n_runs_per_graph):
        graph_copy = deepcopy(graph)
        add_random_distributions(graph_copy, plot=False, nonzero_ratio=nonzero_ratio)

        for algo in ["sinkhorn", "stable_sinkhorn"]:
            print(
                f"-- Run number {n_run:>{len(str(n_runs_per_graph))}} of algo {algo:<15}:",
                end=" ",
            )
            runtime, (dist, quad_term, sol, err, sol_graph) = return_runtime(
                choose_algo[algo]
            )(graph_copy, *args, **kwargs)
            print(
                f"cost: {dist:.2f}, quadratic term: {quad_term:.2f}, err: {err:.2f}, runtime: {runtime:.2f} s"
            )
            update_records(results[algo], dist, quad_term, err, sol, runtime)
        print("")

    for record in results.values():
        average_records(record, n_runs_per_graph)

    return results

## Experiment

In [None]:
warnings.filterwarnings("ignore")

graph_sizes = [50, 100, 500, 1000]
graph_type = "bipartite"
alphas = [10, 5, 1, 0.1, 1e-2, 1e-3, 1e-4, 1e-5]
n_runs_per_graph = 10
proportion_of_sink = 0.1

results = {graph_size: {} for graph_size in graph_sizes}
for graph_size in graph_sizes:
    graph = create_graph(graph_size, graph_type)
    add_random_weights(graph, plot=False)

    for alpha in alphas:
        print(f"\nGraph size: {graph_size}, alpha: {alpha}")
        results[graph_size][alpha] = full_pipeline(
            graph,
            n_runs_per_graph,
            proportion_of_sink,
            alpha=alpha,
            verbose=False,
        )

# removing unserializable np.ndarray
for graph_size in results.values():
    for alpha in graph_size.values():
        for algo in alpha.values():
            algo.pop("solutions")

print(json.dumps(results, indent=2))