Skip to content
Permalink
master
Go to file
 
 
Cannot retrieve contributors at this time
200 lines (154 sloc) 5.18 KB
# -*- coding: utf-8 -*-
"""
Dimension reduction with OT
.. warning::
Note that by default the module is not imported in :mod:`ot`. In order to
use it you need to explicitely import :mod:`ot.dr`
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License
from scipy import linalg
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1, x2):
""" Compute squared euclidean distance between samples (autograd)
"""
x1p2 = np.sum(np.square(x1), 1)
x2p2 = np.sum(np.square(x2), 1)
return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)
def sinkhorn(w1, w2, M, reg, k):
"""Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K = np.exp(-M / reg)
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
vi = w2 / (np.dot(K.T, ui))
ui = w1 / (np.dot(K, vi))
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
def split_classes(X, y):
"""split samples in X by classes in y
"""
lstsclass = np.unique(y)
return [X[y == i, :].astype(np.float32) for i in lstsclass]
def fda(X, y, p=2, reg=1e-16):
"""Fisher Discriminant Analysis
Parameters
----------
X : ndarray, shape (n, d)
Training samples.
y : ndarray, shape (n,)
Labels for training samples.
p : int, optional
Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (ridge regularization)
Returns
-------
P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
proj : callable
projection function including mean centering
"""
mx = np.mean(X)
X -= mx.reshape((1, -1))
# data split between classes
d = X.shape[1]
xc = split_classes(X, y)
nc = len(xc)
p = min(nc - 1, p)
Cw = 0
for x in xc:
Cw += np.cov(x, rowvar=False)
Cw /= nc
mxc = np.zeros((d, nc))
for i in range(nc):
mxc[:, i] = np.mean(xc[i])
mx0 = np.mean(mxc, 1)
Cb = 0
for i in range(nc):
Cb += (mxc[:, i] - mx0).reshape((-1, 1)) * \
(mxc[:, i] - mx0).reshape((1, -1))
w, V = linalg.eig(Cb, Cw + reg * np.eye(d))
idx = np.argsort(w.real)
Popt = V[:, idx[-p:]]
def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None):
"""
Wasserstein Discriminant Analysis [11]_
The function solves the following optimization problem:
.. math::
P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
where :
- :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
- :math:`W` is entropic regularized Wasserstein distances
- :math:`X^i` are samples in the dataset corresponding to class i
Parameters
----------
X : ndarray, shape (n, d)
Training samples.
y : ndarray, shape (n,)
Labels for training samples.
p : int, optional
Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
P0 : ndarray, shape (d, p)
Initial starting point for projection.
verbose : int, optional
Print information along iterations.
Returns
-------
P : ndarray, shape (d, p)
Optimal transportation matrix for the given parameters
proj : callable
Projection function including mean centering.
References
----------
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
""" # noqa
mx = np.mean(X)
X -= mx.reshape((1, -1))
# data split between classes
d = X.shape[1]
xc = split_classes(X, y)
# compute uniform weighs
wc = [np.ones((x.shape[0]), dtype=np.float32) / x.shape[0] for x in xc]
def cost(P):
# wda loss
loss_b = 0
loss_w = 0
for i, xi in enumerate(xc):
xi = np.dot(xi, P)
for j, xj in enumerate(xc[i:]):
xj = np.dot(xj, P)
M = dist(xi, xj)
G = sinkhorn(wc[i], wc[j + i], M, reg, k)
if j == 0:
loss_w += np.sum(G * M)
else:
loss_b += np.sum(G * M)
# loss inversed because minimization
return loss_w / loss_b
# declare manifold and problem
manifold = Stiefel(d, p)
problem = Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
Popt = solver.solve(problem, x=P0)
def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)
return Popt, proj
You can’t perform that action at this time.