Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 117 additions & 76 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,32 +591,64 @@ class OTDA(object):

"""

def __init__(self, metric='sqeuclidean'):
def __init__(self, metric='sqeuclidean', norm = None,
distribution_estimation='uniform', target_samples = None,
mapping_type = 'barycentric_interpolation',
reg = 1, eta = 1):
""" Class initialization"""
self.xs = 0
self.xt = 0
self.G = 0
self.metric = metric
self.computed = False

def fit(self, xs, xt, ws=None, wt=None, norm=None):
self.norm = norm

self.distribution_estimation = distribution_estimation
self.target_samples = target_samples
self.mapping_type = mapping_type

self.reg = reg
self.eta = eta


def _estimate_distribution(self,X):

if self.distribution_estimation == "uniform":
return unif(X.shape[0])

return X

def fit(self, X, y=None):
""" Fit domain adaptation between samples is xs and xt (with optional weights)"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
self.G = emd(ws, wt, self.M)
self.xs = X

self.ws = self._estimate_distribution(X)

return self

def transform(self,X):
"""Scikitlearn compatible. Direction is determined by
the existence of an a priori over the target distribution """

return self._transport(X) if self._xor(np.array([(X == self.xs)]).all(),self.target_samples is None) else X

def _transport(self,X):

self.xt = X if self.target_samples is None else self.target_samples
self.wt = self._estimate_distribution(X)

self.M = dist(self.xs, self.xt, metric=self.metric)
self.normalizeM(self.norm)
self.G = emd(self.ws, self.wt, self.M)
self.computed = True


if self.mapping_type == 'OoS_mapping':
pass #TODO

return self.interp(1) if self.target_samples is None else self.interp(-1)


def _xor(self,logical_expression_1,logical_expression_2):

return ((logical_expression_1 or logical_expression_2)
and ((not logical_expression_1) or (not logical_expression_2)))

def interp(self, direction=1):
"""Barycentric interpolation for the source (1) or target (-1) samples

Expand Down Expand Up @@ -652,7 +684,8 @@ def interp(self, direction=1):
else:
print("Warning, model not fitted yet, returning None")
return None



def predict(self, x, direction=1):
""" Out of sample mapping using the formulation from [6]

Expand Down Expand Up @@ -703,71 +736,79 @@ def normalizeM(self, norm):
class OTDA_sinkhorn(OTDA):

"""Class for domain adaptation with optimal transport with entropic regularization"""

def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)

def _transport(self,X):
""" Regularized domain adaptation between samples xs and xt """
self.xt = X if self.target_samples is None else self.target_samples
self.wt = self._estimate_distribution(X)

self.M = dist(self.xs, self.xt, metric=self.metric)
self.normalizeM(self.norm)
self.G = sinkhorn(self.ws, self.wt, self.M, self.reg)
self.computed = True



if self.mapping_type == 'OoS_mapping':
pass #TODO

return self.interp(1) if self.target_samples is None else self.interp(-1)

class OTDA_lpl1(OTDA):

"""Class for domain adaptation with optimal transport with entropic and group regularization"""

def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)

def fit(self, X, y):
""" Regularized domain adaptation between samples xs and xt """
self.xs = X
self.ys = y

self.ws = self._estimate_distribution(X)

return self

def _transport(self,X):

self.xt = X if self.target_samples is None else self.target_samples
self.wt = self._estimate_distribution(X)

self.M = dist(self.xs, self.xt, metric=self.metric)
self.normalizeM(self.norm)
self.G = sinkhorn_lpl1_mm(self.ws, self.ys, self.wt, self.M, self.reg, self.eta)
self.computed = True


if self.mapping_type == 'OoS_mapping':
pass #TODO

return self.interp(1) if self.target_samples is None else self.interp(-1)


class OTDA_l1l2(OTDA):

"""Class for domain adaptation with optimal transport with entropic and group lasso regularization"""

def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
""" Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
self.xs = xs
self.xt = xt

if wt is None:
wt = unif(xt.shape[0])
if ws is None:
ws = unif(xs.shape[0])

self.ws = ws
self.wt = wt

self.M = dist(xs, xt, metric=self.metric)
self.normalizeM(norm)
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
def fit(self, X, y):
""" Regularized domain adaptation between samples xs and xt """
self.xs = X
self.ys = y

self.ws = self._estimate_distribution(X)

return self

def _transport(self,X):

self.xt = X if self.target_samples is None else self.target_samples
self.wt = self._estimate_distribution(X)

self.M = dist(self.xs, self.xt, metric=self.metric)
self.normalizeM(self.norm)
self.G = sinkhorn_l1l2_gl(self.ws, self.ys, self.wt, self.M, self.reg, self.eta)
self.computed = True


if self.mapping_type == 'OoS_mapping':
pass #TODO

return self.interp(1) if self.target_samples is None else self.interp(-1)

class OTDA_mapping_linear(OTDA):

Expand Down