In [1]:
import torch
print('Torch version: {}'.format(torch.__version__))
print('CUDA available: {}'.format(torch.cuda.is_available()))
print('CUDA version: {}'.format(torch.version.cuda))
print('CUDNN version: {}'.format(torch.backends.cudnn.version()))

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark=True

from src.megawass import MegaWass

Torch version: 1.10.1
CUDA available: False
CUDA version: 10.2
CUDNN version: 7605


# Preparing toy data

In [2]:
# generate simulated data
nx = 100
dx = 30
ny = 200
dy = 20

x = torch.rand(nx, dx).to(device)
y = torch.rand(ny, dy).to(device)
Cx = torch.cdist(x, x, p=2)**2
Cy = torch.cdist(y, y, p=2)**2

D_samp = torch.rand(nx, ny).to(device)
D_feat = torch.rand(dx, dy).to(device)

# UCOOT

Given $2$ matrices of arbitrary size: $X_1 \in \mathbb R^{n_1 \times d_1}$ and $X_2 \in \mathbb R^{n_2 \times d_2}$, and $4$ corresponding histograms assigned to their rows and columns $\mu_{n_1}, \mu_{d_1}, \mu_{n_2}$ and $\mu_{d_2}$, the method $\texttt{solver_fucoot}$ solves
\begin{equation*}
    \begin{split}
        \text{FUCOOT}_{\rho, \lambda, \varepsilon}(X_1, X_2) 
        &= \inf_{\substack{P_s \in \mathbb R^{n_1 \times n_2}_{\geq 0} \\ P_f \in \mathbb R^{d_1 \times d_2}_{\geq 0}}} \text{func}(P_s, P_f) 
    \end{split}
\end{equation*}
where the function
\begin{equation*}
    \begin{split}
        \text{func}(P_s, P_f) 
        &= \langle | X_1 - X_2 |^2, P_s \otimes P_f \rangle 
        + \alpha_s \langle D_s, P_s \rangle + \alpha_f \langle D_f, P_f \rangle \\
        &+ \rho_1 \text{KL}(\text{some function of $P_s$ and $P_f$} \vert \mu_{n_1} \otimes \mu_{d_1}) 
        + \rho_2 \text{KL}(\text{some function of $P_s$ and $P_f$} \vert \mu_{n_2} \otimes \mu_{d_2})  \\
        &+ \varepsilon_s \text{KL}(P_s | \mu_{n_1} \otimes \mu_{n_2}) + 
        \varepsilon_f \text{KL}(P_f | \mu_{d_1} \otimes \mu_{d_2}).
    \end{split}
\end{equation*}
Here, the subscripts "s" and "f" mean sample and feature.

Some notes on the input arguments:

- By default, all histograms are uniform distributions, so just leave it as None if you don't want something else.

- The input matrices $D_s \in \mathbb R^{n_1 \times n_2}$ and $D_f \in \mathbb R^{d_1 \times d_2}$ present prior knowledge (if available) on the sample and feature couplings, respectively. If they are not available, then just leave it as None.

- The marginal relaxation parameters $\rho_1$ and $\rho_2$ can take any nonnegative values. It is also possible to use infinity value (by setting, e.g. $\rho_1 = \rho_2 = \texttt{float("inf")})$. In that case, you are doing (balanced) COOT and your epsilon **must** be **strictly positive**.

