Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add a geomloss wrapper for sinkhorn solver #571

Merged
merged 17 commits into from Nov 21, 2023
Merged

[MRG] Add a geomloss wrapper for sinkhorn solver #571

merged 17 commits into from Nov 21, 2023

Conversation

rflamary
Copy link
Collaborator

@rflamary rflamary commented Nov 10, 2023

Types of changes

This is a first attempt to build on the awesome geomloss and Keops of @jeanfeydy. this function empirical_sinkhorn2_geomloss is a simple wrapper that will be called in ot.sovle_sample with method='geomloss' .

import numpy as np
import scipy as sp
import ot

n = 1000
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
x2 = rng.randn(n//2, 2)+5

xb = torch.tensor(x, dtype=torch.float32)
xb2 = torch.tensor(x2, dtype=torch.float32)

a = torch.ones(n, dtype=torch.float32, requires_grad=True) / n
b = torch.ones(n//2, dtype=torch.float32, requires_grad=True) / (n//2)

#%%  empirical_sinkhorn2_geomloss wrapper for geomloss
reg=1

ot.tic()
value0, log0 = ot.bregman.empirical_sinkhorn2(xb, xb2, reg=reg, lazy=False, log=True)
ot.toc('Classical sinhorn solver : {}s')

ot.tic()
value, log = empirical_sinkhorn2_geomloss(xb, xb2, reg=reg, log= True)
ot.toc('Geomloss : {}s')
T = log['lazy_plan'] # recover lazy plan

#%% ot.solve_sample wrapper for geomloss
reg=1

# automatic solver
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_auto', lazy=True)
ot.toc('Geomloss (automatic) : {}s')

# tensorized solver is fast but O(n^2) in memory
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_tensorized', lazy=True)
ot.toc('Geomloss tensorized : {}s')

# online solver compute the distanec marix when necessary and is O(n) in memory
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_online', lazy=True)
ot.toc('Geomloss online : {}s')

# multiscale is usually the fastest
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_multiscale', lazy=True)
ot.toc('Geomloss multiscale : {}s')

The computational time returned by the script above on CPU only machine are:

[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode
Classical sinhorn solver : 0.6667807102203369s
Geomloss : 0.037354469299316406s
Geomloss (automatic) : 0.05784487724304199s
Geomloss tensorized : 0.052454471588134766s
Geomloss online : 0.2332136631011963s
Geomloss multiscale : 0.04137110710144043s

For the moment this solver is compatible only with pytorch and numpy (with pytorch conversion)

Motivation and context / Related issue

How has this been tested (if it applies)

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link

codecov bot commented Nov 10, 2023

Codecov Report

Merging #571 (49d2ef7) into master (cffb6cf) will increase coverage by 0.00%.
The diff coverage is 97.08%.

Additional details and impacted files
@@           Coverage Diff            @@
##           master     #571    +/-   ##
========================================
  Coverage   96.64%   96.65%            
========================================
  Files          74       75     +1     
  Lines       15187    15321   +134     
========================================
+ Hits        14678    14808   +130     
- Misses        509      513     +4     

@rflamary rflamary changed the title [WIP] Add a geomloss wrapper for sinkhorn solver [MRG] Add a geomloss wrapper for sinkhorn solver Nov 21, 2023
@rflamary rflamary merged commit 299f560 into master Nov 21, 2023
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant