Skip to content
Merged
20 changes: 19 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ (matrix.python && 'system Python') || 'conda' }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand All @@ -23,13 +23,27 @@ jobs:
- windows-latest
arch:
- x64
python:
- ''
- python
include:
- version: '1'
os: ubuntu-latest
arch: x64
coverage: true
steps:
- uses: actions/checkout@v2
- name: Install python
uses: actions/setup-python@v2
with:
python-version: '3.x'
architecture: ${{ matrix.arch }}
if: matrix.python
# Limitation of pip: https://pythonot.github.io/index.html#pip-installation
- run: python -m pip install cython numpy
if: matrix.python
- run: python -m pip install pot
if: matrix.python
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand All @@ -45,7 +59,11 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
env:
PYTHON: ${{ matrix.python }}
- uses: julia-actions/julia-runtest@v1
env:
PYTHON: ${{ matrix.python }}
- uses: julia-actions/julia-processcoverage@v1
if: matrix.coverage
- uses: codecov/codecov-action@v1
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/Docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ jobs:
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()'
env:
PYTHON: ''
- run: julia --project=docs docs/make.jl
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }}
PYTHON: ''
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
authors = ["David Widmann"]
version = "0.1.0"

[deps]
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

[compat]
PyCall = "1"
julia = "1"

[extras]
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# POT.jl