- The regularisation parameters $\varepsilon_s$ and $\varepsilon_f$ can be **any** nonnegative values, **even zero**. **Important note**: if at least one of them is zero, then $\rho_1$ and $\rho_2$ must **not** contain infinity values. (because we use another algorithm to solve the case zero epsilon and it won't work with infinity value).

- In case that you use zero epsilon, it may be desirable to increase argument $\texttt{nits_uot}$ because the algorithm may converge not fast enough.

- It is possible to trigger the early stopping if you see that the current and previous costs do not much differ. To do this, set your threshold via the argument $\texttt{early_stopping_tol}$.

- It is recommended that you set $\texttt{verbose = True}$, so that you can see the evolution of costs. 
It is also possible to save the training cost by setting $\texttt{log = True}$.

### Basic usage

In [3]:
rho = (1e-1, 1e-1) # use (float("inf"), float("inf")) if use COOT
eps = (1e-2, 0)

megawass = MegaWass(nits_bcd=20, nits_uot=1000, tol_bcd=1e-6, tol_uot=1e-6, eval_bcd=1, eval_uot=20)
(pi_samp, pi_feat), _, log_cost, log_ent_cost = megawass.solver_fucoot(
        X=x,
        Y=y,
        rho=rho,
        eps=eps,
        log=True,
        verbose=True,
        early_stopping_tol=1e-6
    )

Cost at iteration 1: 0.11049536615610123
Cost at iteration 2: 0.1103777140378952
Cost at iteration 3: 0.110260508954525
Cost at iteration 4: 0.11012265086174011
Cost at iteration 5: 0.10992490500211716
Cost at iteration 6: 0.10953307151794434
Cost at iteration 7: 0.10852140188217163
Cost at iteration 8: 0.10715685784816742
Cost at iteration 9: 0.10599719732999802
Cost at iteration 10: 0.10484065115451813
Cost at iteration 11: 0.10354764759540558
Cost at iteration 12: 0.10196802765130997
Cost at iteration 13: 0.10028758645057678
Cost at iteration 14: 0.09893098473548889
Cost at iteration 15: 0.0973796397447586
Cost at iteration 16: 0.09529273957014084
Cost at iteration 17: 0.09278176724910736
Cost at iteration 18: 0.09028378129005432
Cost at iteration 19: 0.08852900564670563
Cost at iteration 20: 0.08738391101360321


### A bit more complicated 

In [4]:
rho = (1e-1, 1e-1) # use (float("inf"), float("inf")) if use COOT
eps = (1e-2, 0)
alpha = (1, 1) # optional, only care if D_s and / or D_f is available
D = (D_samp, D_feat) # optional, only care if D_s and / or D_f is available

megawass = MegaWass(nits_bcd=20, nits_uot=1000, tol_bcd=1e-6, tol_uot=1e-6, eval_bcd=1, eval_uot=20)
(pi_samp, pi_feat), _, log_cost, log_ent_cost = megawass.solver_fucoot(
        X=x,
        Y=y,
        rho=rho,
        eps=eps,
        alpha=alpha,
        D=D,
        log=True,
        verbose=True,
        early_stopping_tol=1e-6
    )

Cost at iteration 1: 0.149433434009552
Cost at iteration 2: 0.1471438705921173
Cost at iteration 3: 0.14715427160263062
Cost at iteration 4: 0.14716514945030212


# UGW

Given $2$ square matrices of arbitrary size: $X_1 \in \mathbb R^{n_1 \times n_1}$ and $X_2 \in \mathbb R^{n_2 \times n_2}$, and $2$ corresponding histograms $\mu_{n_1}$ and $\mu_{n_2}$, the method $\texttt{solver_fugw_simple}$ solves
\begin{equation*}
    \begin{split}
        \text{FUGW}_{\rho, \lambda, \varepsilon}(X_1, X_2) 
        &= \inf_{P \in \mathbb R^{n_1 \times n_2}_{\geq 0}} \text{func}(P) 
    \end{split}
\end{equation*}
where the function
\begin{equation*}
    \begin{split}
        \text{func}(P) 
        &= \langle | X_1 - X_2 |^2, P \otimes P \rangle 
        + 2\alpha \; \langle D, P \rangle \\
        &+ \rho_1 \text{KL}(\text{some function of $P$} \vert \mu_{n_1} \otimes \mu_{n_1}) 
        + \rho_2 \text{KL}(\text{some function of $P$} \vert \mu_{n_2} \otimes \mu_{n_2})  \\
        &+ 2 \varepsilon \; \text{KL}(P | \mu_{n_1} \otimes \mu_{n_2}).
    \end{split}
\end{equation*}

Some notes on the input arguments: almost the same as above.

- By default, all histograms are uniform distributions, so just leave it as None if you don't want something else.

- The input matrice $D \in \mathbb R^{n_1 \times n_2}$ presents prior knowledge (if available) on the sample couplings. If it is not available, then just leave it as None.

- The marginal relaxation parameters $\rho_1$ and $\rho_2$ can take any nonnegative values. It is also possible to use infinity value (by setting, e.g. $\rho_1 = \rho_2 = \texttt{float("inf")})$. In that case, you are doing (balanced) GW and your epsilon **must** be **strictly positive**.

- The regularisation parameter $\varepsilon$ can be **any** nonnegative values, **even zero**. **Important note**: if it is zero, then $\rho_1$ and $\rho_2$ must **not** contain infinity values. (because we use another algorithm to solve the case zero epsilon and it won't work with infinity value).

- In case that you use zero epsilon, it may be desirable to increase argument $\texttt{nits_uot}$ because the algorithm may converge not fast enough.

- It is possible to trigger the early stopping if you see that the current and previous costs do not much differ. To do this, set your threshold via the argument $\texttt{early_stopping_tol}$.

- It is recommended that you set $\texttt{verbose = True}$, so that you can see the evolution of costs. 
It is also possible to save the training cost by setting $\texttt{log = True}$.

### Basic usage

In [5]:
rho = (1e-1, 1e-1) # use (float("inf"), float("inf")) if use COOT
eps = 0

megawass = MegaWass(nits_bcd=20, nits_uot=1000, tol_bcd=1e-6, tol_uot=1e-6, eval_bcd=1, eval_uot=20)
(pi_samp, pi_feat), _, log_cost, log_ent_cost = megawass.solver_fugw_simple(
        X=Cx,
        Y=Cy,
        rho=rho,
        eps=eps,
        log=True,
        verbose=True,
        early_stopping_tol=1e-6
    )

Cost at iteration 1: 0.1998400092124939
Cost at iteration 2: 0.19886314868927002
Cost at iteration 3: 0.19852057099342346
Cost at iteration 4: 0.19836637377738953
Cost at iteration 5: 0.19827359914779663
Cost at iteration 6: 0.19818103313446045
Cost at iteration 7: 0.1981189250946045
Cost at iteration 8: 0.198068767786026
Cost at iteration 9: 0.19804254174232483
Cost at iteration 10: 0.19803324341773987
Cost at iteration 11: 0.19802765548229218
Cost at iteration 12: 0.19801917672157288
Cost at iteration 13: 0.1979912370443344
Cost at iteration 14: 0.19797450304031372
Cost at iteration 15: 0.19795718789100647
Cost at iteration 16: 0.19795098900794983
Cost at iteration 17: 0.1979484260082245
Cost at iteration 18: 0.1979471892118454
Cost at iteration 19: 0.19794538617134094
Cost at iteration 20: 0.19794116914272308


### A bit more complicated

In [6]:
rho = (1e-1, 1e-1) # use (float("inf"), float("inf")) if use COOT
eps = 1e-2
alpha = 1 # optional, only care if D is available
D = D_samp # optional, only care if D is available

megawass = MegaWass(nits_bcd=20, nits_uot=1000, tol_bcd=1e-6, tol_uot=1e-6, eval_bcd=1, eval_uot=20)
(pi_samp, pi_feat), _, log_cost, log_ent_cost = megawass.solver_fugw_simple(
        X=Cx,
        Y=Cy,
        rho=rho,
        eps=eps,
        alpha=alpha,
        D=D,
        log=True,
        verbose=True,
        early_stopping_tol=1e-6
    )

Cost at iteration 1: 0.20035411417484283
Cost at iteration 2: 0.19998568296432495
Cost at iteration 3: 0.19994251430034637
Cost at iteration 4: 0.19991283118724823
Cost at iteration 5: 0.1998777836561203
Cost at iteration 6: 0.19985756278038025
Cost at iteration 7: 0.19984951615333557
Cost at iteration 8: 0.1998465210199356
