diff --git a/ot/da.py b/ot/da.py index 5039fbd2f..db5f3ca5a 100644 --- a/ot/da.py +++ b/ot/da.py @@ -591,32 +591,64 @@ class OTDA(object): """ - def __init__(self, metric='sqeuclidean'): + def __init__(self, metric='sqeuclidean', norm = None, + distribution_estimation='uniform', target_samples = None, + mapping_type = 'barycentric_interpolation', + reg = 1, eta = 1): """ Class initialization""" - self.xs = 0 - self.xt = 0 - self.G = 0 self.metric = metric - self.computed = False - - def fit(self, xs, xt, ws=None, wt=None, norm=None): + self.norm = norm + + self.distribution_estimation = distribution_estimation + self.target_samples = target_samples + self.mapping_type = mapping_type + + self.reg = reg + self.eta = eta + + + def _estimate_distribution(self,X): + + if self.distribution_estimation == "uniform": + return unif(X.shape[0]) + + return X + + def fit(self, X, y=None): """ Fit domain adaptation between samples is xs and xt (with optional weights)""" - self.xs = xs - self.xt = xt - - if wt is None: - wt = unif(xt.shape[0]) - if ws is None: - ws = unif(xs.shape[0]) - - self.ws = ws - self.wt = wt - - self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) - self.G = emd(ws, wt, self.M) + self.xs = X + + self.ws = self._estimate_distribution(X) + + return self + + def transform(self,X): + """Scikitlearn compatible. Direction is determined by + the existence of an a priori over the target distribution """ + + return self._transport(X) if self._xor(np.array([(X == self.xs)]).all(),self.target_samples is None) else X + + def _transport(self,X): + + self.xt = X if self.target_samples is None else self.target_samples + self.wt = self._estimate_distribution(X) + + self.M = dist(self.xs, self.xt, metric=self.metric) + self.normalizeM(self.norm) + self.G = emd(self.ws, self.wt, self.M) self.computed = True - + + if self.mapping_type == 'OoS_mapping': + pass #TODO + + return self.interp(1) if self.target_samples is None else self.interp(-1) + + + def _xor(self,logical_expression_1,logical_expression_2): + + return ((logical_expression_1 or logical_expression_2) + and ((not logical_expression_1) or (not logical_expression_2))) + def interp(self, direction=1): """Barycentric interpolation for the source (1) or target (-1) samples @@ -652,7 +684,8 @@ def interp(self, direction=1): else: print("Warning, model not fitted yet, returning None") return None - + + def predict(self, x, direction=1): """ Out of sample mapping using the formulation from [6] @@ -703,71 +736,79 @@ def normalizeM(self, norm): class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic regularization""" - - def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): - """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)""" - self.xs = xs - self.xt = xt - - if wt is None: - wt = unif(xt.shape[0]) - if ws is None: - ws = unif(xs.shape[0]) - - self.ws = ws - self.wt = wt - - self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) - self.G = sinkhorn(ws, wt, self.M, reg, **kwargs) + + def _transport(self,X): + """ Regularized domain adaptation between samples xs and xt """ + self.xt = X if self.target_samples is None else self.target_samples + self.wt = self._estimate_distribution(X) + + self.M = dist(self.xs, self.xt, metric=self.metric) + self.normalizeM(self.norm) + self.G = sinkhorn(self.ws, self.wt, self.M, self.reg) self.computed = True - - + + if self.mapping_type == 'OoS_mapping': + pass #TODO + + return self.interp(1) if self.target_samples is None else self.interp(-1) + class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and group regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs): - """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters""" - self.xs = xs - self.xt = xt - - if wt is None: - wt = unif(xt.shape[0]) - if ws is None: - ws = unif(xs.shape[0]) - - self.ws = ws - self.wt = wt - - self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) - self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs) + + def fit(self, X, y): + """ Regularized domain adaptation between samples xs and xt """ + self.xs = X + self.ys = y + + self.ws = self._estimate_distribution(X) + + return self + + def _transport(self,X): + + self.xt = X if self.target_samples is None else self.target_samples + self.wt = self._estimate_distribution(X) + + self.M = dist(self.xs, self.xt, metric=self.metric) + self.normalizeM(self.norm) + self.G = sinkhorn_lpl1_mm(self.ws, self.ys, self.wt, self.M, self.reg, self.eta) self.computed = True - + + if self.mapping_type == 'OoS_mapping': + pass #TODO + + return self.interp(1) if self.target_samples is None else self.interp(-1) + class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic and group lasso regularization""" - def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs): - """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters""" - self.xs = xs - self.xt = xt - - if wt is None: - wt = unif(xt.shape[0]) - if ws is None: - ws = unif(xs.shape[0]) - - self.ws = ws - self.wt = wt - - self.M = dist(xs, xt, metric=self.metric) - self.normalizeM(norm) - self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs) + def fit(self, X, y): + """ Regularized domain adaptation between samples xs and xt """ + self.xs = X + self.ys = y + + self.ws = self._estimate_distribution(X) + + return self + + def _transport(self,X): + + self.xt = X if self.target_samples is None else self.target_samples + self.wt = self._estimate_distribution(X) + + self.M = dist(self.xs, self.xt, metric=self.metric) + self.normalizeM(self.norm) + self.G = sinkhorn_l1l2_gl(self.ws, self.ys, self.wt, self.M, self.reg, self.eta) self.computed = True - + + if self.mapping_type == 'OoS_mapping': + pass #TODO + + return self.interp(1) if self.target_samples is None else self.interp(-1) class OTDA_mapping_linear(OTDA):