Julia interface for the Python Optimal Transport (POT) library
*Julia interface for the [Python Optimal Transport (POT) package](https://pythonot.github.io/)*

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://devmotion.github.io/POT.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://devmotion.github.io/POT.jl/dev)
[![Build Status](https://github.com/devmotion/POT.jl/workflows/CI/badge.svg)](https://github.com/devmotion/POT.jl/actions)
[![Coverage](https://codecov.io/gh/devmotion/POT.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/devmotion/POT.jl)
[![Coverage](https://coveralls.io/repos/github/devmotion/POT.jl/badge.svg?branch=master)](https://coveralls.io/github/devmotion/POT.jl?branch=master)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)

This package was originally part of [OptimalTransport.jl](https://github.com/zsteve/OptimalTransport.jl).
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ makedocs(;
canonical="https://devmotion.github.io/POT.jl",
assets=String[],
),
pages=["Home" => "index.md"],
pages=["Home" => "index.md", "api.md"],
strict=true,
checkdocs=:exports,
)
Expand Down
22 changes: 22 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# API

## Exact optimal transport (Kantorovich) problem

```@docs
emd
emd2
```

## Entropically regularised optimal transport

```@docs
sinkhorn
sinkhorn2
```

## Unbalanced optimal transport

```@docs
sinkhorn_unbalanced
sinkhorn_unbalanced2
```
15 changes: 2 additions & 13 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
```@meta
CurrentModule = POT
```
# POT.jl

# POT

Documentation for [POT](https://github.com/devmotion/POT.jl).

```@index
```

```@autodocs
Modules = [POT]
```
*Julia interface for the [Python Optimal Transport (POT) package](https://pythonot.github.io/)*
12 changes: 11 additions & 1 deletion src/POT.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
module POT

# Write your package code here.
using PyCall: PyCall

export emd, emd2, sinkhorn, sinkhorn2, sinkhorn_unbalanced, sinkhorn_unbalanced2

const pot = PyCall.PyNULL()

include("lib.jl")

function __init__()
return copy!(pot, PyCall.pyimport_conda("ot", "pot", "conda-forge"))
end

end
145 changes: 145 additions & 0 deletions src/lib.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
emd(mu, nu, C)

Compute transport map for Monge-Kantorovich problem with source and target marginals `mu`
and `nu` and a cost matrix `C` of dimensions `(length(mu), length(nu))`.

Return optimal transport coupling `γ` of the same dimensions as `C` which solves

```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
```

This function is a wrapper of the function
[`emd`](https://pythonot.github.io/all.html#ot.emd) in the Python Optimal Transport
package.
"""
function emd(mu, nu, C)
return pot.lp.emd(nu, mu, PyCall.PyReverseDims(C))'
end

"""
emd2(mu, nu, C)

Compute exact transport cost for Monge-Kantorovich problem with source and target marginals
`mu` and `nu` and a cost matrix `C` of dimensions `(length(mu), length(nu))`.

Returns optimal transport cost (a scalar), i.e. the optimal value

```math
\\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\langle \\gamma, C \\rangle
```

This function is a wrapper of the function
[`emd2`](https://pythonot.github.io/all.html#ot.emd2) in the Python Optimal Transport
package.
"""
function emd2(mu, nu, C)
return pot.lp.emd2(nu, mu, PyCall.PyReverseDims(C))[1]
end

"""
sinkhorn(mu, nu, C, eps; tol=1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)

Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C` and entropic
regularization parameter `eps`.

Method can be a choice of `"sinkhorn"`, `"greenkhorn"`, `"sinkhorn_stabilized"`, or
`"sinkhorn_epsilon_scaling"` (Flamary et al., 2017).

This function is a wrapper of the function
[`sinkhorn`](https://pythonot.github.io/all.html?highlight=sinkhorn#ot.sinkhorn) in the
Python Optimal Transport package.
"""
function sinkhorn(mu, nu, C, eps; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false)
return pot.sinkhorn(
nu,
mu,
PyCall.PyReverseDims(C),
eps;
stopThr=tol,
numItermax=max_iter,
method=method,
verbose=verbose,
)'
end

"""
sinkhorn2(mu, nu, C, eps; tol=1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)

Compute optimal transport cost of histograms `mu` and `nu` with cost matrix `C` and
entropic regularization parameter `eps`.

Method can be a choice of `"sinkhorn"`, `"greenkhorn"`, `"sinkhorn_stabilized"`, or
`"sinkhorn_epsilon_scaling"` (Flamary et al., 2017).

This function is a wrapper of the function
[`sinkhorn2`](https://pythonot.github.io/all.html?highlight=sinkhorn#ot.sinkhorn2) in the
Python Optimal Transport package.
"""
function sinkhorn2(
mu, nu, C, eps; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
)
return pot.sinkhorn2(
nu,
mu,
PyCall.PyReverseDims(C),
eps;
stopThr=tol,
numItermax=max_iter,
method=method,
verbose=verbose,
)[1]
end

"""
sinkhorn_unbalanced(mu, nu, C, eps, lambda; tol = 1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)

Compute optimal transport map of histograms `mu` and `nu` with cost matrix `C`, using
entropic regularisation parameter `eps` and marginal weighting functions `lambda`.

This function is a wrapper of the function
[`sinkhorn_unbalanced`](https://pythonot.github.io/all.html?highlight=sinkhorn_unbalanced#ot.sinkhorn_unbalanced)
in the Python Optimal Transport package.
"""
function sinkhorn_unbalanced(
mu, nu, C, eps, lambda; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
)
return pot.sinkhorn_unbalanced(
nu,
mu,
PyCall.PyReverseDims(C),
eps,
lambda;
stopThr=tol,
numItermax=max_iter,
method=method,
verbose=verbose,
)'
end

"""
sinkhorn_unbalanced2(mu, nu, C, eps, lambda; tol = 1e-9, max_iter = 1000, method = "sinkhorn", verbose = false)

Compute optimal transport cost of histograms `mu` and `nu` with cost matrix `C`, using
entropic regularisation parameter `eps` and marginal weighting functions `lambda`.

This function is a wrapper of the function
[`sinkhorn_unbalanced2`](https://pythonot.github.io/all.html#ot.sinkhorn_unbalanced2) in
the Python Optimal Transport package.
"""
function sinkhorn_unbalanced2(
mu, nu, C, eps, lambda; tol=1e-9, max_iter=1000, method="sinkhorn", verbose=false
)
return pot.sinkhorn_unbalanced2(
nu,
mu,
PyCall.PyReverseDims(C),
eps,
lambda;
stopThr=tol,
numItermax=max_iter,
method=method,
verbose=verbose,
)[1]
end